Fix batch of test fixes

This commit is contained in:
Georges-Antoine Assi
2024-01-17 10:26:37 -05:00
parent c486763e26
commit 7963e243dc
26 changed files with 210 additions and 228 deletions

View File

@@ -90,19 +90,17 @@ def upgrade() -> None:
file_size_bytes = int(row[1] * SIZE_UNIT_TO_BYTES.get(row[2], 1))
updates.append({"id": row[0], "file_size_bytes": file_size_bytes})
if not updates:
return
# Perform bulk update
connection.execute(
text("UPDATE roms SET file_size_bytes = :file_size_bytes WHERE id = :id"),
updates,
)
if updates:
# Perform bulk update
connection.execute(
text("UPDATE roms SET file_size_bytes = :file_size_bytes WHERE id = :id"),
updates,
)
# Clean roms table
with op.batch_alter_table("roms", schema=None) as batch_op:
batch_op.drop_column("file_size")
batch_op.drop_column("file_size_units")
# Clean roms table
# ### end Alembic commands ###
@@ -122,6 +120,10 @@ def downgrade() -> None:
batch_op.create_foreign_key(None, "platforms", ["platform_slug"], ["slug"])
batch_op.drop_column("platform_id")
batch_op.drop_column("file_size_bytes")
batch_op.add_column(
sa.Column("file_size_units", sa.String(length=10), nullable=False)
)
batch_op.add_column(sa.Column("file_size", sa.Float(), nullable=False))
with op.batch_alter_table("saves", schema=None) as batch_op:
batch_op.add_column(

View File

@@ -50,9 +50,10 @@ class ConfigManager:
_self = None
def __new__(cls):
def __new__(cls, *args, **kwargs):
if cls._self is None:
cls._self = super().__new__(cls)
cls._self = super().__new__(cls, *args, **kwargs)
return cls._self
# Tests require custom config path

View File

@@ -276,7 +276,7 @@ async def update_rom(
return dbromh.get_roms(id)
@protected_route(router.delete, "/roms", ["roms.write"])
@protected_route(router.post, "/roms/delete", ["roms.write"])
async def delete_roms(
request: Request,
) -> MessageResponse:

View File

@@ -60,7 +60,7 @@ def update_save(request: Request, id: int) -> MessageResponse:
pass
@protected_route(router.put, "/saves", ["assets.write"])
@protected_route(router.post, "/saves/delete", ["assets.write"])
async def delete_saves(request: Request) -> MessageResponse:
data: dict = await request.json()
save_ids: list = data["saves"]
@@ -86,7 +86,7 @@ async def delete_saves(request: Request) -> MessageResponse:
try:
fsasseth.remove_file(file_name=save.file_name, file_path=save.file_path)
except FileNotFoundError:
error = f"Save file {save.file_name} not found for platform {save.platform_slug}"
error = f"Save file {save.file_name} not found for platform {save.rom.platform_slug}"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)

View File

@@ -60,7 +60,7 @@ def update_save(request: Request, id: int) -> MessageResponse:
pass
@protected_route(router.put, "/states", ["assets.write"])
@protected_route(router.post, "/states/delete", ["assets.write"])
async def delete_states(request: Request) -> MessageResponse:
data: dict = await request.json()
state_ids: list = data["states"]
@@ -85,7 +85,7 @@ async def delete_states(request: Request) -> MessageResponse:
try:
fsasseth.remove_file(file_name=state.file_name, file_path=state.file_path)
except FileNotFoundError:
error = f"Save file {state.file_name} not found for platform {state.platform_slug}"
error = f"Save file {state.file_name} not found for platform {state.rom.platform_slug}"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)

View File

@@ -1,6 +1,6 @@
import pytest
from utils.oauth import create_oauth_token
from handler import oauthh
from datetime import timedelta
from handler.tests.conftest import setup_database, clear_database, admin_user, editor_user, viewer_user, platform, rom, save, state # noqa
from ..auth import ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS
@@ -14,7 +14,7 @@ def access_token(admin_user): # noqa
"type": "access",
}
return create_oauth_token(
return oauthh.create_oauth_token(
data=data, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
@@ -27,6 +27,6 @@ def refresh_token(admin_user): # noqa
"type": "refresh",
}
return create_oauth_token(
return oauthh.create_oauth_token(
data=data, expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
)

View File

@@ -14,7 +14,7 @@ def test_delete_saves(access_token, save):
assert response.status_code == 200
body = response.json()
assert body == []
assert body['msg'] == "Successfully deleted 1 saves."
def test_delete_states(access_token, state):
@@ -26,4 +26,4 @@ def test_delete_states(access_token, state):
assert response.status_code == 200
body = response.json()
assert body == []
assert body['msg'] == "Successfully deleted 1 states."

View File

@@ -0,0 +1,23 @@
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
def test_config():
response = client.get("/config")
assert response.status_code == 200
config = response.json()
assert config.get('EXCLUDED_PLATFORMS') == []
assert config.get('EXCLUDED_SINGLE_EXT') == []
assert config.get('EXCLUDED_SINGLE_FILES') == []
assert config.get('EXCLUDED_MULTI_FILES') == []
assert config.get('EXCLUDED_MULTI_PARTS_EXT') == []
assert config.get('EXCLUDED_MULTI_PARTS_FILES') == []
assert config.get('PLATFORMS_BINDING') == {}
assert config.get('ROMS_FOLDER_NAME') == 'roms'
assert config.get('SAVES_FOLDER_NAME') == 'saves'
assert config.get('STATES_FOLDER_NAME') == 'states'
assert config.get('SCREENSHOTS_FOLDER_NAME') == 'screenshots'

View File

@@ -1,7 +1,7 @@
from fastapi.testclient import TestClient
from main import app
from utils import get_version
from handler import ghh
client = TestClient(app)
@@ -9,8 +9,9 @@ client = TestClient(app)
def test_heartbeat():
response = client.get("/heartbeat")
assert response.status_code == 200
heartbeat = response.json()
assert heartbeat.get('VERSION') == get_version()
assert heartbeat.get('VERSION') == ghh.get_version()
assert heartbeat.get('ROMM_AUTH_ENABLED')
assert heartbeat.get('WATCHER').get('ENABLED')
assert heartbeat.get('WATCHER').get('TITLE') == "Rescan on filesystem change"
@@ -23,5 +24,3 @@ def test_heartbeat():
assert heartbeat.get('SCHEDULER').get('MAME_XML').get('ENABLED')
assert heartbeat.get('SCHEDULER').get('MAME_XML').get('CRON') == "0 5 * * *"
assert heartbeat.get('SCHEDULER').get('MAME_XML').get('TITLE') == "Scheduled MAME XML update"
assert heartbeat.get('CONFIG').get('EXCLUDED_MULTI_FILES') == []
assert heartbeat.get('CONFIG').get('EXCLUDED_SINGLE_EXT') == []

View File

@@ -3,7 +3,7 @@ import pytest
from fastapi.testclient import TestClient
from main import app
from utils.cache import cache
from handler.redis_handler import cache
from models.user import Role
client = TestClient(app)
@@ -27,12 +27,12 @@ def test_login_logout(admin_user):
assert response.status_code == 200
assert response.cookies.get("session")
assert response.json()["message"] == "Successfully logged in"
assert response.json()["msg"] == "Successfully logged in"
response = client.post("/logout")
assert response.status_code == 200
assert response.json()["message"] == "Successfully logged out"
assert response.json()["msg"] == "Successfully logged out"
def test_get_all_users(access_token):
@@ -92,4 +92,4 @@ def test_delete_user(access_token, editor_user):
assert response.status_code == 200
body = response.json()
assert body["message"] == "User successfully deleted"
assert body["msg"] == "User successfully deleted"

View File

@@ -2,7 +2,7 @@ from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES
from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient
from main import app
from utils.oauth import WRITE_SCOPES
from handler.auth_handler import WRITE_SCOPES
client = TestClient(app)

View File

@@ -57,7 +57,7 @@ class AuthHandler:
def clear_session(req: HTTPConnection | Request):
session_id = req.session.get("session_id")
if session_id:
redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
req.session["session_id"] = None
def authenticate_user(self, username: str, password: str):
@@ -126,7 +126,7 @@ class OAuthHandler:
def __init__(self) -> None:
pass
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
def create_oauth_token(self, data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
@@ -138,7 +138,7 @@ class OAuthHandler:
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
async def get_current_active_user_from_bearer_token(token: str):
async def get_current_active_user_from_bearer_token(self, token: str):
from handler import dbuserh
try:

View File

@@ -3,7 +3,7 @@ from typing import Any
import emoji
from config.config_manager import config_manager as cm
from handler import fsasseth, dbplatformh, igdbh, fsresourceh, fsromh
from handler import fsasseth, igdbh, fsresourceh, fsromh, dbplatformh
from logger.logger import log
from models import Platform, Rom, Save, Screenshot, State
@@ -32,7 +32,7 @@ def scan_platform(fs_slug: str, fs_platforms) -> Platform:
f" {fs_slug} not found in file system, trying to match via config..."
)
if fs_slug in SWAPPED_PLATFORM_BINDINGS.keys():
platform = dbh.get_platform_by_fs_slug(fs_slug)
platform = dbplatformh.get_platform_by_fs_slug(fs_slug)
if platform:
platform_attrs["fs_slug"] = SWAPPED_PLATFORM_BINDINGS[platform.slug]
@@ -80,6 +80,7 @@ async def scan_rom(
regs, rev, langs, other_tags = fsromh.parse_tags(rom_attrs["file_name"])
rom_attrs.update(
{
"platform_id": platform.id,
"file_path": roms_path,
"file_name": rom_attrs["file_name"],
"file_name_no_tags": fsromh.get_file_name_with_no_tags(rom_attrs["file_name"]),
@@ -92,7 +93,6 @@ async def scan_rom(
"tags": other_tags,
}
)
rom_attrs["platform_id"] = platform.id
# Search in IGDB
igdbh_rom = (

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import sessionmaker
from config.config_manager import ConfigManager
from models import Platform, Rom, User, Save, State, Screenshot
from models.user import Role
from handler import dbh, dbuserh, dbplatformh, dbromh, dbsaveh, dbstateh, authh
from handler import dbuserh, dbplatformh, dbromh, dbsaveh, dbstateh, authh, dbscreenshotsh
engine = create_engine(ConfigManager.get_db_engine(), pool_pre_ping=True)
session = sessionmaker(bind=engine, expire_on_commit=False)
@@ -39,9 +39,9 @@ def platform():
@pytest.fixture
def rom(platform: Platform):
rom = Rom(
platform_id=platform.id,
name="test_rom",
slug="test_rom_slug",
platform_slug=platform.slug,
file_name="test_rom.zip",
file_name_no_tags="test_rom",
file_extension="zip",
@@ -52,46 +52,44 @@ def rom(platform: Platform):
@pytest.fixture
def save(rom: Rom):
def save(rom: Rom, platform: Platform):
save = Save(
rom_id=rom.id,
platform_slug=rom.platform_slug,
file_name="test_save.sav",
file_name_no_tags="test_save",
file_extension="sav",
emulator="test_emulator",
file_path=f"{rom.platform_slug}/saves/test_emulator",
file_path=f"{platform.slug}/saves/test_emulator",
file_size_bytes=1.0,
)
return dbsaveh.add_save(save)
@pytest.fixture
def state(rom: Rom):
def state(rom: Rom, platform: Platform):
state = State(
rom_id=rom.id,
platform_slug=rom.platform_slug,
file_name="test_state.state",
file_name_no_tags="test_state",
file_extension="state",
emulator="test_emulator",
file_path=f"{rom.platform_slug}/states/test_emulator",
file_path=f"{platform.slug}/states/test_emulator",
file_size_bytes=2.0,
)
return dbstateh.add_state(state)
@pytest.fixture
def screenshot(rom: Rom):
def screenshot(rom: Rom, platform: Platform):
screenshot = Screenshot(
rom_id=rom.id,
file_name="test_screenshot.png",
file_name_no_tags="test_screenshot",
file_extension="png",
file_path=f"{rom.platform_slug}/screenshots",
file_path=f"{platform.slug}/screenshots",
file_size_bytes=3.0,
)
return dbh.add_screenshot(screenshot)
return dbscreenshotsh.add_screenshot(screenshot)
@pytest.fixture
def admin_user():

View File

@@ -3,34 +3,32 @@ from sqlalchemy.exc import IntegrityError
# from handler.db_handler import DBHandler
from models import Platform, Rom, User, Save, State, Screenshot
from models.user import Role
from handler import dbh, authh
# dbh = DBHandler()
from handler import authh, dbplatformh, dbromh, dbuserh, dbsaveh, dbstateh, dbscreenshotsh
def test_platforms():
platform = Platform(
name="test_platform", slug="test_platform_slug", fs_slug="test_platform_slug"
)
dbh.add_platform(platform)
dbplatformh.add_platform(platform)
platforms = dbh.get_platform()
platforms = dbplatformh.get_platform()
assert len(platforms) == 1
platform = dbh.get_platform(platform.slug)
platform = dbplatformh.get_platform(platform.slug)
assert platform.name == "test_platform"
dbh.purge_platforms([])
platforms = dbh.get_platform()
dbplatformh.purge_platforms([])
platforms = dbplatformh.get_platform()
assert len(platforms) == 0
def test_roms(rom):
dbh.add_rom(
def test_roms(rom: Rom):
dbromh.add_rom(
Rom(
platform_id=rom.platform_id,
name="test_rom_2",
slug="test_rom_slug_2",
platform_slug=rom.platform_slug,
file_name="test_rom_2",
file_name_no_tags="test_rom_2",
file_extension="zip",
@@ -39,66 +37,66 @@ def test_roms(rom):
)
)
with dbh.session.begin() as session:
roms = session.scalars(dbh.get_roms(rom.platform_slug)).all()
with dbromh.session.begin() as session:
roms = session.scalars(dbromh.get_roms(rom.platform_slug)).all()
assert len(roms) == 2
rom = dbh.get_rom(roms[0].id)
rom = dbromh.get_rom(roms[0].id)
assert rom.file_name == "test_rom.zip"
dbh.update_rom(roms[1].id, {"file_name": "test_rom_2_updated"})
rom_2 = dbh.get_rom(roms[1].id)
dbromh.update_rom(roms[1].id, {"file_name": "test_rom_2_updated"})
rom_2 = dbromh.get_rom(roms[1].id)
assert rom_2.file_name == "test_rom_2_updated"
dbh.delete_rom(rom.id)
dbromh.delete_rom(rom.id)
with dbh.session.begin() as session:
roms = session.scalars(dbh.get_roms(rom.platform_slug)).all()
with dbromh.session.begin() as session:
roms = session.scalars(dbromh.get_roms(rom.platform_slug)).all()
assert len(roms) == 1
dbh.purge_roms(rom_2.platform_slug, [rom_2.id])
dbromh.purge_roms(rom_2.platform_slug, [rom_2.id])
with dbh.session.begin() as session:
roms = session.scalars(dbh.get_roms(rom.platform_slug)).all()
with dbromh.session.begin() as session:
roms = session.scalars(dbromh.get_roms(rom.platform_slug)).all()
assert len(roms) == 0
def test_utils(rom):
with dbh.session.begin() as session:
roms = session.scalars(dbh.get_roms(rom.platform_slug)).all()
def test_utils(rom: Rom):
with dbromh.session.begin() as session:
roms = session.scalars(dbromh.get_roms(rom.platform_slug)).all()
assert (
dbh.get_rom_by_filename(rom.platform_slug, rom.file_name).id == roms[0].id
dbromh.get_rom_by_filename(rom.platform_slug, rom.file_name).id == roms[0].id
)
def test_users(admin_user):
dbh.add_user(
dbuserh.add_user(
User(
username="new_user",
hashed_password=authh.get_password_hash("new_password"),
)
)
all_users = dbh.get_users()
all_users = dbuserh.get_users()
assert len(all_users) == 2
new_user = dbh.get_user_by_username("new_user")
new_user = dbuserh.get_user_by_username("new_user")
assert new_user.username == "new_user"
assert new_user.role == Role.VIEWER
assert new_user.enabled
dbh.update_user(new_user.id, {"role": Role.EDITOR})
dbuserh.update_user(new_user.id, {"role": Role.EDITOR})
new_user = dbh.get_user(new_user.id)
new_user = dbuserh.get_user(new_user.id)
assert new_user.role == Role.EDITOR
dbh.delete_user(new_user.id)
dbuserh.delete_user(new_user.id)
all_users = dbh.get_users()
all_users = dbuserh.get_users()
assert len(all_users) == 1
try:
new_user = dbh.add_user(
new_user = dbuserh.add_user(
User(
username="test_admin",
hashed_password=authh.get_password_hash("new_password"),
@@ -109,11 +107,10 @@ def test_users(admin_user):
assert "Duplicate entry 'test_admin' for key" in str(e)
def test_saves(save):
dbh.add_save(
def test_saves(save: Save):
dbsaveh.add_save(
Save(
rom_id=save.rom_id,
platform_slug=save.platform_slug,
file_name="test_save_2.sav",
file_name_no_tags="test_save_2",
file_extension="sav",
@@ -123,27 +120,26 @@ def test_saves(save):
)
)
rom = dbh.get_rom(save.rom_id)
rom = dbsaveh.get_rom(save.rom_id)
assert len(rom.saves) == 2
save = dbh.get_save(rom.saves[0].id)
save = dbsaveh.get_save(rom.saves[0].id)
assert save.file_name == "test_save.sav"
dbh.update_save(save.id, {"file_name": "test_save_2.sav"})
save = dbh.get_save(save.id)
dbsaveh.update_save(save.id, {"file_name": "test_save_2.sav"})
save = dbsaveh.get_save(save.id)
assert save.file_name == "test_save_2.sav"
dbh.delete_save(save.id)
dbsaveh.delete_save(save.id)
rom = dbh.get_rom(save.rom_id)
rom = dbsaveh.get_rom(save.rom_id)
assert len(rom.saves) == 1
def test_states(state):
dbh.add_state(
def test_states(state: State):
dbstateh.add_state(
State(
rom_id=state.rom_id,
platform_slug=state.platform_slug,
file_name="test_state_2.state",
file_name_no_tags="test_state_2",
file_extension="state",
@@ -152,27 +148,26 @@ def test_states(state):
)
)
rom = dbh.get_rom(state.rom_id)
rom = dbstateh.get_rom(state.rom_id)
assert len(rom.states) == 2
state = dbh.get_state(rom.states[0].id)
state = dbstateh.get_state(rom.states[0].id)
assert state.file_name == "test_state.state"
dbh.update_state(state.id, {"file_name": "test_state_2.state"})
state = dbh.get_state(state.id)
dbstateh.update_state(state.id, {"file_name": "test_state_2.state"})
state = dbstateh.get_state(state.id)
assert state.file_name == "test_state_2.state"
dbh.delete_state(state.id)
dbstateh.delete_state(state.id)
rom = dbh.get_rom(state.rom_id)
rom = dbstateh.get_rom(state.rom_id)
assert len(rom.states) == 1
def test_screenshots(screenshot):
dbh.add_screenshot(
def test_screenshots(screenshot: Screenshot):
dbscreenshotsh.add_screenshot(
Screenshot(
rom_id=screenshot.rom_id,
platform_slug=screenshot.platform_slug,
file_name="test_screenshot_2.png",
file_name_no_tags="test_screenshot_2",
file_extension="png",
@@ -181,17 +176,17 @@ def test_screenshots(screenshot):
)
)
rom = dbh.get_rom(screenshot.rom_id)
rom = dbscreenshotsh.get_rom(screenshot.rom_id)
assert len(rom.screenshots) == 2
screenshot = dbscreenshotsh.get_screenshot(rom.screenshots[0].id)
assert screenshot.file_name == "test_screenshot.png"
dbscreenshotsh.update_screenshot(screenshot.id, {"file_name": "test_screenshot_2.png"})
screenshot = dbh.get_screenshot(screenshot.id)
screenshot = dbscreenshotsh.get_screenshot(screenshot.id)
assert screenshot.file_name == "test_screenshot_2.png"
dbscreenshotsh.delete_screenshot(screenshot.id)
rom = dbh.get_rom(screenshot.rom_id)
rom = dbscreenshotsh.get_rom(screenshot.rom_id)
assert len(rom.screenshots) == 1

View File

@@ -18,7 +18,6 @@ from endpoints import (
webrcade,
stats,
)
from endpoints.sockets import scan
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination

View File

@@ -1,4 +1,4 @@
from utils.oauth import DEFAULT_SCOPES, WRITE_SCOPES, FULL_SCOPES
from handler.auth_handler import FULL_SCOPES, WRITE_SCOPES, DEFAULT_SCOPES
def test_admin(admin_user):
admin_user.oauth_scopes == FULL_SCOPES

View File

@@ -20,11 +20,8 @@ class User(BaseModel, SimpleUser):
username: str = Column(String(length=255), unique=True, index=True)
hashed_password: str = Column(String(length=255))
enabled: bool = Column(Boolean(), default=True)
role: Role = Column(Enum(Role), default=Role.VIEWER)
avatar_path: str = Column(String(length=255), default="")
@property

View File

@@ -3,26 +3,17 @@ from base64 import b64encode
from fastapi.exceptions import HTTPException
from models import User
from handler import dbh
from ..auth import (
verify_password,
get_password_hash,
authenticate_user,
get_current_active_user_from_session,
create_default_admin_user,
HybridAuthBackend,
)
from ..oauth import WRITE_SCOPES, create_oauth_token
from ..cache import cache
from handler import authh, oauthh, dbuserh
from handler.redis_handler import cache
def test_verify_password():
assert verify_password("password", get_password_hash("password"))
assert not verify_password("password", get_password_hash("notpassword"))
assert authh.verify_password("password", authh.get_password_hash("password"))
assert not authh.verify_password("password", authh.get_password_hash("notpassword"))
def test_authenticate_user(admin_user):
current_user = authenticate_user("test_admin", "test_admin_password")
current_user = authh.authenticate_user("test_admin", "test_admin_password")
assert current_user
assert current_user.id == admin_user.id
@@ -37,7 +28,7 @@ async def test_get_current_active_user_from_session(editor_user):
self.session = {"session_id": session_id}
conn = MockConnection()
current_user = await get_current_active_user_from_session(conn)
current_user = await authh.get_current_active_user_from_session(conn)
assert current_user
assert isinstance(current_user, User)
@@ -53,7 +44,7 @@ async def test_get_current_active_user_from_session_bad_session_key(editor_user)
self.headers = {}
conn = MockConnection()
current_user = await get_current_active_user_from_session(conn)
current_user = await authh.get_current_active_user_from_session(conn)
assert not current_user
@@ -70,7 +61,7 @@ async def test_get_current_active_user_from_session_bad_username(editor_user):
conn = MockConnection()
try:
await get_current_active_user_from_session(conn)
await authh.get_current_active_user_from_session(conn)
except HTTPException as e:
assert e.status_code == 403
assert e.detail == "User not found"
@@ -87,28 +78,28 @@ async def test_get_current_active_user_from_session_disabled_user(editor_user):
conn = MockConnection()
dbh.update_user(editor_user.id, {"enabled": False})
dbuserh.update_user(editor_user.id, {"enabled": False})
try:
await get_current_active_user_from_session(conn)
await authh.get_current_active_user_from_session(conn)
except HTTPException as e:
assert e.status_code == 403
assert e.detail == "Inactive user"
def test_create_default_admin_user():
create_default_admin_user()
authh.create_default_admin_user()
user = dbh.get_user_by_username("test_admin")
user = dbuserh.get_user_by_username("test_admin")
assert user.username == "test_admin"
assert verify_password("test_admin_password", user.hashed_password)
assert authh.verify_password("test_admin_password", user.hashed_password)
users = dbh.get_users()
users = dbuserh.get_users()
assert len(users) == 1
create_default_admin_user()
authh.create_default_admin_user()
users = dbh.get_users()
users = dbuserh.get_users()
assert len(users) == 1
@@ -120,14 +111,14 @@ async def test_hybrid_auth_backend_session(editor_user):
def __init__(self):
self.session = {"session_id": session_id}
backend = HybridAuthBackend()
backend = authh.authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)
assert user.id == editor_user.id
assert creds.scopes == editor_user.oauth_scopes
assert creds.scopes == WRITE_SCOPES
assert creds.scopes == oauthh.WRITE_SCOPES
async def test_hybrid_auth_backend_empty_session_and_headers(editor_user):
@@ -136,7 +127,7 @@ async def test_hybrid_auth_backend_empty_session_and_headers(editor_user):
self.session = {}
self.headers = {}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)
@@ -146,7 +137,7 @@ async def test_hybrid_auth_backend_empty_session_and_headers(editor_user):
async def test_hybrid_auth_backend_bearer_auth_header(editor_user):
access_token = create_oauth_token(
access_token = oauthh.create_oauth_token(
data={
"sub": editor_user.username,
"scopes": " ".join(editor_user.oauth_scopes),
@@ -159,7 +150,7 @@ async def test_hybrid_auth_backend_bearer_auth_header(editor_user):
self.session = {}
self.headers = {"Authorization": f"Bearer {access_token}"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)
@@ -174,7 +165,7 @@ async def test_hybrid_auth_backend_bearer_invalid_token(editor_user):
self.session = {}
self.headers = {"Authorization": "Bearer invalid_token"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
with pytest.raises(HTTPException):
@@ -189,13 +180,13 @@ async def test_hybrid_auth_backend_basic_auth_header(editor_user):
self.session = {}
self.headers = {"Authorization": f"Basic {token}"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)
assert user.id == editor_user.id
assert creds.scopes == WRITE_SCOPES
assert creds.scopes == oauthh.WRITE_SCOPES
assert set(creds.scopes).issubset(editor_user.oauth_scopes)
@@ -205,7 +196,7 @@ async def test_hybrid_auth_backend_basic_auth_header_unencoded(editor_user):
self.session = {}
self.headers = {"Authorization": "Basic test_editor:test_editor_password"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
with pytest.raises(HTTPException):
@@ -218,7 +209,7 @@ async def test_hybrid_auth_backend_invalid_scheme():
self.session = {}
self.headers = {"Authorization": "Some invalid_scheme"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)
@@ -228,7 +219,7 @@ async def test_hybrid_auth_backend_invalid_scheme():
async def test_hybrid_auth_backend_with_refresh_token(editor_user):
refresh_token = create_oauth_token(
refresh_token = oauthh.create_oauth_token(
data={
"sub": editor_user.username,
"scopes": " ".join(editor_user.oauth_scopes),
@@ -241,7 +232,7 @@ async def test_hybrid_auth_backend_with_refresh_token(editor_user):
self.session = {}
self.headers = {"Authorization": f"Bearer {refresh_token}"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)
@@ -252,7 +243,7 @@ async def test_hybrid_auth_backend_with_refresh_token(editor_user):
async def test_hybrid_auth_backend_scope_subset(editor_user):
scopes = editor_user.oauth_scopes[:3]
access_token = create_oauth_token(
access_token = oauthh.create_oauth_token(
data={
"sub": editor_user.username,
"scopes": " ".join(scopes),
@@ -265,7 +256,7 @@ async def test_hybrid_auth_backend_scope_subset(editor_user):
self.session = {}
self.headers = {"Authorization": f"Bearer {access_token}"}
backend = HybridAuthBackend()
backend = authh.HybridAuthBackend()
conn = MockConnection()
creds, user = await backend.authenticate(conn)

View File

@@ -1,6 +1,6 @@
import pytest
from ..fastapi import scan_platform, scan_rom
from handler.scan_handler import scan_platform, scan_rom
from exceptions.fs_exceptions import RomsNotFoundException
from models import Platform, Rom

View File

@@ -1,29 +1,14 @@
import pytest
from unittest.mock import patch
from ...handler.fs_handler.fs_roms_handler import (
get_rom_cover,
get_platforms,
get_fs_structure,
get_roms,
get_rom_file_size,
_exclude_files,
# get_rom_screenshots # TODO: write test
# store_default_resources # TODO: write test
# get_rom_files, # TODO: write test
# rename_file, # TODO: write test
# remove_file, # TODO: write test
# build_upload_file_path, # TODO: write test
# build_artwork_path, # TODO: write test
# build_avatar_path, # TODO: write test
)
import pytest
from handler import fsresourceh, fsplatformh, fsromh
from config import DEFAULT_PATH_COVER_L, DEFAULT_PATH_COVER_S
@pytest.mark.vcr
def test_get_rom_cover():
# Game: Metroid Prime (EUR).iso
cover = get_rom_cover(
cover = fsresourceh.get_rom_cover(
overwrite=False,
fs_slug="ngc",
rom_name="Metroid Prime",
@@ -33,7 +18,7 @@ def test_get_rom_cover():
assert DEFAULT_PATH_COVER_L in cover["path_cover_l"]
# Game: Paper Mario (USA).z64
cover = get_rom_cover(
cover = fsresourceh.get_rom_cover(
overwrite=True,
fs_slug="n64",
rom_name="Paper Mario",
@@ -44,7 +29,7 @@ def test_get_rom_cover():
assert "n64/Paper%20Mario/cover/big.png" in cover["path_cover_l"]
# Game: Super Mario 64 (J) (Rev A)
cover = get_rom_cover(
cover = fsresourceh.get_rom_cover(
overwrite=False,
fs_slug="n64",
rom_name="Super Mario 64",
@@ -55,7 +40,7 @@ def test_get_rom_cover():
assert "n64/Super%20Mario%2064/cover/big.png" in cover["path_cover_l"]
# Game: Disney's Kim Possible: What's the Switch?.zip
cover = get_rom_cover(
cover = fsresourceh.get_rom_cover(
overwrite=False,
fs_slug="ps2",
rom_name="Disney's Kim Possible: What's the Switch?",
@@ -72,7 +57,7 @@ def test_get_rom_cover():
)
# Game: Fake Game.xyz
cover = get_rom_cover(
cover = fsresourceh.get_rom_cover(
overwrite=False,
fs_slug="n64",
rom_name="Fake Game",
@@ -83,20 +68,20 @@ def test_get_rom_cover():
def test_get_platforms():
platforms = get_platforms()
platforms = fsplatformh.get_platforms()
assert "n64" in platforms
assert "psx" in platforms
def test_get_fs_structure():
roms_structure = get_fs_structure(fs_slug="n64")
roms_structure = fsromh.get_fs_structure(fs_slug="n64")
assert roms_structure == "n64/roms"
def test_get_roms():
roms = get_roms(fs_slug="n64")
roms = fsromh.get_roms(fs_slug="n64")
assert len(roms) == 2
assert roms[0]["file_name"] == "Paper Mario (USA).z64"
@@ -107,16 +92,16 @@ def test_get_roms():
def test_rom_size():
rom_size = get_rom_file_size(
roms_path=get_fs_structure(fs_slug="n64"),
rom_size = fsromh.get_rom_file_size(
roms_path=fsromh.get_fs_structure(fs_slug="n64"),
file_name="Paper Mario (USA).z64",
multi=False,
)
assert rom_size == 1024
rom_size = get_rom_file_size(
roms_path=get_fs_structure(fs_slug="n64"),
rom_size = fsromh.get_rom_file_size(
roms_path=fsromh.get_fs_structure(fs_slug="n64"),
file_name="Super Mario 64 (J) (Rev A)",
multi=True,
multi_files=[
@@ -137,7 +122,7 @@ def test_exclude_files():
patch("utils.fs.config", cm.config)
filtered_files = _exclude_files(
filtered_files = fsromh._exclude_files(
files=[
"Super Mario 64 (J) (Rev A) [Part 1].z64",
"Super Mario 64 (J) (Rev A) [Part 2].z64",
@@ -149,7 +134,7 @@ def test_exclude_files():
cm.config.EXCLUDED_SINGLE_EXT = ["z64"]
filtered_files = _exclude_files(
filtered_files = fsromh._exclude_files(
files=[
"Super Mario 64 (J) (Rev A) [Part 1].z64",
"Super Mario 64 (J) (Rev A) [Part 2].z64",
@@ -161,7 +146,7 @@ def test_exclude_files():
cm.config.EXCLUDED_SINGLE_FILES = ["*.z64"]
filtered_files = _exclude_files(
filtered_files = fsromh._exclude_files(
files=[
"Super Mario 64 (J) (Rev A) [Part 1].z64",
"Super Mario 64 (J) (Rev A) [Part 2].z64",
@@ -173,7 +158,7 @@ def test_exclude_files():
cm.config.EXCLUDED_SINGLE_FILES = ["_.*"]
filtered_files = _exclude_files(
filtered_files = fsromh._exclude_files(
files=[
"Links Awakening.nsp",
"_.Links Awakening.nsp",

View File

@@ -2,29 +2,25 @@ import pytest
from fastapi import APIRouter, Request
from fastapi.exceptions import HTTPException
from handler import dbh
from ..oauth import (
create_oauth_token,
get_current_active_user_from_bearer_token,
protected_route,
)
from decorators.auth import protected_route
from handler import dbuserh, oauthh
def test_create_oauth_token():
token = create_oauth_token({"sub": "test_user"})
token = oauthh.create_oauth_token(data={"sub": "test_user"})
assert isinstance(token, str)
async def test_get_current_active_user_from_bearer_token(admin_user):
token = create_oauth_token(
{
token = oauthh.create_oauth_token(
data={
"sub": admin_user.username,
"scopes": " ".join(admin_user.oauth_scopes),
"type": "access",
},
)
user, payload = await get_current_active_user_from_bearer_token(token)
user, payload = await oauthh.get_current_active_user_from_bearer_token(token)
assert user.id == admin_user.id
assert payload["sub"] == admin_user.username
@@ -34,29 +30,29 @@ async def test_get_current_active_user_from_bearer_token(admin_user):
async def test_get_current_active_user_from_bearer_token_invalid_token():
with pytest.raises(HTTPException):
await get_current_active_user_from_bearer_token("invalid_token")
await oauthh.get_current_active_user_from_bearer_token("invalid_token")
async def test_get_current_active_user_from_bearer_token_invalid_user():
token = create_oauth_token({"sub": "invalid_user"})
token = oauthh.create_oauth_token(data={"sub": "invalid_user"})
with pytest.raises(HTTPException):
await get_current_active_user_from_bearer_token(token)
await oauthh.get_current_active_user_from_bearer_token(token)
async def test_get_current_active_user_from_bearer_token_disabled_user(admin_user):
token = create_oauth_token(
{
token = oauthh.create_oauth_token(
data={
"sub": admin_user.username,
"scopes": " ".join(admin_user.oauth_scopes),
"type": "access",
},
)
dbh.update_user(admin_user.id, {"enabled": False})
dbuserh.update_user(admin_user.id, {"enabled": False})
try:
await get_current_active_user_from_bearer_token(token)
await oauthh.get_current_active_user_from_bearer_token(token)
except HTTPException as e:
assert e.status_code == 401
assert e.detail == "Inactive user"

View File

@@ -1,69 +1,65 @@
from utils import (
parse_tags,
get_file_name_with_no_tags as gfnwt,
parse_file_extension as gfe,
)
from handler import fsromh
def test_parse_tags():
file_name = "Super Mario Bros. (World).nes"
assert parse_tags(file_name) == (["World"], "", [], [])
assert fsromh.parse_tags(file_name) == (["World"], "", [], [])
file_name = "Super Mario Bros. (W) (Rev A).nes"
assert parse_tags(file_name) == (["World"], "A", [], [])
assert fsromh.parse_tags(file_name) == (["World"], "A", [], [])
file_name = "Super Mario Bros. (USA) (Rev A) (Beta).nes"
assert parse_tags(file_name) == (["USA"], "A", [], ["Beta"])
assert fsromh.parse_tags(file_name) == (["USA"], "A", [], ["Beta"])
file_name = "Super Mario Bros. (U) (Beta).nes"
assert parse_tags(file_name) == (["USA"], "", [], ["Beta"])
assert fsromh.parse_tags(file_name) == (["USA"], "", [], ["Beta"])
file_name = "Super Mario Bros. (CH) [!].nes"
assert parse_tags(file_name) == (["China"], "", [], ["!"])
assert fsromh.parse_tags(file_name) == (["China"], "", [], ["!"])
file_name = "Super Mario Bros. (reg-T) (rev-1.2).nes"
assert parse_tags(file_name) == (["Taiwan"], "1.2", [], [])
assert fsromh.parse_tags(file_name) == (["Taiwan"], "1.2", [], [])
file_name = "Super Mario Bros. (Reg S) (Rev A).nes"
assert parse_tags(file_name) == (["Spain"], "A", [], [])
assert fsromh.parse_tags(file_name) == (["Spain"], "A", [], [])
file_name = "Super Metroid (Japan, USA) (En,Ja).zip"
assert parse_tags(file_name) == (["Japan", "USA"], "", ["English", "Japanese"], [])
assert fsromh.parse_tags(file_name) == (["Japan", "USA"], "", ["English", "Japanese"], [])
def test_get_file_name_with_no_tags():
file_name = "Super Mario Bros. (World).nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "Super Mario Bros. (W) (Rev A).nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "Super Mario Bros. (USA) (Rev A) (Beta).nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "Super Mario Bros. (U) (Beta).nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "Super Mario Bros. (U) [!].nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "Super Mario Bros. (reg-T) (rev-1.2).nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "Super Mario Bros. (Reg S) (Rev A).nes"
assert gfnwt(file_name) == "Super Mario Bros."
assert fsromh.get_file_name_with_no_tags(file_name) == "Super Mario Bros."
file_name = "007 - Agent Under Fire.nkit.iso"
assert gfnwt(file_name) == "007 - Agent Under Fire"
assert fsromh.get_file_name_with_no_tags(file_name) == "007 - Agent Under Fire"
file_name = "Jimmy Houston's Bass Tournament U.S.A..zip"
assert gfnwt(file_name) == "Jimmy Houston's Bass Tournament U.S.A."
assert fsromh.get_file_name_with_no_tags(file_name) == "Jimmy Houston's Bass Tournament U.S.A."
# This is expected behavior, since the regex is aggressive
file_name = "Battle Stadium D.O.N.zip"
assert gfnwt(file_name) == "Battle Stadium D.O.N"
assert fsromh.get_file_name_with_no_tags(file_name) == "Battle Stadium D.O.N"
def test_get_file_extension():
assert gfe("Super Mario Bros. (World).nes") == "nes"
assert gfe("007 - Agent Under Fire.nkit.iso") == "nkit.iso"
assert fsromh.parse_file_extension("Super Mario Bros. (World).nes") == "nes"
assert fsromh.parse_file_extension("007 - Agent Under Fire.nkit.iso") == "nkit.iso"

View File

@@ -172,7 +172,7 @@ async function deleteRoms({
roms: Rom[];
deleteFromFs: boolean;
}): Promise<{ data: MessageResponse }> {
return api.delete("/roms", {
return api.post("/roms/delete", {
data: { roms: roms.map((r) => r.id), delete_from_fs: deleteFromFs },
});
}

View File

@@ -22,7 +22,7 @@ async function deleteSaves({
saves: SaveSchema[];
deleteFromFs: boolean;
}) {
return api.delete("/saves", {
return api.post("/saves/delete", {
data: {
saves: saves.map((s) => s.id),
delete_from_fs: deleteFromFs,

View File

@@ -23,7 +23,7 @@ async function deleteStates({
states: StateSchema[];
deleteFromFs: boolean;
}) {
return api.put("/states", {
return api.post("/states/delete", {
data: {
states: states.map((s) => s.id),
delete_from_fs: deleteFromFs,