mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 00:27:41 +01:00
fix: Use proper JSON contains function for PostgreSQL
Fix `json_array_contains_value` function to use the `@>` operator for checking if a JSON array contains a value in PostgreSQL. This is necessary because the `has_key` function only works for string values. Also, remove `get_rom_collections` method, as it was doing the same thing as `get_collections_by_rom_id`. Fixes #1441.
This commit is contained in:
@@ -1,7 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.collection import Collection
|
||||
from sqlalchemy import Select, delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import ColumnExpressionArgument
|
||||
from utils.database import json_array_contains_value
|
||||
|
||||
from .base_handler import DBBaseHandler
|
||||
|
||||
@@ -37,11 +42,19 @@ class DBCollectionsHandler(DBBaseHandler):
|
||||
|
||||
@begin_session
|
||||
def get_collections_by_rom_id(
|
||||
self, rom_id: int, session: Session = None
|
||||
self,
|
||||
rom_id: int,
|
||||
*,
|
||||
order_by: Sequence[str | ColumnExpressionArgument[Any]] | None = None,
|
||||
session: Session = None,
|
||||
) -> list[Collection]:
|
||||
return session.scalars(
|
||||
select(Collection).filter(Collection.roms.contains(rom_id))
|
||||
).all()
|
||||
query = select(Collection).filter(
|
||||
json_array_contains_value(Collection.roms, rom_id, session=session)
|
||||
)
|
||||
if order_by is not None:
|
||||
query = query.order_by(*order_by)
|
||||
|
||||
return session.scalars(query).all()
|
||||
|
||||
@begin_session
|
||||
def update_collection(
|
||||
|
||||
@@ -6,7 +6,6 @@ from models.collection import Collection
|
||||
from models.rom import Rom, RomUser
|
||||
from sqlalchemy import and_, delete, func, or_, select, update
|
||||
from sqlalchemy.orm import Query, Session, selectinload
|
||||
from utils.database import json_array_contains_value
|
||||
|
||||
from .base_handler import DBBaseHandler
|
||||
|
||||
@@ -181,24 +180,6 @@ class DBRomsHandler(DBBaseHandler):
|
||||
query.filter_by(file_name_no_ext=file_name_no_ext).limit(1)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_rom_collections(
|
||||
self, rom: Rom, session: Session = None
|
||||
) -> list[Collection]:
|
||||
return (
|
||||
session.scalars(
|
||||
select(Collection)
|
||||
.filter(
|
||||
json_array_contains_value(
|
||||
Collection.roms, str(rom.id), session=session
|
||||
)
|
||||
)
|
||||
.order_by(Collection.name.asc())
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def update_rom(self, id: int, data: dict, session: Session = None) -> Rom:
|
||||
return session.execute(
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from utils.database import CustomJSON
|
||||
@@ -141,9 +142,12 @@ class Rom(BaseModel):
|
||||
return screenshots
|
||||
|
||||
def get_collections(self) -> list[Collection]:
|
||||
from handler.database import db_rom_handler
|
||||
from handler.database import db_collection_handler
|
||||
|
||||
return db_rom_handler.get_rom_collections(self)
|
||||
return db_collection_handler.get_collections_by_rom_id(
|
||||
self.id,
|
||||
order_by=[func.lower("name")],
|
||||
)
|
||||
|
||||
# Metadata fields
|
||||
@property
|
||||
|
||||
@@ -21,5 +21,11 @@ def json_array_contains_value(
|
||||
"""Check if a JSON array column contains a single value."""
|
||||
conn = session.get_bind()
|
||||
if is_postgresql(conn):
|
||||
return sa.type_coerce(column, sa_pg.JSONB()).has_key(value)
|
||||
# In PostgreSQL, string values can be checked for containment using the `?` operator.
|
||||
# For other types, we use the `@>` operator.
|
||||
if isinstance(value, str):
|
||||
return sa.type_coerce(column, sa_pg.JSONB).has_key(value)
|
||||
return sa.type_coerce(column, sa_pg.JSONB).contains(
|
||||
func.cast(value, sa_pg.JSONB)
|
||||
)
|
||||
return func.json_contains(column, value)
|
||||
|
||||
Reference in New Issue
Block a user