Merge pull request #2921 from rommapp/only-ids-param

New endpoints to fetch all IDs
This commit is contained in:
Georges-Antoine Assi
2026-01-28 14:55:42 -05:00
committed by GitHub
16 changed files with 309 additions and 71 deletions

View File

@@ -24,7 +24,11 @@ from handler.filesystem.base_handler import CoverSize
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.collection import Collection, SmartCollection
from models.collection import (
Collection,
SmartCollection,
VirtualCollection,
)
from utils.router import APIRouter
router = APIRouter(
@@ -166,7 +170,32 @@ def get_collections(
collections = db_collection_handler.get_collections(updated_after=updated_after)
return CollectionSchema.for_user(request.user.id, [c for c in collections])
return CollectionSchema.for_user(request.user.id, collections)
@protected_route(router.get, "/identifiers", [Scope.COLLECTIONS_READ])
def get_collection_identifiers(
request: Request,
) -> list[int]:
"""Get collections identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of collection IDs
"""
collections = db_collection_handler.get_collections(
only_fields=[
Collection.id,
Collection.name,
Collection.user_id,
Collection.is_public,
],
)
return [c.id for c in collections if c.user_id == request.user.id or c.is_public]
@protected_route(router.get, "/virtual", [Scope.COLLECTIONS_READ])
@@ -184,11 +213,34 @@ def get_virtual_collections(
list[VirtualCollectionSchema]: List of virtual collections
"""
virtual_collections = db_collection_handler.get_virtual_collections(type, limit)
virtual_collections = db_collection_handler.get_virtual_collections(
type=type, limit=limit
)
return [VirtualCollectionSchema.model_validate(vc) for vc in virtual_collections]
@protected_route(router.get, "/virtual/identifiers", [Scope.COLLECTIONS_READ])
def get_virtual_collection_identifiers(
request: Request,
) -> list[str]:
"""Get virtual collections identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[str]: List of generated virtual collection IDs
"""
virtual_collections = db_collection_handler.get_virtual_collections(
type="all",
only_fields=[VirtualCollection.name, VirtualCollection.type],
)
return [s.id for s in virtual_collections]
@protected_route(router.get, "/smart", [Scope.COLLECTIONS_READ])
def get_smart_collections(
request: Request,
@@ -213,10 +265,29 @@ def get_smart_collections(
request.user.id, updated_after=updated_after
)
return SmartCollectionSchema.for_user(
request.user.id, [s for s in smart_collections]
return SmartCollectionSchema.for_user(request.user.id, smart_collections)
@protected_route(router.get, "/smart/identifiers", [Scope.COLLECTIONS_READ])
def get_smart_collection_identifiers(
request: Request,
) -> list[int]:
"""Get smart collections identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of smart collection IDs
"""
smart_collections = db_collection_handler.get_smart_collections(
request.user.id,
only_fields=[SmartCollection.id],
)
return [s.id for s in smart_collections]
@protected_route(router.get, "/{id}", [Scope.COLLECTIONS_READ])
def get_collection(request: Request, id: int) -> CollectionSchema:

View File

@@ -125,6 +125,24 @@ def get_platform_firmware(
]
@protected_route(router.get, "/identifiers", [Scope.FIRMWARE_READ])
def get_firmware_identifiers(
request: Request,
) -> list[int]:
"""Get firmware identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of firmware IDs
"""
firmware = db_firmware_handler.list_firmware(
only_fields=[Firmware.id],
)
return [f.id for f in firmware]
@protected_route(
router.get,
"/{id}",

View File

@@ -16,6 +16,7 @@ from handler.scan_handler import scan_platform
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.platform import Platform
from utils.platforms import get_supported_platforms
from utils.router import APIRouter
@@ -66,6 +67,18 @@ def get_platforms(
]
@protected_route(router.get, "/identifiers", [Scope.PLATFORMS_READ])
def get_platform_identifiers(
request: Request,
) -> list[int]:
"""Retrieve platform identifiers."""
platforms = db_platform_handler.get_platforms(
only_fields=[Platform.id],
)
return [p.id for p in platforms]
@protected_route(router.get, "/supported", [Scope.PLATFORMS_READ])
def get_supported_platforms_endpoint(request: Request) -> list[PlatformSchema]:
"""Retrieve the list of supported platforms."""

View File

@@ -1,3 +1,4 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Any
@@ -34,7 +35,7 @@ class CollectionSchema(BaseCollectionSchema):
@classmethod
def for_user(
cls, user_id: int, collections: list["Collection"]
cls, user_id: int, collections: Sequence["Collection"]
) -> list["CollectionSchema"]:
return [
cls.model_validate(c)
@@ -67,7 +68,7 @@ class SmartCollectionSchema(BaseCollectionSchema):
@classmethod
def for_user(
cls, user_id: int, smart_collections: list["SmartCollection"]
cls, user_id: int, smart_collections: Sequence["SmartCollection"]
) -> list["SmartCollectionSchema"]:
"""Filter smart collections visible to user and create schemas."""
return [

View File

@@ -66,7 +66,7 @@ from handler.metadata.ss_handler import get_preferred_media_types
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.rom import Rom
from models.rom import Rom, RomNote
from utils.database import safe_int, safe_str_to_bool
from utils.filesystem import sanitize_filename
from utils.hashing import crc32_to_hex
@@ -517,6 +517,19 @@ def get_roms(
)
@protected_route(router.get, "/identifiers", [Scope.ROMS_READ])
def get_rom_identifiers(
request: Request,
) -> list[int]:
"""Retrieve rom identifiers."""
db_roms = db_rom_handler.get_roms_scalar(
user_id=request.user.id,
only_fields=[Rom.id],
)
return [r.id for r in db_roms]
@protected_route(
router.get,
"/download",
@@ -1637,8 +1650,6 @@ async def get_rom_notes(
tags: list[str] = DEFAULT_TAGS,
) -> list[UserNoteSchema]:
"""Get all notes for a ROM."""
from handler.database import db_rom_handler
rom = db_rom_handler.get_rom(id)
if not rom:
raise RomNotFoundInDatabaseException(id)
@@ -1657,6 +1668,30 @@ async def get_rom_notes(
return [UserNoteSchema.model_validate(note) for note in notes]
@protected_route(
router.get,
"/{id}/notes/identifiers",
[Scope.ROMS_READ],
responses={status.HTTP_404_NOT_FOUND: {}},
)
async def get_rom_note_identifiers(
request: Request,
id: Annotated[int, PathVar(description="Rom internal id.", ge=1)],
) -> list[int]:
"""Get all note identifiers for a ROM."""
rom = db_rom_handler.get_rom(id)
if not rom:
raise RomNotFoundInDatabaseException(id)
notes = db_rom_handler.get_rom_notes(
rom_id=id,
user_id=request.user.id,
only_fields=[RomNote.id],
)
return [note.id for note in notes]
@protected_route(
router.post,
"/{id}/notes",
@@ -1669,8 +1704,6 @@ async def create_rom_note(
note_data: Annotated[dict, Body()],
) -> UserNoteSchema:
"""Create a new note for a ROM."""
from handler.database import db_rom_handler
rom = db_rom_handler.get_rom(id)
if not rom:
raise RomNotFoundInDatabaseException(id)
@@ -1702,8 +1735,6 @@ async def update_rom_note(
note_data: Annotated[dict, Body()],
) -> UserNoteSchema:
"""Update a ROM note."""
from handler.database import db_rom_handler
note = db_rom_handler.update_rom_note(
note_id=note_id,
user_id=request.user.id,
@@ -1736,8 +1767,6 @@ async def delete_rom_note(
note_id: Annotated[int, PathVar(description="Note id.", ge=1)],
) -> dict:
"""Delete a ROM note."""
from handler.database import db_rom_handler
success = db_rom_handler.delete_rom_note(note_id=note_id, user_id=request.user.id)
if not success:

View File

@@ -13,6 +13,7 @@ from handler.scan_handler import scan_save, scan_screenshot
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.assets import Save
from utils.router import APIRouter
router = APIRouter(
@@ -151,6 +152,26 @@ def get_saves(
return [SaveSchema.model_validate(save) for save in saves]
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
def get_save_identifiers(
request: Request,
) -> list[int]:
"""Get save identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of save IDs
"""
saves = db_save_handler.get_saves(
user_id=request.user.id,
only_fields=[Save.id],
)
return [save.id for save in saves]
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
def get_save(request: Request, id: int) -> SaveSchema:
save = db_save_handler.get_save(user_id=request.user.id, id=id)

View File

@@ -13,6 +13,7 @@ from handler.scan_handler import scan_screenshot, scan_state
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.assets import State
from utils.router import APIRouter
router = APIRouter(
@@ -153,6 +154,26 @@ def get_states(
return [StateSchema.model_validate(state) for state in states]
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
def get_state_identifiers(
request: Request,
) -> list[int]:
"""Get state identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of state IDs
"""
states = db_state_handler.get_states(
user_id=request.user.id,
only_fields=[State.id],
)
return [state.id for state in states]
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
def get_state(request: Request, id: int) -> StateSchema:
state = db_state_handler.get_state(user_id=request.user.id, id=id)

View File

@@ -212,6 +212,23 @@ def get_users(request: Request) -> list[UserSchema]:
return [UserSchema.model_validate(u) for u in db_user_handler.get_users()]
@protected_route(router.get, "/identifiers", [Scope.USERS_READ])
def get_user_identifiers(
request: Request,
) -> list[int]:
"""Get all user identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: All user ids stored in the RomM's database
"""
users = db_user_handler.get_users(only_fields=[User.id])
return [u.id for u in users]
@protected_route(router.get, "/me", [Scope.ME_READ])
def get_current_user(request: Request) -> UserSchema | None:
"""Get current user endpoint

View File

@@ -4,7 +4,14 @@ from datetime import datetime
from typing import Any
from sqlalchemy import delete, insert, literal, or_, select, update
from sqlalchemy.orm import Query, Session, noload, selectinload
from sqlalchemy.orm import (
Query,
QueryableAttribute,
Session,
load_only,
noload,
selectinload,
)
from decorators.database import begin_session
from models.collection import (
@@ -87,11 +94,16 @@ class DBCollectionsHandler(DBBaseHandler):
def get_collections(
self,
updated_after: datetime | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
query: Query = None, # type: ignore
session: Session = None, # type: ignore
) -> Sequence[Collection]:
if updated_after:
query = query.filter(Collection.updated_at > updated_after)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query.order_by(Collection.name.asc())).unique().all()
@begin_session
@@ -157,19 +169,21 @@ class DBCollectionsHandler(DBBaseHandler):
self,
type: str,
limit: int | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[VirtualCollection]:
return (
session.scalars(
select(VirtualCollection)
.filter(or_(VirtualCollection.type == type, literal(type == "all")))
.limit(limit)
.order_by(VirtualCollection.name.asc())
)
.unique()
.all()
query = (
select(VirtualCollection)
.filter(or_(VirtualCollection.type == type, literal(type == "all")))
.limit(limit)
.order_by(VirtualCollection.name.asc())
)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).unique().all()
# Smart collections
@begin_session
def add_smart_collection(
@@ -206,6 +220,7 @@ class DBCollectionsHandler(DBBaseHandler):
self,
user_id: int | None = None,
updated_after: datetime | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[SmartCollection]:
query = select(SmartCollection).order_by(SmartCollection.name.asc())
@@ -219,6 +234,9 @@ class DBCollectionsHandler(DBBaseHandler):
if updated_after:
query = query.filter(SmartCollection.updated_at > updated_after)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).unique().all()
@begin_session

View File

@@ -1,7 +1,7 @@
from collections.abc import Sequence
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import QueryableAttribute, Session, load_only
from decorators.database import begin_session
from models.firmware import Firmware
@@ -32,6 +32,7 @@ class DBFirmwareHandler(DBBaseHandler):
self,
*,
platform_id: int | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[Firmware]:
query = select(Firmware).order_by(Firmware.file_name.asc())
@@ -39,6 +40,9 @@ class DBFirmwareHandler(DBBaseHandler):
if platform_id:
query = query.filter_by(platform_id=platform_id)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).all()
@begin_session

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from datetime import datetime
from sqlalchemy import delete, or_, select, update
from sqlalchemy.orm import Query, Session, selectinload
from sqlalchemy.orm import Query, QueryableAttribute, Session, load_only, selectinload
from decorators.database import begin_session
from models.platform import Platform
@@ -67,11 +67,16 @@ class DBPlatformsHandler(DBBaseHandler):
def get_platforms(
self,
updated_after: datetime | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
query: Query = None, # type: ignore
session: Session = None, # type: ignore
) -> Sequence[Platform]:
if updated_after:
query = query.filter(Platform.updated_at > updated_after)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query.order_by(Platform.name.asc())).unique().all()
@begin_session

View File

@@ -21,7 +21,15 @@ from sqlalchemy import (
text,
update,
)
from sqlalchemy.orm import Query, Session, joinedload, noload, selectinload
from sqlalchemy.orm import (
Query,
QueryableAttribute,
Session,
joinedload,
load_only,
noload,
selectinload,
)
from sqlalchemy.sql.elements import KeyedColumnElement
from sqlalchemy.sql.selectable import Select
@@ -140,32 +148,6 @@ def with_details(func):
return wrapper
def with_simple(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
kwargs["query"] = select(Rom).options(
# Ensure platform is loaded for main ROM objects
selectinload(Rom.platform),
# Display properties for the current user (last_played)
selectinload(Rom.rom_users).options(noload(RomUser.rom)),
# Sort table by metadata (first_release_date)
selectinload(Rom.metadatum).options(noload(RomMetadata.rom)),
# Required for multi-file ROM actions and 3DS QR code
selectinload(Rom.files).options(
joinedload(RomFile.rom).load_only(Rom.fs_path, Rom.fs_name)
),
# Show sibling rom badges on cards
selectinload(Rom.sibling_roms).options(
noload(Rom.platform), noload(Rom.metadatum)
),
# Show notes indicator on cards
selectinload(Rom.notes),
)
return func(*args, **kwargs)
return wrapper
class DBRomsHandler(DBBaseHandler):
@begin_session
@with_details
@@ -535,6 +517,25 @@ class DBRomsHandler(DBBaseHandler):
) -> Query[Rom]:
from handler.scan_handler import MetadataSource
query = query.options(
# Ensure platform is loaded for main ROM objects
selectinload(Rom.platform),
# Display properties for the current user (last_played)
selectinload(Rom.rom_users).options(noload(RomUser.rom)),
# Sort table by metadata (first_release_date)
selectinload(Rom.metadatum).options(noload(RomMetadata.rom)),
# Required for multi-file ROM actions and 3DS QR code
selectinload(Rom.files).options(
joinedload(RomFile.rom).load_only(Rom.fs_path, Rom.fs_name)
),
# Show sibling rom badges on cards
selectinload(Rom.sibling_roms).options(
noload(Rom.platform), noload(Rom.metadatum)
),
# Show notes indicator on cards
selectinload(Rom.notes),
)
# Handle platform filtering - platform filtering always uses OR logic since ROMs belong to only one platform
if platform_ids:
query = self._filter_by_platform_ids(query, platform_ids)
@@ -730,7 +731,6 @@ class DBRomsHandler(DBBaseHandler):
return query
@with_simple
@begin_session
def get_roms_query(
self,
@@ -738,9 +738,10 @@ class DBRomsHandler(DBBaseHandler):
order_by: str = "name",
order_dir: str = "asc",
user_id: int | None = None,
query: Query = None, # type: ignore
session: Session = None, # type: ignore
) -> tuple[Query[Rom], Any]:
query = select(Rom)
if user_id:
query = query.outerjoin(
RomUser, and_(RomUser.rom_id == Rom.id, RomUser.user_id == user_id)
@@ -771,12 +772,13 @@ class DBRomsHandler(DBBaseHandler):
else:
order_attr = order_attr.asc()
return query.order_by(order_attr), order_attr_column
return query.order_by(order_attr), order_attr_column # type: ignore
@begin_session
def get_roms_scalar(
self,
*,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
**kwargs,
) -> Sequence[Rom]:
@@ -785,6 +787,10 @@ class DBRomsHandler(DBBaseHandler):
order_dir=kwargs.get("order_dir", "asc"),
user_id=kwargs.get("user_id", None),
)
if only_fields:
query = query.options(load_only(*only_fields))
roms = self.filter_roms(
query=query,
platform_ids=kwargs.get("platform_ids", None),
@@ -873,15 +879,12 @@ class DBRomsHandler(DBBaseHandler):
session: Session = None, # type: ignore
) -> dict[str, Rom]:
"""Retrieve a dictionary of roms by their filesystem names."""
roms = (
session.scalars(
query.filter(Rom.fs_name.in_(fs_names)).filter_by(
platform_id=platform_id
)
)
.unique()
.all()
query = query.filter(Rom.fs_name.in_(fs_names)).filter_by(
platform_id=platform_id
)
roms = session.scalars(query).unique().all()
return {rom.fs_name: rom for rom in roms}
@begin_session
@@ -1065,8 +1068,9 @@ class DBRomsHandler(DBBaseHandler):
rom_id: int,
user_id: int,
public_only: bool = False,
search: str = "",
search: str | None = "",
tags: list[str] | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[RomNote]:
query = session.query(RomNote).filter(RomNote.rom_id == rom_id)
@@ -1088,6 +1092,9 @@ class DBRomsHandler(DBBaseHandler):
json_array_contains_value(RomNote.tags, tag, session=session)
)
if only_fields:
query = query.options(load_only(*only_fields))
return query.order_by(RomNote.updated_at.desc()).all()
@begin_session

View File

@@ -1,7 +1,7 @@
from collections.abc import Sequence
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import QueryableAttribute, Session, load_only
from decorators.database import begin_session
from models.assets import Save
@@ -48,6 +48,7 @@ class DBSavesHandler(DBBaseHandler):
user_id: int,
rom_id: int | None = None,
platform_id: int | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[Save]:
query = select(Save).filter_by(user_id=user_id)
@@ -60,6 +61,9 @@ class DBSavesHandler(DBBaseHandler):
Rom.platform_id == platform_id
)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).all()
@begin_session

