diff --git a/backend/handler/database/roms_handler.py b/backend/handler/database/roms_handler.py index 0032396c4..ee85dd738 100644 --- a/backend/handler/database/roms_handler.py +++ b/backend/handler/database/roms_handler.py @@ -29,6 +29,7 @@ from handler.metadata.base_hander import UniversalPlatformSlug as UPS from models.assets import Save, Screenshot, State from models.platform import Platform from models.rom import Rom, RomFile, RomMetadata, RomUser +from utils.database import json_array_contains_value from .base_handler import DBBaseHandler @@ -305,75 +306,30 @@ class DBRomsHandler(DBBaseHandler): or_(*(Rom.hasheous_metadata[key].as_boolean() for key in keys_to_check)) ) - def filter_by_genre(self, query: Query, selected_genre: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("genres @> (:genre)::jsonb").bindparams( - genre=f'["{selected_genre}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(genres, JSON_ARRAY(:genre))").bindparams( - genre=selected_genre - ) - ) + def filter_by_genre(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(RomMetadata.genres, value, session=session) + ) - def filter_by_franchise(self, query: Query, selected_franchise: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("franchises @> (:franchise)::jsonb").bindparams( - franchise=f'["{selected_franchise}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(franchises, JSON_ARRAY(:franchise))").bindparams( - franchise=selected_franchise - ) - ) + def filter_by_franchise(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(RomMetadata.franchises, value, session=session) + ) - def filter_by_collection(self, query: Query, selected_collection: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("collections @> (:collection)::jsonb").bindparams( - collection=f'["{selected_collection}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(collections, JSON_ARRAY(:collection))").bindparams( - collection=selected_collection - ) - ) + def filter_by_collection(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(RomMetadata.collections, value, session=session) + ) - def filter_by_company(self, query: Query, selected_company: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("companies @> (:company)::jsonb").bindparams( - company=f'["{selected_company}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(companies, JSON_ARRAY(:company))").bindparams( - company=selected_company - ) - ) + def filter_by_company(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(RomMetadata.companies, value, session=session) + ) - def filter_by_age_rating(self, query: Query, selected_age_rating: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("age_ratings @> (:age_rating)::jsonb").bindparams( - age_rating=f'["{selected_age_rating}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(age_ratings, JSON_ARRAY(:age_rating))").bindparams( - age_rating=selected_age_rating - ) - ) + def filter_by_age_rating(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(RomMetadata.age_ratings, value, session=session) + ) def filter_by_status(self, query: Query, selected_status: str): status_filter = RomUser.status == selected_status @@ -389,33 +345,15 @@ class DBRomsHandler(DBBaseHandler): return query.filter(status_filter, RomUser.hidden.is_(False)) - def filter_by_region(self, query: Query, selected_region: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("regions @> (:region)::jsonb").bindparams( - region=f'["{selected_region}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(regions, JSON_ARRAY(:region))").bindparams( - region=selected_region - ) - ) + def filter_by_region(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(Rom.regions, value, session=session) + ) - def filter_by_language(self, query: Query, selected_language: str): - if ROMM_DB_DRIVER == "postgresql": - return query.filter( - text("languages @> (:language)::jsonb").bindparams( - language=f'["{selected_language}"]' - ) - ) - else: - return query.filter( - text("JSON_OVERLAPS(languages, JSON_ARRAY(:language))").bindparams( - language=selected_language - ) - ) + def filter_by_language(self, query: Query, session: Session, value: str) -> Query: + return query.filter( + json_array_contains_value(Rom.languages, value, session=session) + ) @begin_session def filter_roms( @@ -591,25 +529,29 @@ class DBRomsHandler(DBBaseHandler): query = query.outerjoin(RomMetadata) if selected_genre: - query = self.filter_by_genre(query, selected_genre) - + query = self.filter_by_genre(query, session=session, value=selected_genre) if selected_franchise: - query = self.filter_by_franchise(query, selected_franchise) - + query = self.filter_by_franchise( + query, session=session, value=selected_franchise + ) if selected_collection: - query = self.filter_by_collection(query, selected_collection) - + query = self.filter_by_collection( + query, session=session, value=selected_collection + ) if selected_company: - query = self.filter_by_company(query, selected_company) - + query = self.filter_by_company( + query, session=session, value=selected_company + ) if selected_age_rating: - query = self.filter_by_age_rating(query, selected_age_rating) - + query = self.filter_by_age_rating( + query, session=session, value=selected_age_rating + ) if selected_region: - query = self.filter_by_region(query, selected_region) - + query = self.filter_by_region(query, session=session, value=selected_region) if selected_language: - query = self.filter_by_language(query, selected_language) + query = self.filter_by_language( + query, session=session, value=selected_language + ) # The RomUser table is already joined if user_id is set if selected_status and user_id: diff --git a/backend/utils/database.py b/backend/utils/database.py index dc5595685..4145ca367 100644 --- a/backend/utils/database.py +++ b/backend/utils/database.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Sequence import sqlalchemy as sa from sqlalchemy.dialects import postgresql as sa_pg @@ -12,18 +12,41 @@ def CustomJSON(**kwargs: Any) -> sa.JSON: return sa.JSON(**kwargs).with_variant(sa_pg.JSONB(**kwargs), "postgresql") -def is_postgresql(conn: sa.Connection) -> bool: - return conn.engine.name == "postgresql" +def is_db_version_compatible( + conn: sa.Connection, + min_version: tuple[int, ...] | None = None, +) -> bool: + """Check if the database server version complies with the given version constraints.""" + if min_version is None: + return True + server_version = conn.engine.dialect.server_version_info + return bool(server_version and server_version >= min_version) -def is_mysql(conn: sa.Connection) -> bool: - return conn.engine.name == "mysql" +def is_postgresql( + conn: sa.Connection, min_version: tuple[int, ...] | None = None +) -> bool: + if conn.engine.name != "postgresql": + return False + return is_db_version_compatible(conn, min_version=min_version) + + +def is_mysql(conn: sa.Connection, min_version: tuple[int, ...] | None = None) -> bool: + if conn.engine.name != "mysql": + return False + return is_db_version_compatible(conn, min_version=min_version) + + +def is_mariadb(conn: sa.Connection, min_version: tuple[int, ...] | None = None) -> bool: + if conn.engine.name != "mariadb": + return False + return is_db_version_compatible(conn, min_version=min_version) def json_array_contains_value( - column: sa.Column, value: Any, *, session: Session + column: sa.Column, value: str | int, *, session: Session ) -> ColumnElement: - """Check if a JSON array column contains a single value.""" + """Check if a JSON array column contains the given value.""" conn = session.get_bind() if is_postgresql(conn): # In PostgreSQL, string values can be checked for containment using the `?` operator. @@ -33,10 +56,72 @@ def json_array_contains_value( return sa.type_coerce(column, sa_pg.JSONB).contains( func.cast(value, sa_pg.JSONB) ) - elif is_mysql(conn): - # In MySQL, JSON.contains() requires a JSON-formatted string (even if it's an int) + elif is_mysql(conn) or is_mariadb(conn): + # In MySQL and MariaDB, JSON_CONTAINS requires a JSON-formatted string (even if it's an int). return func.json_contains(column, json.dumps(value)) - return func.json_contains(column, value) + + raise NotImplementedError( + f"json_array_contains_value is not implemented for engine: {conn.engine.name}" + ) + + +def json_array_contains_any( + column: sa.Column, values: Sequence[str] | Sequence[int], *, session: Session +) -> ColumnElement: + """Check if a JSON array column contains any of the given values.""" + if not values: + return sa.false() + + conn = session.get_bind() + if is_postgresql(conn): + # In PostgreSQL, string arrays can be checked for overlap using the `?|` operator. + # For other types, we combine element-wise checks with OR. + if isinstance(values[0], str): + return sa.type_coerce(column, sa_pg.JSONB).has_any( + sa.type_coerce(values, sa_pg.ARRAY(sa_pg.TEXT)) + ) + return sa.or_( + *[json_array_contains_value(column, v, session=session) for v in values] + ) + elif is_mysql(conn) or is_mariadb(conn, min_version=(10, 9)): + # In MySQL and MariaDB, JSON_OVERLAPS requires a JSON-formatted string (even if it's an int). + return func.json_overlaps(column, json.dumps(values)) + elif is_mariadb(conn): + # MariaDB before 10.9 does not have JSON_OVERLAPS, so we fall back to element-wise checks. + return sa.or_( + *[json_array_contains_value(column, v, session=session) for v in values] + ) + + raise NotImplementedError( + f"json_array_contains_any is not implemented for engine: {conn.engine.name}" + ) + + +def json_array_contains_all( + column: sa.Column, values: Sequence[Any], *, session: Session +) -> ColumnElement: + """Check if a JSON array column contains all of the given values.""" + if not values: + return sa.false() + + conn = session.get_bind() + if is_postgresql(conn): + # In PostgreSQL, string arrays can be checked for containment using the `?&` operator. + # For other types, we combine element-wise checks with AND. + if isinstance(values[0], str): + return sa.type_coerce(column, sa_pg.JSONB).has_all( + sa.type_coerce(values, sa_pg.ARRAY(sa_pg.TEXT)) + ) + return sa.and_( + *[json_array_contains_value(column, v, session=session) for v in values] + ) + elif is_mysql(conn) or is_mariadb(conn): + # In MySQL and MariaDB, JSON_CONTAINS requires a JSON-formatted string (even if it's an int). + return func.json_contains(column, json.dumps(values)) + + raise NotImplementedError( + f"json_array_contains_all is not implemented for engine: {conn.engine.name}" + ) def safe_float(value: Any, default: float = 0.0) -> float: