diff --git a/backend/endpoints/saves.py b/backend/endpoints/saves.py index c97cebc4a..638bd56fb 100644 --- a/backend/endpoints/saves.py +++ b/backend/endpoints/saves.py @@ -107,8 +107,10 @@ async def add_save( platform_fs_slug=rom.platform_slug, rom_id=rom.id, ) - db_screenshot = db_screenshot_handler.get_screenshot_by_filename( - rom_id=rom.id, user_id=request.user.id, file_name=screenshotFile.filename + db_screenshot = db_screenshot_handler.get_screenshot( + filename=screenshotFile.filename, + rom_id=rom.id, + user_id=request.user.id, ) if db_screenshot: db_screenshot = db_screenshot_handler.update_screenshot( @@ -193,10 +195,10 @@ async def update_save(request: Request, id: int) -> SaveSchema: platform_fs_slug=db_save.rom.platform_slug, rom_id=db_save.rom.id, ) - db_screenshot = db_screenshot_handler.get_screenshot_by_filename( + db_screenshot = db_screenshot_handler.get_screenshot( + filename=screenshotFile.filename, rom_id=db_save.rom.id, user_id=request.user.id, - file_name=screenshotFile.filename, ) if db_screenshot: db_screenshot = db_screenshot_handler.update_screenshot( diff --git a/backend/endpoints/screenshots.py b/backend/endpoints/screenshots.py index 6ade2d397..d276b1010 100644 --- a/backend/endpoints/screenshots.py +++ b/backend/endpoints/screenshots.py @@ -65,8 +65,10 @@ async def add_screenshot( platform_fs_slug=rom.platform_slug, rom_id=rom.id, ) - db_screenshot = db_screenshot_handler.get_screenshot_by_filename( - rom_id=rom.id, user_id=current_user.id, file_name=screenshotFile.filename + db_screenshot = db_screenshot_handler.get_screenshot( + filename=screenshotFile.filename, + rom_id=rom.id, + user_id=current_user.id, ) if db_screenshot: db_screenshot = db_screenshot_handler.update_screenshot( diff --git a/backend/endpoints/states.py b/backend/endpoints/states.py index 432f25ca0..aa2bb7490 100644 --- a/backend/endpoints/states.py +++ b/backend/endpoints/states.py @@ -107,8 +107,10 @@ async def add_state( platform_fs_slug=rom.platform_slug, rom_id=rom.id, ) - db_screenshot = db_screenshot_handler.get_screenshot_by_filename( - rom_id=rom.id, user_id=request.user.id, file_name=screenshotFile.filename + db_screenshot = db_screenshot_handler.get_screenshot( + filename=screenshotFile.filename, + rom_id=rom.id, + user_id=request.user.id, ) if db_screenshot: db_screenshot = db_screenshot_handler.update_screenshot( @@ -195,10 +197,10 @@ async def update_state(request: Request, id: int) -> StateSchema: platform_fs_slug=db_state.rom.platform_slug, rom_id=db_state.rom.id, ) - db_screenshot = db_screenshot_handler.get_screenshot_by_filename( + db_screenshot = db_screenshot_handler.get_screenshot( + filename=screenshotFile.filename, rom_id=db_state.rom.id, user_id=request.user.id, - file_name=screenshotFile.filename, ) if db_screenshot: db_screenshot = db_screenshot_handler.update_screenshot( diff --git a/backend/handler/database/screenshots_handler.py b/backend/handler/database/screenshots_handler.py index 98734741d..2ccdcb131 100644 --- a/backend/handler/database/screenshots_handler.py +++ b/backend/handler/database/screenshots_handler.py @@ -1,14 +1,43 @@ from collections.abc import Sequence +from functools import partial from decorators.database import begin_session from models.assets import Screenshot -from sqlalchemy import and_, delete, select, update +from sqlalchemy import delete, select, update from sqlalchemy.orm import Session +from sqlalchemy.sql import Delete, Select, Update from .base_handler import DBBaseHandler class DBScreenshotsHandler(DBBaseHandler): + def filter[QueryT: Select[tuple[Screenshot]] | Update | Delete]( + self, + query: QueryT, + *, + filenames: Sequence[str] = (), + filenames_no_ext: Sequence[str] = (), + rom_ids: Sequence[int] = (), + user_ids: Sequence[int] = (), + exclude_filenames: Sequence[str] = (), + exclude_filenames_no_ext: Sequence[str] = (), + ) -> QueryT: + if filenames: + query = query.filter(Screenshot.file_name.in_(filenames)) + if filenames_no_ext: + query = query.filter(Screenshot.file_name_no_ext.in_(filenames_no_ext)) + if rom_ids: + query = query.filter(Screenshot.rom_id.in_(rom_ids)) + if user_ids: + query = query.filter(Screenshot.user_id.in_(user_ids)) + if exclude_filenames: + query = query.filter(Screenshot.file_name.not_in(exclude_filenames)) + if exclude_filenames_no_ext: + query = query.filter( + Screenshot.file_name_no_ext.not_in(exclude_filenames_no_ext) + ) + return query + @begin_session def add_screenshot( self, screenshot: Screenshot, session: Session = None @@ -16,18 +45,27 @@ class DBScreenshotsHandler(DBBaseHandler): return session.merge(screenshot) @begin_session - def get_screenshot(self, id, session: Session = None) -> Screenshot | None: - return session.get(Screenshot, id) + def get_screenshot( + self, + *, + filename: str | None = None, + filename_no_ext: str | None = None, + rom_id: int | None = None, + user_id: int | None = None, + session: Session = None, + ) -> Screenshot | None: + query = self.filter( + select(Screenshot), + filenames=[filename] if filename is not None else (), + filenames_no_ext=[filename_no_ext] if filename_no_ext is not None else (), + rom_ids=[rom_id] if rom_id is not None else (), + user_ids=[user_id] if user_id is not None else (), + ) + return session.scalars(query.limit(1)).first() @begin_session - def get_screenshot_by_filename( - self, rom_id: int, user_id: int, file_name: str, session: Session = None - ) -> Screenshot | None: - return session.scalars( - select(Screenshot) - .filter_by(rom_id=rom_id, user_id=user_id, file_name=file_name) - .limit(1) - ).first() + def get_screenshot_by_id(self, id, session: Session = None) -> Screenshot | None: + return session.get(Screenshot, id) @begin_session def update_screenshot( @@ -57,25 +95,17 @@ class DBScreenshotsHandler(DBBaseHandler): screenshots_to_keep: list[str], session: Session = None, ) -> Sequence[Screenshot]: - missing_screenshots = session.scalars( - select(Screenshot).filter( - and_( - Screenshot.rom_id == rom_id, - Screenshot.user_id == user_id, - Screenshot.file_name.not_in(screenshots_to_keep), - ) - ) - ).all() + query_fn = partial( + self.filter, + rom_ids=[rom_id], + user_ids=[user_id], + exclude_filenames=screenshots_to_keep, + ) + + missing_screenshots = session.scalars(query_fn(query=select(Screenshot))).all() session.execute( - update(Screenshot) - .where( - and_( - Screenshot.rom_id == rom_id, - Screenshot.user_id == user_id, - Screenshot.file_name.not_in(screenshots_to_keep), - ) - ) + query_fn(query=update(Screenshot)) .values(**{"missing_from_fs": True}) .execution_options(synchronize_session="evaluate") ) diff --git a/backend/models/assets.py b/backend/models/assets.py index ad63b336c..c017fefa3 100644 --- a/backend/models/assets.py +++ b/backend/models/assets.py @@ -59,17 +59,12 @@ class Save(RomAsset): @cached_property def screenshot(self) -> Screenshot | None: - from handler.database import db_rom_handler + from handler.database import db_screenshot_handler - db_rom = db_rom_handler.get_rom(self.rom_id) - if db_rom is None: - return None - - for screenshot in db_rom.screenshots: - if screenshot.file_name_no_ext == self.file_name_no_ext: - return screenshot - - return None + return db_screenshot_handler.get_screenshot( + filename_no_ext=self.file_name_no_ext, + rom_id=self.rom_id, + ) class State(RomAsset): @@ -83,17 +78,12 @@ class State(RomAsset): @cached_property def screenshot(self) -> Screenshot | None: - from handler.database import db_rom_handler + from handler.database import db_screenshot_handler - db_rom = db_rom_handler.get_rom(self.rom_id) - if db_rom is None: - return None - - for screenshot in db_rom.screenshots: - if screenshot.file_name_no_ext == self.file_name_no_ext: - return screenshot - - return None + return db_screenshot_handler.get_screenshot( + filename_no_ext=self.file_name_no_ext, + rom_id=self.rom_id, + ) class Screenshot(RomAsset): diff --git a/backend/tests/handler/test_db_handler.py b/backend/tests/handler/test_db_handler.py index ee9f8ee62..0520e00f5 100644 --- a/backend/tests/handler/test_db_handler.py +++ b/backend/tests/handler/test_db_handler.py @@ -196,14 +196,16 @@ def test_screenshots(screenshot: Screenshot, platform: Platform, admin_user: Use assert rom is not None assert len(rom.screenshots) == 2 - new_screenshot = db_screenshot_handler.get_screenshot(id=rom.screenshots[0].id) + new_screenshot = db_screenshot_handler.get_screenshot_by_id( + id=rom.screenshots[0].id + ) assert new_screenshot is not None assert new_screenshot.file_name == "test_screenshot.png" db_screenshot_handler.update_screenshot( new_screenshot.id, {"file_name": "test_screenshot_2.png"} ) - new_screenshot = db_screenshot_handler.get_screenshot(id=new_screenshot.id) + new_screenshot = db_screenshot_handler.get_screenshot_by_id(id=new_screenshot.id) assert new_screenshot is not None assert new_screenshot.file_name == "test_screenshot_2.png"