mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 00:27:41 +01:00
fix: Database JSON array utils
Fix existing JSON array util `json_array_contains_value`, and add two new utils: `json_array_contains_any` and `json_array_contains_all`. These utils have been tested with arrays of strings and integers, on the following database engine versions: - PostgreSQL: 12, 13, 14, 15, 16, 17, 18 - MySQL: 8.0, 8.4, 9.0, 9.4 - MariaDB: 10.5, 10.6, 10.11, 11.4, 11.8, 12.0
This commit is contained in:
@@ -29,6 +29,7 @@ from handler.metadata.base_hander import UniversalPlatformSlug as UPS
|
||||
from models.assets import Save, Screenshot, State
|
||||
from models.platform import Platform
|
||||
from models.rom import Rom, RomFile, RomMetadata, RomUser
|
||||
from utils.database import json_array_contains_value
|
||||
|
||||
from .base_handler import DBBaseHandler
|
||||
|
||||
@@ -305,75 +306,30 @@ class DBRomsHandler(DBBaseHandler):
|
||||
or_(*(Rom.hasheous_metadata[key].as_boolean() for key in keys_to_check))
|
||||
)
|
||||
|
||||
def filter_by_genre(self, query: Query, selected_genre: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("genres @> (:genre)::jsonb").bindparams(
|
||||
genre=f'["{selected_genre}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(genres, JSON_ARRAY(:genre))").bindparams(
|
||||
genre=selected_genre
|
||||
)
|
||||
)
|
||||
def filter_by_genre(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(RomMetadata.genres, value, session=session)
|
||||
)
|
||||
|
||||
def filter_by_franchise(self, query: Query, selected_franchise: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("franchises @> (:franchise)::jsonb").bindparams(
|
||||
franchise=f'["{selected_franchise}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(franchises, JSON_ARRAY(:franchise))").bindparams(
|
||||
franchise=selected_franchise
|
||||
)
|
||||
)
|
||||
def filter_by_franchise(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(RomMetadata.franchises, value, session=session)
|
||||
)
|
||||
|
||||
def filter_by_collection(self, query: Query, selected_collection: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("collections @> (:collection)::jsonb").bindparams(
|
||||
collection=f'["{selected_collection}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(collections, JSON_ARRAY(:collection))").bindparams(
|
||||
collection=selected_collection
|
||||
)
|
||||
)
|
||||
def filter_by_collection(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(RomMetadata.collections, value, session=session)
|
||||
)
|
||||
|
||||
def filter_by_company(self, query: Query, selected_company: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("companies @> (:company)::jsonb").bindparams(
|
||||
company=f'["{selected_company}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(companies, JSON_ARRAY(:company))").bindparams(
|
||||
company=selected_company
|
||||
)
|
||||
)
|
||||
def filter_by_company(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(RomMetadata.companies, value, session=session)
|
||||
)
|
||||
|
||||
def filter_by_age_rating(self, query: Query, selected_age_rating: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("age_ratings @> (:age_rating)::jsonb").bindparams(
|
||||
age_rating=f'["{selected_age_rating}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(age_ratings, JSON_ARRAY(:age_rating))").bindparams(
|
||||
age_rating=selected_age_rating
|
||||
)
|
||||
)
|
||||
def filter_by_age_rating(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(RomMetadata.age_ratings, value, session=session)
|
||||
)
|
||||
|
||||
def filter_by_status(self, query: Query, selected_status: str):
|
||||
status_filter = RomUser.status == selected_status
|
||||
@@ -389,33 +345,15 @@ class DBRomsHandler(DBBaseHandler):
|
||||
|
||||
return query.filter(status_filter, RomUser.hidden.is_(False))
|
||||
|
||||
def filter_by_region(self, query: Query, selected_region: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("regions @> (:region)::jsonb").bindparams(
|
||||
region=f'["{selected_region}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(regions, JSON_ARRAY(:region))").bindparams(
|
||||
region=selected_region
|
||||
)
|
||||
)
|
||||
def filter_by_region(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(Rom.regions, value, session=session)
|
||||
)
|
||||
|
||||
def filter_by_language(self, query: Query, selected_language: str):
|
||||
if ROMM_DB_DRIVER == "postgresql":
|
||||
return query.filter(
|
||||
text("languages @> (:language)::jsonb").bindparams(
|
||||
language=f'["{selected_language}"]'
|
||||
)
|
||||
)
|
||||
else:
|
||||
return query.filter(
|
||||
text("JSON_OVERLAPS(languages, JSON_ARRAY(:language))").bindparams(
|
||||
language=selected_language
|
||||
)
|
||||
)
|
||||
def filter_by_language(self, query: Query, session: Session, value: str) -> Query:
|
||||
return query.filter(
|
||||
json_array_contains_value(Rom.languages, value, session=session)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def filter_roms(
|
||||
@@ -591,25 +529,29 @@ class DBRomsHandler(DBBaseHandler):
|
||||
query = query.outerjoin(RomMetadata)
|
||||
|
||||
if selected_genre:
|
||||
query = self.filter_by_genre(query, selected_genre)
|
||||
|
||||
query = self.filter_by_genre(query, session=session, value=selected_genre)
|
||||
if selected_franchise:
|
||||
query = self.filter_by_franchise(query, selected_franchise)
|
||||
|
||||
query = self.filter_by_franchise(
|
||||
query, session=session, value=selected_franchise
|
||||
)
|
||||
if selected_collection:
|
||||
query = self.filter_by_collection(query, selected_collection)
|
||||
|
||||
query = self.filter_by_collection(
|
||||
query, session=session, value=selected_collection
|
||||
)
|
||||
if selected_company:
|
||||
query = self.filter_by_company(query, selected_company)
|
||||
|
||||
query = self.filter_by_company(
|
||||
query, session=session, value=selected_company
|
||||
)
|
||||
if selected_age_rating:
|
||||
query = self.filter_by_age_rating(query, selected_age_rating)
|
||||
|
||||
query = self.filter_by_age_rating(
|
||||
query, session=session, value=selected_age_rating
|
||||
)
|
||||
if selected_region:
|
||||
query = self.filter_by_region(query, selected_region)
|
||||
|
||||
query = self.filter_by_region(query, session=session, value=selected_region)
|
||||
if selected_language:
|
||||
query = self.filter_by_language(query, selected_language)
|
||||
query = self.filter_by_language(
|
||||
query, session=session, value=selected_language
|
||||
)
|
||||
|
||||
# The RomUser table is already joined if user_id is set
|
||||
if selected_status and user_id:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as sa_pg
|
||||
@@ -12,18 +12,41 @@ def CustomJSON(**kwargs: Any) -> sa.JSON:
|
||||
return sa.JSON(**kwargs).with_variant(sa_pg.JSONB(**kwargs), "postgresql")
|
||||
|
||||
|
||||
def is_postgresql(conn: sa.Connection) -> bool:
|
||||
return conn.engine.name == "postgresql"
|
||||
def is_db_version_compatible(
|
||||
conn: sa.Connection,
|
||||
min_version: tuple[int, ...] | None = None,
|
||||
) -> bool:
|
||||
"""Check if the database server version complies with the given version constraints."""
|
||||
if min_version is None:
|
||||
return True
|
||||
server_version = conn.engine.dialect.server_version_info
|
||||
return bool(server_version and server_version >= min_version)
|
||||
|
||||
|
||||
def is_mysql(conn: sa.Connection) -> bool:
|
||||
return conn.engine.name == "mysql"
|
||||
def is_postgresql(
|
||||
conn: sa.Connection, min_version: tuple[int, ...] | None = None
|
||||
) -> bool:
|
||||
if conn.engine.name != "postgresql":
|
||||
return False
|
||||
return is_db_version_compatible(conn, min_version=min_version)
|
||||
|
||||
|
||||
def is_mysql(conn: sa.Connection, min_version: tuple[int, ...] | None = None) -> bool:
|
||||
if conn.engine.name != "mysql":
|
||||
return False
|
||||
return is_db_version_compatible(conn, min_version=min_version)
|
||||
|
||||
|
||||
def is_mariadb(conn: sa.Connection, min_version: tuple[int, ...] | None = None) -> bool:
|
||||
if conn.engine.name != "mariadb":
|
||||
return False
|
||||
return is_db_version_compatible(conn, min_version=min_version)
|
||||
|
||||
|
||||
def json_array_contains_value(
|
||||
column: sa.Column, value: Any, *, session: Session
|
||||
column: sa.Column, value: str | int, *, session: Session
|
||||
) -> ColumnElement:
|
||||
"""Check if a JSON array column contains a single value."""
|
||||
"""Check if a JSON array column contains the given value."""
|
||||
conn = session.get_bind()
|
||||
if is_postgresql(conn):
|
||||
# In PostgreSQL, string values can be checked for containment using the `?` operator.
|
||||
@@ -33,10 +56,72 @@ def json_array_contains_value(
|
||||
return sa.type_coerce(column, sa_pg.JSONB).contains(
|
||||
func.cast(value, sa_pg.JSONB)
|
||||
)
|
||||
elif is_mysql(conn):
|
||||
# In MySQL, JSON.contains() requires a JSON-formatted string (even if it's an int)
|
||||
elif is_mysql(conn) or is_mariadb(conn):
|
||||
# In MySQL and MariaDB, JSON_CONTAINS requires a JSON-formatted string (even if it's an int).
|
||||
return func.json_contains(column, json.dumps(value))
|
||||
return func.json_contains(column, value)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"json_array_contains_value is not implemented for engine: {conn.engine.name}"
|
||||
)
|
||||
|
||||
|
||||
def json_array_contains_any(
|
||||
column: sa.Column, values: Sequence[str] | Sequence[int], *, session: Session
|
||||
) -> ColumnElement:
|
||||
"""Check if a JSON array column contains any of the given values."""
|
||||
if not values:
|
||||
return sa.false()
|
||||
|
||||
conn = session.get_bind()
|
||||
if is_postgresql(conn):
|
||||
# In PostgreSQL, string arrays can be checked for overlap using the `?|` operator.
|
||||
# For other types, we combine element-wise checks with OR.
|
||||
if isinstance(values[0], str):
|
||||
return sa.type_coerce(column, sa_pg.JSONB).has_any(
|
||||
sa.type_coerce(values, sa_pg.ARRAY(sa_pg.TEXT))
|
||||
)
|
||||
return sa.or_(
|
||||
*[json_array_contains_value(column, v, session=session) for v in values]
|
||||
)
|
||||
elif is_mysql(conn) or is_mariadb(conn, min_version=(10, 9)):
|
||||
# In MySQL and MariaDB, JSON_OVERLAPS requires a JSON-formatted string (even if it's an int).
|
||||
return func.json_overlaps(column, json.dumps(values))
|
||||
elif is_mariadb(conn):
|
||||
# MariaDB before 10.9 does not have JSON_OVERLAPS, so we fall back to element-wise checks.
|
||||
return sa.or_(
|
||||
*[json_array_contains_value(column, v, session=session) for v in values]
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"json_array_contains_any is not implemented for engine: {conn.engine.name}"
|
||||
)
|
||||
|
||||
|
||||
def json_array_contains_all(
|
||||
column: sa.Column, values: Sequence[Any], *, session: Session
|
||||
) -> ColumnElement:
|
||||
"""Check if a JSON array column contains all of the given values."""
|
||||
if not values:
|
||||
return sa.false()
|
||||
|
||||
conn = session.get_bind()
|
||||
if is_postgresql(conn):
|
||||
# In PostgreSQL, string arrays can be checked for containment using the `?&` operator.
|
||||
# For other types, we combine element-wise checks with AND.
|
||||
if isinstance(values[0], str):
|
||||
return sa.type_coerce(column, sa_pg.JSONB).has_all(
|
||||
sa.type_coerce(values, sa_pg.ARRAY(sa_pg.TEXT))
|
||||
)
|
||||
return sa.and_(
|
||||
*[json_array_contains_value(column, v, session=session) for v in values]
|
||||
)
|
||||
elif is_mysql(conn) or is_mariadb(conn):
|
||||
# In MySQL and MariaDB, JSON_CONTAINS requires a JSON-formatted string (even if it's an int).
|
||||
return func.json_contains(column, json.dumps(values))
|
||||
|
||||
raise NotImplementedError(
|
||||
f"json_array_contains_all is not implemented for engine: {conn.engine.name}"
|
||||
)
|
||||
|
||||
|
||||
def safe_float(value: Any, default: float = 0.0) -> float:
|
||||
|
||||
Reference in New Issue
Block a user