complete updating the endpoints and models

This commit is contained in:
Georges-Antoine Assi
2024-12-20 22:41:56 -05:00
parent 0850c0cbcf
commit 3fcce6606c
26 changed files with 332 additions and 201 deletions

View File

@@ -1,6 +1,8 @@
from typing import Sequence
from decorators.database import begin_session
from models.collection import Collection
from sqlalchemy import Select, delete, select, update
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
from .base_handler import DBBaseHandler
@@ -10,10 +12,17 @@ class DBCollectionsHandler(DBBaseHandler):
@begin_session
def add_collection(
self, collection: Collection, session: Session = None
) -> Collection | None:
) -> Collection:
collection = session.merge(collection)
session.flush()
return session.scalar(select(Collection).filter_by(id=collection.id).limit(1))
new_collection = session.scalar(
select(Collection).filter_by(id=collection.id).limit(1)
)
if not new_collection:
raise ValueError("Could not find newly created collection")
return new_collection
@begin_session
def get_collection(self, id: int, session: Session = None) -> Collection | None:
@@ -28,9 +37,9 @@ class DBCollectionsHandler(DBBaseHandler):
)
@begin_session
def get_collections(self, session: Session = None) -> Select[tuple[Collection]]:
def get_collections(self, session: Session = None) -> Sequence[Collection]:
return (
session.scalars(select(Collection).order_by(Collection.name.asc())) # type: ignore[attr-defined]
session.scalars(select(Collection).order_by(Collection.name.asc()))
.unique()
.all()
)
@@ -38,7 +47,7 @@ class DBCollectionsHandler(DBBaseHandler):
@begin_session
def get_collections_by_rom_id(
self, rom_id: int, session: Session = None
) -> list[Collection]:
) -> Sequence[Collection]:
return session.scalars(
select(Collection).filter(Collection.roms.contains(rom_id))
).all()
@@ -56,8 +65,8 @@ class DBCollectionsHandler(DBBaseHandler):
return session.query(Collection).filter_by(id=id).one()
@begin_session
def delete_collection(self, id: int, session: Session = None) -> int:
return session.execute(
def delete_collection(self, id: int, session: Session = None) -> None:
session.execute(
delete(Collection)
.where(Collection.id == id)
.execution_options(synchronize_session="evaluate")

View File

@@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.firmware import Firmware
from sqlalchemy import and_, delete, select, update
@@ -26,7 +28,7 @@ class DBFirmwareHandler(DBBaseHandler):
*,
platform_id: int | None = None,
session: Session = None,
) -> list[Firmware]:
) -> Sequence[Firmware]:
return session.scalars(
select(Firmware)
.filter_by(platform_id=platform_id)
@@ -45,7 +47,7 @@ class DBFirmwareHandler(DBBaseHandler):
@begin_session
def update_firmware(self, id: int, data: dict, session: Session = None) -> Firmware:
return session.execute(
return session.scalar(
update(Firmware)
.where(Firmware.id == id)
.values(**data)
@@ -54,7 +56,7 @@ class DBFirmwareHandler(DBBaseHandler):
@begin_session
def delete_firmware(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(Firmware)
.where(Firmware.id == id)
.execution_options(synchronize_session="evaluate")
@@ -63,7 +65,7 @@ class DBFirmwareHandler(DBBaseHandler):
@begin_session
def purge_firmware(
self, platform_id: int, fs_firmwares: list[str], session: Session = None
) -> None:
) -> Sequence[Firmware]:
purged_firmware = (
session.scalars(
select(Firmware)

View File

@@ -1,7 +1,9 @@
from typing import Sequence
from decorators.database import begin_session
from models.platform import Platform
from models.rom import Rom
from sqlalchemy import Select, delete, or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from .base_handler import DBBaseHandler
@@ -21,7 +23,7 @@ class DBPlatformsHandler(DBBaseHandler):
select(Platform).filter_by(id=platform.id).limit(1)
)
if not new_platform:
raise ValueError("Could not find newlyewly created platform")
raise ValueError("Could not find newly created platform")
return new_platform
@@ -30,9 +32,9 @@ class DBPlatformsHandler(DBBaseHandler):
return session.scalar(select(Platform).filter_by(id=id).limit(1))
@begin_session
def get_platforms(self, *, session: Session = None) -> Select[tuple[Platform]]:
def get_platforms(self, *, session: Session = None) -> Sequence[Platform]:
return (
session.scalars(select(Platform).order_by(Platform.name.asc())) # type: ignore[attr-defined]
session.scalars(select(Platform).order_by(Platform.name.asc()))
.unique()
.all()
)
@@ -61,7 +63,7 @@ class DBPlatformsHandler(DBBaseHandler):
@begin_session
def purge_platforms(
self, fs_platforms: list[str], session: Session = None
) -> Select[tuple[Platform]]:
) -> Sequence[Platform]:
purged_platforms = (
session.scalars(
select(Platform)

View File

@@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.assets import Save
from sqlalchemy import and_, delete, select, update
@@ -12,7 +14,7 @@ class DBSavesHandler(DBBaseHandler):
return session.merge(save)
@begin_session
def get_save(self, id: int, session: Session = None) -> Save:
def get_save(self, id: int, session: Session = None) -> Save | None:
return session.get(Save, id)
@begin_session
@@ -27,7 +29,7 @@ class DBSavesHandler(DBBaseHandler):
@begin_session
def update_save(self, id: int, data: dict, session: Session = None) -> Save:
return session.execute(
return session.scalar(
update(Save)
.where(Save.id == id)
.values(**data)
@@ -36,7 +38,7 @@ class DBSavesHandler(DBBaseHandler):
@begin_session
def delete_save(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(Save)
.where(Save.id == id)
.execution_options(synchronize_session="evaluate")
@@ -45,8 +47,18 @@ class DBSavesHandler(DBBaseHandler):
@begin_session
def purge_saves(
self, rom_id: int, user_id: int, saves: list[str], session: Session = None
) -> None:
return session.execute(
) -> Sequence[Save]:
purged_saves = session.scalars(
select(Save).filter(
and_(
Save.rom_id == rom_id,
Save.user_id == user_id,
Save.file_name.not_in(saves),
)
)
).all()
session.execute(
delete(Save)
.where(
and_(
@@ -57,3 +69,5 @@ class DBSavesHandler(DBBaseHandler):
)
.execution_options(synchronize_session="evaluate")
)
return purged_saves

View File

@@ -1,6 +1,8 @@
from typing import Sequence
from decorators.database import begin_session
from models.assets import Screenshot
from sqlalchemy import delete, select, update
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session
from .base_handler import DBBaseHandler
@@ -14,7 +16,7 @@ class DBScreenshotsHandler(DBBaseHandler):
return session.merge(screenshot)
@begin_session
def get_screenshot(self, id, session: Session = None) -> Screenshot:
def get_screenshot(self, id, session: Session = None) -> Screenshot | None:
return session.get(Screenshot, id)
@begin_session
@@ -31,7 +33,7 @@ class DBScreenshotsHandler(DBBaseHandler):
def update_screenshot(
self, id: int, data: dict, session: Session = None
) -> Screenshot:
return session.execute(
return session.scalar(
update(Screenshot)
.where(Screenshot.id == id)
.values(**data)
@@ -40,7 +42,7 @@ class DBScreenshotsHandler(DBBaseHandler):
@begin_session
def delete_screenshot(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(Screenshot)
.where(Screenshot.id == id)
.execution_options(synchronize_session="evaluate")
@@ -49,13 +51,27 @@ class DBScreenshotsHandler(DBBaseHandler):
@begin_session
def purge_screenshots(
self, rom_id: int, user_id: int, screenshots: list[str], session: Session = None
) -> None:
return session.execute(
) -> Sequence[Screenshot]:
purged_screenshots = session.scalars(
select(Screenshot).filter(
and_(
Screenshot.rom_id == rom_id,
Screenshot.user_id == user_id,
Screenshot.file_name.not_in(screenshots),
)
)
).all()
session.execute(
delete(Screenshot)
.where(
Screenshot.rom_id == rom_id,
Screenshot.user_id == user_id,
Screenshot.file_name.not_in(screenshots),
and_(
Screenshot.rom_id == rom_id,
Screenshot.user_id == user_id,
Screenshot.file_name.not_in(screenshots),
)
)
.execution_options(synchronize_session="evaluate")
)
return purged_screenshots

View File

@@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.assets import State
from sqlalchemy import and_, delete, select, update
@@ -12,7 +14,7 @@ class DBStatesHandler(DBBaseHandler):
return session.merge(state)
@begin_session
def get_state(self, id: int, session: Session = None) -> State:
def get_state(self, id: int, session: Session = None) -> State | None:
return session.get(State, id)
@begin_session
@@ -27,7 +29,7 @@ class DBStatesHandler(DBBaseHandler):
@begin_session
def update_state(self, id: int, data: dict, session: Session = None) -> State:
return session.execute(
return session.scalar(
update(State)
.where(State.id == id)
.values(**data)
@@ -36,7 +38,7 @@ class DBStatesHandler(DBBaseHandler):
@begin_session
def delete_state(self, id: int, session: Session = None) -> None:
return session.execute(
session.execute(
delete(State)
.where(State.id == id)
.execution_options(synchronize_session="evaluate")
@@ -45,8 +47,18 @@ class DBStatesHandler(DBBaseHandler):
@begin_session
def purge_states(
self, rom_id: int, user_id: int, states: list[str], session: Session = None
) -> None:
return session.execute(
) -> Sequence[State]:
purged_states = session.scalars(
select(State).filter(
and_(
State.rom_id == rom_id,
State.user_id == user_id,
State.file_name.not_in(states),
)
)
).all()
session.execute(
delete(State)
.where(
and_(
@@ -57,3 +69,5 @@ class DBStatesHandler(DBBaseHandler):
)
.execution_options(synchronize_session="evaluate")
)
return purged_states

View File

@@ -11,25 +11,28 @@ class DBStatsHandler(DBBaseHandler):
@begin_session
def get_platforms_count(self, session: Session = None) -> int:
"""Get the number of platforms with any roms."""
return session.scalar(
select(func.count(distinct(Rom.platform_id))).select_from(Rom)
return (
session.scalar(
select(func.count(distinct(Rom.platform_id))).select_from(Rom)
)
or 0
)
@begin_session
def get_roms_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Rom))
return session.scalar(select(func.count()).select_from(Rom)) or 0
@begin_session
def get_saves_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Save))
return session.scalar(select(func.count()).select_from(Save)) or 0
@begin_session
def get_states_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(State))
return session.scalar(select(func.count()).select_from(State)) or 0
@begin_session
def get_screenshots_count(self, session: Session = None) -> int:
return session.scalar(select(func.count()).select_from(Screenshot))
return session.scalar(select(func.count()).select_from(Screenshot)) or 0
@begin_session
def get_total_filesize(self, session: Session = None) -> int:

View File

@@ -1,3 +1,5 @@
from typing import Sequence
from decorators.database import begin_session
from models.user import Role, User
from sqlalchemy import delete, select, update
@@ -27,7 +29,7 @@ class DBUsersHandler(DBBaseHandler):
@begin_session
def update_user(self, id: int, data: dict, session: Session = None) -> User:
return session.execute(
return session.scalar(
update(User)
.where(User.id == id)
.values(**data)
@@ -35,7 +37,7 @@ class DBUsersHandler(DBBaseHandler):
)
@begin_session
def get_users(self, session: Session = None) -> list[User]:
def get_users(self, session: Session = None) -> Sequence[User]:
return session.scalars(select(User)).all()
@begin_session
@@ -47,5 +49,5 @@ class DBUsersHandler(DBBaseHandler):
)
@begin_session
def get_admin_users(self, session: Session = None) -> list[User]:
def get_admin_users(self, session: Session = None) -> Sequence[User]:
return session.scalars(select(User).filter_by(role=Role.ADMIN)).all()

View File

@@ -95,7 +95,7 @@ class FSResourcesHandler(FSHandler):
return ""
async def get_cover(
self, entity: Rom | Collection | None, overwrite: bool, url_cover: str = ""
self, entity: Rom | Collection | None, overwrite: bool, url_cover: str | None
) -> tuple[str, str]:
if not entity:
return "", ""
@@ -192,9 +192,9 @@ class FSResourcesHandler(FSHandler):
return f"{rom.fs_resources_path}/screenshots/{idx}.jpg"
async def get_rom_screenshots(
self, rom: Rom | None, url_screenshots: list
self, rom: Rom | None, url_screenshots: list | None
) -> list[str]:
if not rom:
if not rom or not url_screenshots:
return []
path_screenshots: list[str] = []

View File

@@ -8,7 +8,7 @@ import tarfile
import zipfile
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Any, Final, TypedDict
from typing import Any, Final, Literal, TypedDict
import magic
import py7zr
@@ -59,7 +59,7 @@ FILE_READ_CHUNK_SIZE = 1024 * 8
class FSRom(TypedDict):
multi: bool
file_name: str
fs_name: str
files: list[RomFile]
@@ -90,7 +90,9 @@ def read_zip_file(file_path: Path) -> Iterator[bytes]:
yield chunk
def read_tar_file(file_path: Path, mode: str = "r") -> Iterator[bytes]:
def read_tar_file(
file_path: Path, mode: Literal["r", "r:*", "r:", "r:gz", "r:bz2", "r:xz"] = "r"
) -> Iterator[bytes]:
try:
with tarfile.open(file_path, mode) as f:
for member in f.getmembers():
@@ -339,10 +341,10 @@ class FSRomsHandler(FSHandler):
raise RomsNotFoundException(platform_fs_slug) from exc
fs_roms: list[dict] = [
{"multi": False, "file_name": rom}
{"multi": False, "fs_name": rom}
for rom in self._exclude_files(fs_single_roms, "single")
] + [
{"multi": True, "file_name": rom}
{"multi": True, "fs_name": rom}
for rom in self._exclude_multi_roms(fs_multi_roms)
]
@@ -350,12 +352,12 @@ class FSRomsHandler(FSHandler):
[
FSRom(
multi=rom["multi"],
file_name=rom["file_name"],
files=self.get_rom_files(rom["file_name"], roms_file_path),
fs_name=rom["fs_name"],
files=self.get_rom_files(rom["fs_name"], roms_file_path),
)
for rom in fs_roms
],
key=lambda rom: rom["file_name"],
key=lambda rom: rom["fs_name"],
)
def file_exists(self, path: str, file_name: str) -> bool:

View File

@@ -567,7 +567,7 @@ class IGDBBaseHandler(MetadataHandler):
@check_twitch_token
async def get_matched_roms_by_name(
self, search_term: str, platform_igdb_id: int
self, search_term: str, platform_igdb_id: int | None
) -> list[IGDBRom]:
if not IGDB_API_ENABLED:
return []

View File

@@ -310,7 +310,7 @@ class MobyGamesHandler(MetadataHandler):
return [rom] if rom["moby_id"] else []
async def get_matched_roms_by_name(
self, search_term: str, platform_moby_id: int
self, search_term: str, platform_moby_id: int | None
) -> list[MobyGamesRom]:
if not MOBY_API_ENABLED:
return []