View File

@@ -1,7 +1,7 @@
from collections.abc import Sequence
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import QueryableAttribute, Session, load_only
from decorators.database import begin_session
from models.assets import State
@@ -48,6 +48,7 @@ class DBStatesHandler(DBBaseHandler):
user_id: int,
rom_id: int | None = None,
platform_id: int | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[State]:
query = select(State).filter_by(user_id=user_id)
@@ -60,6 +61,9 @@ class DBStatesHandler(DBBaseHandler):
Rom.platform_id == platform_id
)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).all()
@begin_session

View File

@@ -1,7 +1,7 @@
from collections.abc import Sequence
from sqlalchemy import and_, delete, func, not_, select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import QueryableAttribute, Session, load_only
from sqlalchemy.sql import Delete, Select, Update
from decorators.database import begin_session
@@ -94,6 +94,7 @@ class DBUsersHandler(DBBaseHandler):
emails: Sequence[str] = (),
roles: Sequence[Role] = (),
has_ra_username: bool | None = None,
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[User]:
query = self.filter(
@@ -103,6 +104,10 @@ class DBUsersHandler(DBBaseHandler):
roles=roles,
has_ra_username=has_ra_username,
)
if only_fields:
query = query.options(load_only(*only_fields))
return session.scalars(query).all()
@begin_session

View File

@@ -746,7 +746,7 @@ class IGDBHandler(MetadataHandler):
)
matched_roms.extend(alternative_roms)
# Use a dictionary to keep track of unique ids
# Use a dictionary to keep track of unique IDs
unique_ids: dict[int, Game] = {}
# Use a list comprehension to filter duplicates based on the 'id' key