mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 00:27:41 +01:00
Merge pull request #2921 from rommapp/only-ids-param
New endpoints to fetch all IDs
This commit is contained in:
@@ -24,7 +24,11 @@ from handler.filesystem.base_handler import CoverSize
|
||||
from logger.formatter import BLUE
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
from models.collection import Collection, SmartCollection
|
||||
from models.collection import (
|
||||
Collection,
|
||||
SmartCollection,
|
||||
VirtualCollection,
|
||||
)
|
||||
from utils.router import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
@@ -166,7 +170,32 @@ def get_collections(
|
||||
|
||||
collections = db_collection_handler.get_collections(updated_after=updated_after)
|
||||
|
||||
return CollectionSchema.for_user(request.user.id, [c for c in collections])
|
||||
return CollectionSchema.for_user(request.user.id, collections)
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.COLLECTIONS_READ])
|
||||
def get_collection_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get collections identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: List of collection IDs
|
||||
"""
|
||||
|
||||
collections = db_collection_handler.get_collections(
|
||||
only_fields=[
|
||||
Collection.id,
|
||||
Collection.name,
|
||||
Collection.user_id,
|
||||
Collection.is_public,
|
||||
],
|
||||
)
|
||||
|
||||
return [c.id for c in collections if c.user_id == request.user.id or c.is_public]
|
||||
|
||||
|
||||
@protected_route(router.get, "/virtual", [Scope.COLLECTIONS_READ])
|
||||
@@ -184,11 +213,34 @@ def get_virtual_collections(
|
||||
list[VirtualCollectionSchema]: List of virtual collections
|
||||
"""
|
||||
|
||||
virtual_collections = db_collection_handler.get_virtual_collections(type, limit)
|
||||
virtual_collections = db_collection_handler.get_virtual_collections(
|
||||
type=type, limit=limit
|
||||
)
|
||||
|
||||
return [VirtualCollectionSchema.model_validate(vc) for vc in virtual_collections]
|
||||
|
||||
|
||||
@protected_route(router.get, "/virtual/identifiers", [Scope.COLLECTIONS_READ])
|
||||
def get_virtual_collection_identifiers(
|
||||
request: Request,
|
||||
) -> list[str]:
|
||||
"""Get virtual collections identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[str]: List of generated virtual collection IDs
|
||||
"""
|
||||
|
||||
virtual_collections = db_collection_handler.get_virtual_collections(
|
||||
type="all",
|
||||
only_fields=[VirtualCollection.name, VirtualCollection.type],
|
||||
)
|
||||
|
||||
return [s.id for s in virtual_collections]
|
||||
|
||||
|
||||
@protected_route(router.get, "/smart", [Scope.COLLECTIONS_READ])
|
||||
def get_smart_collections(
|
||||
request: Request,
|
||||
@@ -213,10 +265,29 @@ def get_smart_collections(
|
||||
request.user.id, updated_after=updated_after
|
||||
)
|
||||
|
||||
return SmartCollectionSchema.for_user(
|
||||
request.user.id, [s for s in smart_collections]
|
||||
return SmartCollectionSchema.for_user(request.user.id, smart_collections)
|
||||
|
||||
|
||||
@protected_route(router.get, "/smart/identifiers", [Scope.COLLECTIONS_READ])
|
||||
def get_smart_collection_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get smart collections identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: List of smart collection IDs
|
||||
"""
|
||||
|
||||
smart_collections = db_collection_handler.get_smart_collections(
|
||||
request.user.id,
|
||||
only_fields=[SmartCollection.id],
|
||||
)
|
||||
|
||||
return [s.id for s in smart_collections]
|
||||
|
||||
|
||||
@protected_route(router.get, "/{id}", [Scope.COLLECTIONS_READ])
|
||||
def get_collection(request: Request, id: int) -> CollectionSchema:
|
||||
|
||||
@@ -125,6 +125,24 @@ def get_platform_firmware(
|
||||
]
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.FIRMWARE_READ])
|
||||
def get_firmware_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get firmware identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: List of firmware IDs
|
||||
"""
|
||||
firmware = db_firmware_handler.list_firmware(
|
||||
only_fields=[Firmware.id],
|
||||
)
|
||||
return [f.id for f in firmware]
|
||||
|
||||
|
||||
@protected_route(
|
||||
router.get,
|
||||
"/{id}",
|
||||
|
||||
@@ -16,6 +16,7 @@ from handler.scan_handler import scan_platform
|
||||
from logger.formatter import BLUE
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
from models.platform import Platform
|
||||
from utils.platforms import get_supported_platforms
|
||||
from utils.router import APIRouter
|
||||
|
||||
@@ -66,6 +67,18 @@ def get_platforms(
|
||||
]
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.PLATFORMS_READ])
|
||||
def get_platform_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Retrieve platform identifiers."""
|
||||
|
||||
platforms = db_platform_handler.get_platforms(
|
||||
only_fields=[Platform.id],
|
||||
)
|
||||
return [p.id for p in platforms]
|
||||
|
||||
|
||||
@protected_route(router.get, "/supported", [Scope.PLATFORMS_READ])
|
||||
def get_supported_platforms_endpoint(request: Request) -> list[PlatformSchema]:
|
||||
"""Retrieve the list of supported platforms."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
@@ -34,7 +35,7 @@ class CollectionSchema(BaseCollectionSchema):
|
||||
|
||||
@classmethod
|
||||
def for_user(
|
||||
cls, user_id: int, collections: list["Collection"]
|
||||
cls, user_id: int, collections: Sequence["Collection"]
|
||||
) -> list["CollectionSchema"]:
|
||||
return [
|
||||
cls.model_validate(c)
|
||||
@@ -67,7 +68,7 @@ class SmartCollectionSchema(BaseCollectionSchema):
|
||||
|
||||
@classmethod
|
||||
def for_user(
|
||||
cls, user_id: int, smart_collections: list["SmartCollection"]
|
||||
cls, user_id: int, smart_collections: Sequence["SmartCollection"]
|
||||
) -> list["SmartCollectionSchema"]:
|
||||
"""Filter smart collections visible to user and create schemas."""
|
||||
return [
|
||||
|
||||
@@ -66,7 +66,7 @@ from handler.metadata.ss_handler import get_preferred_media_types
|
||||
from logger.formatter import BLUE
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
from models.rom import Rom
|
||||
from models.rom import Rom, RomNote
|
||||
from utils.database import safe_int, safe_str_to_bool
|
||||
from utils.filesystem import sanitize_filename
|
||||
from utils.hashing import crc32_to_hex
|
||||
@@ -517,6 +517,19 @@ def get_roms(
|
||||
)
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.ROMS_READ])
|
||||
def get_rom_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Retrieve rom identifiers."""
|
||||
db_roms = db_rom_handler.get_roms_scalar(
|
||||
user_id=request.user.id,
|
||||
only_fields=[Rom.id],
|
||||
)
|
||||
|
||||
return [r.id for r in db_roms]
|
||||
|
||||
|
||||
@protected_route(
|
||||
router.get,
|
||||
"/download",
|
||||
@@ -1637,8 +1650,6 @@ async def get_rom_notes(
|
||||
tags: list[str] = DEFAULT_TAGS,
|
||||
) -> list[UserNoteSchema]:
|
||||
"""Get all notes for a ROM."""
|
||||
from handler.database import db_rom_handler
|
||||
|
||||
rom = db_rom_handler.get_rom(id)
|
||||
if not rom:
|
||||
raise RomNotFoundInDatabaseException(id)
|
||||
@@ -1657,6 +1668,30 @@ async def get_rom_notes(
|
||||
return [UserNoteSchema.model_validate(note) for note in notes]
|
||||
|
||||
|
||||
@protected_route(
|
||||
router.get,
|
||||
"/{id}/notes/identifiers",
|
||||
[Scope.ROMS_READ],
|
||||
responses={status.HTTP_404_NOT_FOUND: {}},
|
||||
)
|
||||
async def get_rom_note_identifiers(
|
||||
request: Request,
|
||||
id: Annotated[int, PathVar(description="Rom internal id.", ge=1)],
|
||||
) -> list[int]:
|
||||
"""Get all note identifiers for a ROM."""
|
||||
rom = db_rom_handler.get_rom(id)
|
||||
if not rom:
|
||||
raise RomNotFoundInDatabaseException(id)
|
||||
|
||||
notes = db_rom_handler.get_rom_notes(
|
||||
rom_id=id,
|
||||
user_id=request.user.id,
|
||||
only_fields=[RomNote.id],
|
||||
)
|
||||
|
||||
return [note.id for note in notes]
|
||||
|
||||
|
||||
@protected_route(
|
||||
router.post,
|
||||
"/{id}/notes",
|
||||
@@ -1669,8 +1704,6 @@ async def create_rom_note(
|
||||
note_data: Annotated[dict, Body()],
|
||||
) -> UserNoteSchema:
|
||||
"""Create a new note for a ROM."""
|
||||
from handler.database import db_rom_handler
|
||||
|
||||
rom = db_rom_handler.get_rom(id)
|
||||
if not rom:
|
||||
raise RomNotFoundInDatabaseException(id)
|
||||
@@ -1702,8 +1735,6 @@ async def update_rom_note(
|
||||
note_data: Annotated[dict, Body()],
|
||||
) -> UserNoteSchema:
|
||||
"""Update a ROM note."""
|
||||
from handler.database import db_rom_handler
|
||||
|
||||
note = db_rom_handler.update_rom_note(
|
||||
note_id=note_id,
|
||||
user_id=request.user.id,
|
||||
@@ -1736,8 +1767,6 @@ async def delete_rom_note(
|
||||
note_id: Annotated[int, PathVar(description="Note id.", ge=1)],
|
||||
) -> dict:
|
||||
"""Delete a ROM note."""
|
||||
from handler.database import db_rom_handler
|
||||
|
||||
success = db_rom_handler.delete_rom_note(note_id=note_id, user_id=request.user.id)
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -13,6 +13,7 @@ from handler.scan_handler import scan_save, scan_screenshot
|
||||
from logger.formatter import BLUE
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
from models.assets import Save
|
||||
from utils.router import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
@@ -151,6 +152,26 @@ def get_saves(
|
||||
return [SaveSchema.model_validate(save) for save in saves]
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
|
||||
def get_save_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get save identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: List of save IDs
|
||||
"""
|
||||
saves = db_save_handler.get_saves(
|
||||
user_id=request.user.id,
|
||||
only_fields=[Save.id],
|
||||
)
|
||||
|
||||
return [save.id for save in saves]
|
||||
|
||||
|
||||
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
|
||||
def get_save(request: Request, id: int) -> SaveSchema:
|
||||
save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
|
||||
@@ -13,6 +13,7 @@ from handler.scan_handler import scan_screenshot, scan_state
|
||||
from logger.formatter import BLUE
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
from models.assets import State
|
||||
from utils.router import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
@@ -153,6 +154,26 @@ def get_states(
|
||||
return [StateSchema.model_validate(state) for state in states]
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
|
||||
def get_state_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get state identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: List of state IDs
|
||||
"""
|
||||
states = db_state_handler.get_states(
|
||||
user_id=request.user.id,
|
||||
only_fields=[State.id],
|
||||
)
|
||||
|
||||
return [state.id for state in states]
|
||||
|
||||
|
||||
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
|
||||
def get_state(request: Request, id: int) -> StateSchema:
|
||||
state = db_state_handler.get_state(user_id=request.user.id, id=id)
|
||||
|
||||
@@ -212,6 +212,23 @@ def get_users(request: Request) -> list[UserSchema]:
|
||||
return [UserSchema.model_validate(u) for u in db_user_handler.get_users()]
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.USERS_READ])
|
||||
def get_user_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get all user identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: All user ids stored in the RomM's database
|
||||
"""
|
||||
|
||||
users = db_user_handler.get_users(only_fields=[User.id])
|
||||
return [u.id for u in users]
|
||||
|
||||
|
||||
@protected_route(router.get, "/me", [Scope.ME_READ])
|
||||
def get_current_user(request: Request) -> UserSchema | None:
|
||||
"""Get current user endpoint
|
||||
|
||||
@@ -4,7 +4,14 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, insert, literal, or_, select, update
|
||||
from sqlalchemy.orm import Query, Session, noload, selectinload
|
||||
from sqlalchemy.orm import (
|
||||
Query,
|
||||
QueryableAttribute,
|
||||
Session,
|
||||
load_only,
|
||||
noload,
|
||||
selectinload,
|
||||
)
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.collection import (
|
||||
@@ -87,11 +94,16 @@ class DBCollectionsHandler(DBBaseHandler):
|
||||
def get_collections(
|
||||
self,
|
||||
updated_after: datetime | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
query: Query = None, # type: ignore
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[Collection]:
|
||||
if updated_after:
|
||||
query = query.filter(Collection.updated_at > updated_after)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query.order_by(Collection.name.asc())).unique().all()
|
||||
|
||||
@begin_session
|
||||
@@ -157,19 +169,21 @@ class DBCollectionsHandler(DBBaseHandler):
|
||||
self,
|
||||
type: str,
|
||||
limit: int | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[VirtualCollection]:
|
||||
return (
|
||||
session.scalars(
|
||||
select(VirtualCollection)
|
||||
.filter(or_(VirtualCollection.type == type, literal(type == "all")))
|
||||
.limit(limit)
|
||||
.order_by(VirtualCollection.name.asc())
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
query = (
|
||||
select(VirtualCollection)
|
||||
.filter(or_(VirtualCollection.type == type, literal(type == "all")))
|
||||
.limit(limit)
|
||||
.order_by(VirtualCollection.name.asc())
|
||||
)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query).unique().all()
|
||||
|
||||
# Smart collections
|
||||
@begin_session
|
||||
def add_smart_collection(
|
||||
@@ -206,6 +220,7 @@ class DBCollectionsHandler(DBBaseHandler):
|
||||
self,
|
||||
user_id: int | None = None,
|
||||
updated_after: datetime | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[SmartCollection]:
|
||||
query = select(SmartCollection).order_by(SmartCollection.name.asc())
|
||||
@@ -219,6 +234,9 @@ class DBCollectionsHandler(DBBaseHandler):
|
||||
if updated_after:
|
||||
query = query.filter(SmartCollection.updated_at > updated_after)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query).unique().all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import QueryableAttribute, Session, load_only
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.firmware import Firmware
|
||||
@@ -32,6 +32,7 @@ class DBFirmwareHandler(DBBaseHandler):
|
||||
self,
|
||||
*,
|
||||
platform_id: int | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[Firmware]:
|
||||
query = select(Firmware).order_by(Firmware.file_name.asc())
|
||||
@@ -39,6 +40,9 @@ class DBFirmwareHandler(DBBaseHandler):
|
||||
if platform_id:
|
||||
query = query.filter_by(platform_id=platform_id)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query).all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from sqlalchemy.orm import Query, Session, selectinload
|
||||
from sqlalchemy.orm import Query, QueryableAttribute, Session, load_only, selectinload
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.platform import Platform
|
||||
@@ -67,11 +67,16 @@ class DBPlatformsHandler(DBBaseHandler):
|
||||
def get_platforms(
|
||||
self,
|
||||
updated_after: datetime | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
query: Query = None, # type: ignore
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[Platform]:
|
||||
if updated_after:
|
||||
query = query.filter(Platform.updated_at > updated_after)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query.order_by(Platform.name.asc())).unique().all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -21,7 +21,15 @@ from sqlalchemy import (
|
||||
text,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.orm import Query, Session, joinedload, noload, selectinload
|
||||
from sqlalchemy.orm import (
|
||||
Query,
|
||||
QueryableAttribute,
|
||||
Session,
|
||||
joinedload,
|
||||
load_only,
|
||||
noload,
|
||||
selectinload,
|
||||
)
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
|
||||
@@ -140,32 +148,6 @@ def with_details(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_simple(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["query"] = select(Rom).options(
|
||||
# Ensure platform is loaded for main ROM objects
|
||||
selectinload(Rom.platform),
|
||||
# Display properties for the current user (last_played)
|
||||
selectinload(Rom.rom_users).options(noload(RomUser.rom)),
|
||||
# Sort table by metadata (first_release_date)
|
||||
selectinload(Rom.metadatum).options(noload(RomMetadata.rom)),
|
||||
# Required for multi-file ROM actions and 3DS QR code
|
||||
selectinload(Rom.files).options(
|
||||
joinedload(RomFile.rom).load_only(Rom.fs_path, Rom.fs_name)
|
||||
),
|
||||
# Show sibling rom badges on cards
|
||||
selectinload(Rom.sibling_roms).options(
|
||||
noload(Rom.platform), noload(Rom.metadatum)
|
||||
),
|
||||
# Show notes indicator on cards
|
||||
selectinload(Rom.notes),
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class DBRomsHandler(DBBaseHandler):
|
||||
@begin_session
|
||||
@with_details
|
||||
@@ -535,6 +517,25 @@ class DBRomsHandler(DBBaseHandler):
|
||||
) -> Query[Rom]:
|
||||
from handler.scan_handler import MetadataSource
|
||||
|
||||
query = query.options(
|
||||
# Ensure platform is loaded for main ROM objects
|
||||
selectinload(Rom.platform),
|
||||
# Display properties for the current user (last_played)
|
||||
selectinload(Rom.rom_users).options(noload(RomUser.rom)),
|
||||
# Sort table by metadata (first_release_date)
|
||||
selectinload(Rom.metadatum).options(noload(RomMetadata.rom)),
|
||||
# Required for multi-file ROM actions and 3DS QR code
|
||||
selectinload(Rom.files).options(
|
||||
joinedload(RomFile.rom).load_only(Rom.fs_path, Rom.fs_name)
|
||||
),
|
||||
# Show sibling rom badges on cards
|
||||
selectinload(Rom.sibling_roms).options(
|
||||
noload(Rom.platform), noload(Rom.metadatum)
|
||||
),
|
||||
# Show notes indicator on cards
|
||||
selectinload(Rom.notes),
|
||||
)
|
||||
|
||||
# Handle platform filtering - platform filtering always uses OR logic since ROMs belong to only one platform
|
||||
if platform_ids:
|
||||
query = self._filter_by_platform_ids(query, platform_ids)
|
||||
@@ -730,7 +731,6 @@ class DBRomsHandler(DBBaseHandler):
|
||||
|
||||
return query
|
||||
|
||||
@with_simple
|
||||
@begin_session
|
||||
def get_roms_query(
|
||||
self,
|
||||
@@ -738,9 +738,10 @@ class DBRomsHandler(DBBaseHandler):
|
||||
order_by: str = "name",
|
||||
order_dir: str = "asc",
|
||||
user_id: int | None = None,
|
||||
query: Query = None, # type: ignore
|
||||
session: Session = None, # type: ignore
|
||||
) -> tuple[Query[Rom], Any]:
|
||||
query = select(Rom)
|
||||
|
||||
if user_id:
|
||||
query = query.outerjoin(
|
||||
RomUser, and_(RomUser.rom_id == Rom.id, RomUser.user_id == user_id)
|
||||
@@ -771,12 +772,13 @@ class DBRomsHandler(DBBaseHandler):
|
||||
else:
|
||||
order_attr = order_attr.asc()
|
||||
|
||||
return query.order_by(order_attr), order_attr_column
|
||||
return query.order_by(order_attr), order_attr_column # type: ignore
|
||||
|
||||
@begin_session
|
||||
def get_roms_scalar(
|
||||
self,
|
||||
*,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
**kwargs,
|
||||
) -> Sequence[Rom]:
|
||||
@@ -785,6 +787,10 @@ class DBRomsHandler(DBBaseHandler):
|
||||
order_dir=kwargs.get("order_dir", "asc"),
|
||||
user_id=kwargs.get("user_id", None),
|
||||
)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
roms = self.filter_roms(
|
||||
query=query,
|
||||
platform_ids=kwargs.get("platform_ids", None),
|
||||
@@ -873,15 +879,12 @@ class DBRomsHandler(DBBaseHandler):
|
||||
session: Session = None, # type: ignore
|
||||
) -> dict[str, Rom]:
|
||||
"""Retrieve a dictionary of roms by their filesystem names."""
|
||||
roms = (
|
||||
session.scalars(
|
||||
query.filter(Rom.fs_name.in_(fs_names)).filter_by(
|
||||
platform_id=platform_id
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
query = query.filter(Rom.fs_name.in_(fs_names)).filter_by(
|
||||
platform_id=platform_id
|
||||
)
|
||||
|
||||
roms = session.scalars(query).unique().all()
|
||||
|
||||
return {rom.fs_name: rom for rom in roms}
|
||||
|
||||
@begin_session
|
||||
@@ -1065,8 +1068,9 @@ class DBRomsHandler(DBBaseHandler):
|
||||
rom_id: int,
|
||||
user_id: int,
|
||||
public_only: bool = False,
|
||||
search: str = "",
|
||||
search: str | None = "",
|
||||
tags: list[str] | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[RomNote]:
|
||||
query = session.query(RomNote).filter(RomNote.rom_id == rom_id)
|
||||
@@ -1088,6 +1092,9 @@ class DBRomsHandler(DBBaseHandler):
|
||||
json_array_contains_value(RomNote.tags, tag, session=session)
|
||||
)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return query.order_by(RomNote.updated_at.desc()).all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import QueryableAttribute, Session, load_only
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.assets import Save
|
||||
@@ -48,6 +48,7 @@ class DBSavesHandler(DBBaseHandler):
|
||||
user_id: int,
|
||||
rom_id: int | None = None,
|
||||
platform_id: int | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[Save]:
|
||||
query = select(Save).filter_by(user_id=user_id)
|
||||
@@ -60,6 +61,9 @@ class DBSavesHandler(DBBaseHandler):
|
||||
Rom.platform_id == platform_id
|
||||
)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query).all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import QueryableAttribute, Session, load_only
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.assets import State
|
||||
@@ -48,6 +48,7 @@ class DBStatesHandler(DBBaseHandler):
|
||||
user_id: int,
|
||||
rom_id: int | None = None,
|
||||
platform_id: int | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[State]:
|
||||
query = select(State).filter_by(user_id=user_id)
|
||||
@@ -60,6 +61,9 @@ class DBStatesHandler(DBBaseHandler):
|
||||
Rom.platform_id == platform_id
|
||||
)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query).all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import and_, delete, func, not_, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import QueryableAttribute, Session, load_only
|
||||
from sqlalchemy.sql import Delete, Select, Update
|
||||
|
||||
from decorators.database import begin_session
|
||||
@@ -94,6 +94,7 @@ class DBUsersHandler(DBBaseHandler):
|
||||
emails: Sequence[str] = (),
|
||||
roles: Sequence[Role] = (),
|
||||
has_ra_username: bool | None = None,
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[User]:
|
||||
query = self.filter(
|
||||
@@ -103,6 +104,10 @@ class DBUsersHandler(DBBaseHandler):
|
||||
roles=roles,
|
||||
has_ra_username=has_ra_username,
|
||||
)
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
return session.scalars(query).all()
|
||||
|
||||
@begin_session
|
||||
|
||||
@@ -746,7 +746,7 @@ class IGDBHandler(MetadataHandler):
|
||||
)
|
||||
matched_roms.extend(alternative_roms)
|
||||
|
||||
# Use a dictionary to keep track of unique ids
|
||||
# Use a dictionary to keep track of unique IDs
|
||||
unique_ids: dict[int, Game] = {}
|
||||
|
||||
# Use a list comprehension to filter duplicates based on the 'id' key
|
||||
|
||||
Reference in New Issue
Block a user