diff --git a/backend/endpoints/responses/rom.py b/backend/endpoints/responses/rom.py index 72308842c..85c5e64d0 100644 --- a/backend/endpoints/responses/rom.py +++ b/backend/endpoints/responses/rom.py @@ -7,7 +7,6 @@ from endpoints.responses.assets import SaveSchema, ScreenshotSchema, StateSchema from fastapi import Request from fastapi.responses import StreamingResponse from handler.socket_handler import socket_handler -from handler.database import db_user_handler from handler.metadata.igdb_handler import IGDBMetadata from handler.metadata.moby_handler import MobyMetadata from pydantic import BaseModel, computed_field, Field @@ -35,15 +34,11 @@ class RomNoteSchema(BaseModel): last_edited_at: datetime raw_markdown: str is_public: bool + user__username: str class Config: from_attributes = True - @computed_field - @property - def user__username(self) -> str: - return db_user_handler.get_user(self.user_id).username - @classmethod def for_user(cls, db_rom: Rom, user_id: int) -> list["RomNoteSchema"]: return [ diff --git a/backend/handler/database/platforms_handler.py b/backend/handler/database/platforms_handler.py index 9e76d16f3..9f0cf378e 100644 --- a/backend/handler/database/platforms_handler.py +++ b/backend/handler/database/platforms_handler.py @@ -1,5 +1,6 @@ +import functools from sqlalchemy import delete, or_, select -from sqlalchemy.orm import Session, selectinload +from sqlalchemy.orm import Session, Query, selectinload from decorators.database import begin_session from models.platform import Platform @@ -8,32 +9,39 @@ from models.rom import Rom from .base_handler import DBBaseHandler +def with_roms(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = kwargs.get("session") + if session is None: + raise ValueError("session is required") + + kwargs["query"] = select(Platform).options( + selectinload(Platform.roms).load_only(Rom.id) + ) + return func(*args, **kwargs) + + return wrapper + + class DBPlatformsHandler(DBBaseHandler): @begin_session + @with_roms def add_platform( - self, platform: Platform, session: Session = None + self, platform: Platform, query: Query = None, session: Session = None ) -> Platform | None: session.merge(platform) session.flush() - return session.scalar( - select(Platform) - .options(selectinload(Platform.roms).load_only(Rom.id)) - .filter_by(id=platform.id) - .limit(1) - ) + return session.scalar(query.filter_by(id=platform.id).limit(1)) @begin_session + @with_roms def get_platforms( - self, id: int = None, session: Session = None + self, id: int = None, query: Query = None, session: Session = None ) -> list[Platform] | Platform | None: return ( - session.scalar( - select(Platform) - .options(selectinload(Platform.roms).load_only(Rom.id)) - .filter_by(id=id) - .limit(1) - ) + session.scalar(query.filter_by(id=id).limit(1)) if id else ( session.scalars(select(Platform).order_by(Platform.name.asc())) @@ -43,15 +51,11 @@ class DBPlatformsHandler(DBBaseHandler): ) @begin_session + @with_roms def get_platform_by_fs_slug( - self, fs_slug: str, session: Session = None + self, fs_slug: str, query: Query = None, session: Session = None ) -> Platform | None: - return session.scalar( - select(Platform) - .options(selectinload(Platform.roms).load_only(Rom.id)) - .filter_by(fs_slug=fs_slug) - .limit(1) - ) + return session.scalar(query.filter_by(fs_slug=fs_slug).limit(1)) @begin_session def delete_platform(self, id: int, session: Session = None) -> int: diff --git a/backend/handler/database/roms_handler.py b/backend/handler/database/roms_handler.py index 3ad6e38fd..7acfd65dd 100644 --- a/backend/handler/database/roms_handler.py +++ b/backend/handler/database/roms_handler.py @@ -1,11 +1,30 @@ +import functools from decorators.database import begin_session from models.rom import Rom, RomNote from sqlalchemy import and_, delete, func, select, update, or_, Select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, Query, selectinload from .base_handler import DBBaseHandler +def with_assets(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + session = kwargs.get("session") + if session is None: + raise ValueError("session is required") + + kwargs["query"] = select(Rom).options( + selectinload(Rom.saves), + selectinload(Rom.states), + selectinload(Rom.screenshots), + selectinload(Rom.notes), + ) + return func(*args, **kwargs) + + return wrapper + + class DBRomsHandler(DBBaseHandler): def _filter(self, data: Select[Rom], platform_id: int | None, search_term: str): if platform_id: @@ -33,10 +52,15 @@ class DBRomsHandler(DBBaseHandler): return data.order_by(_column.asc()) @begin_session - def add_rom(self, rom: Rom, session: Session = None): - return session.merge(rom) + @with_assets + def add_rom(self, rom: Rom, query: Query = None, session: Session = None): + session.merge(rom) + session.flush() + + return session.scalar(query.filter_by(id=rom.id).limit(1)) @begin_session + @with_assets def get_roms( self, id: int = None, @@ -44,10 +68,11 @@ class DBRomsHandler(DBBaseHandler): search_term: str = "", order_by: str = "name", order_dir: str = "asc", + query: Query = None, session: Session = None, ): return ( - session.get(Rom, id) + session.scalar(query.filter_by(id=id).limit(1)) if id else self._order( self._filter(select(Rom), platform_id, search_term), @@ -57,28 +82,35 @@ class DBRomsHandler(DBBaseHandler): ) @begin_session + @with_assets def get_rom_by_filename( - self, platform_id: int, file_name: str, session: Session = None + self, + platform_id: int, + file_name: str, + query: Query = None, + session: Session = None, ): - return session.scalars( - select(Rom).filter_by(platform_id=platform_id, file_name=file_name).limit(1) - ).first() + return session.scalar( + query.filter_by(platform_id=platform_id, file_name=file_name).limit(1) + ) @begin_session + @with_assets def get_rom_by_filename_no_tags( - self, file_name_no_tags: str, session: Session = None + self, file_name_no_tags: str, query: Query = None, session: Session = None ): - return session.scalars( - select(Rom).filter_by(file_name_no_tags=file_name_no_tags).limit(1) - ).first() + return session.scalar( + query.filter_by(file_name_no_tags=file_name_no_tags).limit(1) + ) @begin_session + @with_assets def get_rom_by_filename_no_ext( - self, file_name_no_ext: str, session: Session = None + self, file_name_no_ext: str, query: Query = None, session: Session = None ): - return session.scalars( - select(Rom).filter_by(file_name_no_ext=file_name_no_ext).limit(1) - ).first() + return session.scalar( + query.filter_by(file_name_no_ext=file_name_no_ext).limit(1) + ) @begin_session def update_rom(self, id: int, data: dict, session: Session = None): diff --git a/backend/handler/database/users_handler.py b/backend/handler/database/users_handler.py index 6a4644f1f..28d6acb8b 100644 --- a/backend/handler/database/users_handler.py +++ b/backend/handler/database/users_handler.py @@ -13,9 +13,7 @@ class DBUsersHandler(DBBaseHandler): @begin_session def get_user_by_username(self, username: str, session: Session = None): - return session.scalars( - select(User).filter_by(username=username).limit(1) - ).first() + return session.scalar(select(User).filter_by(username=username).limit(1)) @begin_session def get_user(self, id: int, session: Session = None): @@ -30,6 +28,10 @@ class DBUsersHandler(DBBaseHandler): .execution_options(synchronize_session="evaluate") ) + @begin_session + def get_users(self, session: Session = None): + return session.scalars(select(User)).all() + @begin_session def delete_user(self, id: int, session: Session = None): return session.execute( @@ -38,10 +40,6 @@ class DBUsersHandler(DBBaseHandler): .execution_options(synchronize_session="evaluate") ) - @begin_session - def get_users(self, session: Session = None): - return session.scalars(select(User)).all() - @begin_session def get_admin_users(self, session: Session = None): return session.scalars(select(User).filter_by(role=Role.ADMIN)).all() diff --git a/backend/logger/logger.py b/backend/logger/logger.py index 72a4ec10a..6280e4b2d 100644 --- a/backend/logger/logger.py +++ b/backend/logger/logger.py @@ -7,10 +7,15 @@ from logger.stdout_formatter import StdoutFormatter log = logging.getLogger("romm") log.setLevel(logging.DEBUG) +# Set up sqlachemy logger +sql_log = logging.getLogger("sqlalchemy.engine") +sql_log.setLevel(logging.DEBUG) + # Define stdout handler stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(StdoutFormatter()) log.addHandler(stdout_handler) +sql_log.addHandler(stdout_handler) # Hush passlib warnings logging.getLogger("passlib").setLevel(logging.ERROR) diff --git a/backend/models/assets.py b/backend/models/assets.py index 4a95b6c6b..a1596a4e1 100644 --- a/backend/models/assets.py +++ b/backend/models/assets.py @@ -52,8 +52,8 @@ class Save(RomAsset): emulator = Column(String(length=50), nullable=True) - rom = relationship("Rom", lazy="selectin", back_populates="saves") - user = relationship("User", lazy="selectin", back_populates="saves") + rom = relationship("Rom", lazy="joined", back_populates="saves") + user = relationship("User", lazy="joined", back_populates="saves") @cached_property def screenshot(self) -> Optional["Screenshot"]: @@ -73,8 +73,8 @@ class State(RomAsset): emulator = Column(String(length=50), nullable=True) - rom = relationship("Rom", lazy="selectin", back_populates="states") - user = relationship("User", lazy="selectin", back_populates="states") + rom = relationship("Rom", lazy="joined", back_populates="states") + user = relationship("User", lazy="joined", back_populates="states") @cached_property def screenshot(self) -> Optional["Screenshot"]: @@ -92,5 +92,5 @@ class Screenshot(RomAsset): __tablename__ = "screenshots" __table_args__ = {"extend_existing": True} - rom = relationship("Rom", lazy="selectin", back_populates="screenshots") - user = relationship("User", lazy="selectin", back_populates="screenshots") + rom = relationship("Rom", lazy="joined", back_populates="screenshots") + user = relationship("User", lazy="joined", back_populates="screenshots") diff --git a/backend/models/rom.py b/backend/models/rom.py index a74441883..987fe94be 100644 --- a/backend/models/rom.py +++ b/backend/models/rom.py @@ -70,18 +70,13 @@ class Rom(BaseModel): saves: Mapped[list[Save]] = relationship( "Save", - lazy="selectin", back_populates="rom", ) - states: Mapped[list[State]] = relationship( - "State", lazy="selectin", back_populates="rom" - ) + states: Mapped[list[State]] = relationship("State", back_populates="rom") screenshots: Mapped[list[Screenshot]] = relationship( - "Screenshot", lazy="selectin", back_populates="rom" - ) - notes: Mapped[list["RomNote"]] = relationship( - "RomNote", lazy="selectin", back_populates="rom" + "Screenshot", back_populates="rom" ) + notes: Mapped[list["RomNote"]] = relationship("RomNote", back_populates="rom") @property def platform_slug(self) -> str: @@ -189,5 +184,9 @@ class RomNote(BaseModel): nullable=False, ) - rom = relationship("Rom", back_populates="notes") - user = relationship("User", back_populates="notes") + rom = relationship("Rom", lazy="joined", back_populates="notes") + user = relationship("User", lazy="joined", back_populates="notes") + + @property + def user__username(self) -> str: + return self.user.username diff --git a/backend/models/user.py b/backend/models/user.py index 05ff405cf..629a36fc8 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -32,18 +32,13 @@ class User(BaseModel, SimpleUser): saves: Mapped[list[Save]] = relationship( "Save", - lazy="selectin", back_populates="user", ) - states: Mapped[list[State]] = relationship( - "State", lazy="selectin", back_populates="user" - ) + states: Mapped[list[State]] = relationship("State", back_populates="user") screenshots: Mapped[list[Screenshot]] = relationship( - "Screenshot", lazy="selectin", back_populates="user" - ) - notes: Mapped[list[RomNote]] = relationship( - "RomNote", lazy="selectin", back_populates="user" + "Screenshot", back_populates="user" ) + notes: Mapped[list[RomNote]] = relationship("RomNote", back_populates="user") @property def oauth_scopes(self):