Merge pull request #2917 from tmgast/feature/device-registration-save-sync

Add device-based save synchronization
This commit is contained in:
Georges-Antoine Assi
2026-02-03 23:25:23 -05:00
committed by GitHub
22 changed files with 4103 additions and 58 deletions

View File

@@ -0,0 +1,102 @@
"""Add device-based save synchronization
Revision ID: 0068_save_sync
Revises: 0067_romfile_category_enum_cheat
Create Date: 2026-01-17
"""
import sqlalchemy as sa
from alembic import op
revision = "0068_save_sync"
down_revision = "0067_romfile_category_enum_cheat"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"devices",
sa.Column("id", sa.String(255), primary_key=True),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(255), nullable=True),
sa.Column("platform", sa.String(50), nullable=True),
sa.Column("client", sa.String(50), nullable=True),
sa.Column("client_version", sa.String(50), nullable=True),
sa.Column("ip_address", sa.String(45), nullable=True),
sa.Column("mac_address", sa.String(17), nullable=True),
sa.Column("hostname", sa.String(255), nullable=True),
sa.Column(
"sync_mode",
sa.Enum("API", "FILE_TRANSFER", "PUSH_PULL", name="syncmode"),
nullable=False,
server_default="API",
),
sa.Column("sync_enabled", sa.Boolean(), nullable=False, server_default="1"),
sa.Column("last_seen", sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_table(
"device_save_sync",
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("save_id", sa.Integer(), nullable=False),
sa.Column("last_synced_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("is_untracked", sa.Boolean(), nullable=False, server_default="0"),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["save_id"], ["saves.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("device_id", "save_id"),
)
with op.batch_alter_table("saves", schema=None) as batch_op:
batch_op.add_column(sa.Column("slot", sa.String(255), nullable=True))
batch_op.add_column(sa.Column("content_hash", sa.String(32), nullable=True))
op.create_index("ix_devices_user_id", "devices", ["user_id"])
op.create_index("ix_devices_last_seen", "devices", ["last_seen"])
op.create_index("ix_device_save_sync_save_id", "device_save_sync", ["save_id"])
op.create_index("ix_saves_slot", "saves", ["slot"])
op.create_index(
"ix_saves_rom_user_hash", "saves", ["rom_id", "user_id", "content_hash"]
)
def downgrade():
op.drop_index("ix_saves_rom_user_hash", "saves")
op.drop_index("ix_saves_slot", "saves")
op.drop_index("ix_device_save_sync_save_id", "device_save_sync")
op.drop_index("ix_devices_last_seen", "devices")
op.drop_index("ix_devices_user_id", "devices")
with op.batch_alter_table("saves", schema=None) as batch_op:
batch_op.drop_column("content_hash")
batch_op.drop_column("slot")
op.drop_table("device_save_sync")
op.drop_table("devices")
op.execute("DROP TYPE IF EXISTS syncmode")

179
backend/endpoints/device.py Normal file
View File

@@ -0,0 +1,179 @@
import uuid
from datetime import datetime, timezone
from fastapi import HTTPException, Request, Response, status
from pydantic import BaseModel, model_validator
from decorators.auth import protected_route
from endpoints.responses.device import DeviceCreateResponse, DeviceSchema
from handler.auth.constants import Scope
from handler.database import db_device_handler, db_device_save_sync_handler
from logger.logger import log
from models.device import Device
from utils.router import APIRouter
router = APIRouter(
prefix="/devices",
tags=["devices"],
)
class DeviceCreatePayload(BaseModel):
name: str | None = None
platform: str | None = None
client: str | None = None
client_version: str | None = None
ip_address: str | None = None
mac_address: str | None = None
hostname: str | None = None
allow_existing: bool = True
allow_duplicate: bool = False
reset_syncs: bool = False
@model_validator(mode="after")
def _duplicate_disables_existing(self) -> "DeviceCreatePayload":
if self.allow_duplicate:
self.allow_existing = False
return self
class DeviceUpdatePayload(BaseModel):
name: str | None = None
platform: str | None = None
client: str | None = None
client_version: str | None = None
ip_address: str | None = None
mac_address: str | None = None
hostname: str | None = None
sync_enabled: bool | None = None
@protected_route(router.post, "", [Scope.DEVICES_WRITE])
def register_device(
request: Request,
response: Response,
payload: DeviceCreatePayload,
) -> DeviceCreateResponse:
existing_device = None
if not payload.allow_duplicate:
existing_device = db_device_handler.get_device_by_fingerprint(
user_id=request.user.id,
mac_address=payload.mac_address,
hostname=payload.hostname,
platform=payload.platform,
)
if existing_device:
if not payload.allow_existing:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"error": "device_exists",
"message": "A device with this fingerprint already exists",
"device_id": existing_device.id,
},
)
if payload.reset_syncs:
db_device_save_sync_handler.delete_syncs_for_device(
device_id=existing_device.id
)
db_device_handler.update_last_seen(
device_id=existing_device.id, user_id=request.user.id
)
log.info(
f"Returned existing device {existing_device.id} for user {request.user.username}"
)
response.status_code = status.HTTP_200_OK
return DeviceCreateResponse(
device_id=existing_device.id,
name=existing_device.name,
created_at=existing_device.created_at,
)
response.status_code = status.HTTP_201_CREATED
device_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
device = Device(
id=device_id,
user_id=request.user.id,
name=payload.name,
platform=payload.platform,
client=payload.client,
client_version=payload.client_version,
ip_address=payload.ip_address,
mac_address=payload.mac_address,
hostname=payload.hostname,
last_seen=now,
)
db_device = db_device_handler.add_device(device)
log.info(f"Registered device {device_id} for user {request.user.username}")
return DeviceCreateResponse(
device_id=db_device.id,
name=db_device.name,
created_at=db_device.created_at,
)
@protected_route(router.get, "", [Scope.DEVICES_READ])
def get_devices(request: Request) -> list[DeviceSchema]:
devices = db_device_handler.get_devices(user_id=request.user.id)
return [DeviceSchema.model_validate(device) for device in devices]
@protected_route(router.get, "/{device_id}", [Scope.DEVICES_READ])
def get_device(request: Request, device_id: str) -> DeviceSchema:
device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id)
if not device:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device with ID {device_id} not found",
)
return DeviceSchema.model_validate(device)
@protected_route(router.put, "/{device_id}", [Scope.DEVICES_WRITE])
def update_device(
request: Request,
device_id: str,
payload: DeviceUpdatePayload,
) -> DeviceSchema:
device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id)
if not device:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device with ID {device_id} not found",
)
update_data = payload.model_dump(exclude_unset=True)
if update_data:
device = db_device_handler.update_device(
device_id=device_id,
user_id=request.user.id,
data=update_data,
)
return DeviceSchema.model_validate(device)
@protected_route(
router.delete,
"/{device_id}",
[Scope.DEVICES_WRITE],
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_device(request: Request, device_id: str) -> None:
device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id)
if not device:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device with ID {device_id} not found",
)
db_device_handler.delete_device(device_id=device_id, user_id=request.user.id)
log.info(f"Deleted device {device_id} for user {request.user.username}")

View File

@@ -1,6 +1,12 @@
from datetime import datetime
from typing import Any
from pydantic import model_validator
from sqlalchemy import inspect
from sqlalchemy.exc import InvalidRequestError
from .base import BaseModel
from .device import DeviceSyncSchema
class BaseAsset(BaseModel):
@@ -31,7 +37,40 @@ class ScreenshotSchema(BaseAsset):
class SaveSchema(BaseAsset):
emulator: str | None
slot: str | None = None
content_hash: str | None = None
screenshot: ScreenshotSchema | None
device_syncs: list[DeviceSyncSchema] = []
@model_validator(mode="before")
@classmethod
def handle_lazy_relationships(cls, data: Any) -> Any:
if isinstance(data, dict):
return data
try:
state = inspect(data)
except Exception:
return data
result = {}
for field_name in cls.model_fields:
if field_name in state.unloaded:
continue
try:
result[field_name] = getattr(data, field_name)
except InvalidRequestError:
continue
return result
class SlotSummarySchema(BaseModel):
slot: str | None
count: int
latest: SaveSchema
class SaveSummarySchema(BaseModel):
total_count: int
slots: list[SlotSummarySchema]
class StateSchema(BaseAsset):

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from models.device import SyncMode
from .base import BaseModel
class DeviceSyncSchema(BaseModel):
device_id: str
device_name: str | None
last_synced_at: datetime
is_untracked: bool
is_current: bool
class Config:
from_attributes = True
class DeviceSchema(BaseModel):
id: str
user_id: int
name: str | None
platform: str | None
client: str | None
client_version: str | None
ip_address: str | None
mac_address: str | None
hostname: str | None
sync_mode: SyncMode
sync_enabled: bool
last_seen: datetime | None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class DeviceCreateResponse(BaseModel):
device_id: str
name: str | None
created_at: datetime

View File

@@ -1,21 +1,99 @@
import os
import re
from datetime import datetime, timezone
from typing import Annotated
from fastapi import Body, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse
from decorators.auth import protected_route
from endpoints.responses.assets import SaveSchema
from endpoints.responses.assets import SaveSchema, SaveSummarySchema, SlotSummarySchema
from endpoints.responses.device import DeviceSyncSchema
from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException
from handler.auth.constants import Scope
from handler.database import db_rom_handler, db_save_handler, db_screenshot_handler
from handler.database import (
db_device_handler,
db_device_save_sync_handler,
db_rom_handler,
db_save_handler,
db_screenshot_handler,
)
from handler.filesystem import fs_asset_handler
from handler.scan_handler import scan_save, scan_screenshot
from logger.formatter import BLUE
from logger.formatter import highlight as hl
from logger.logger import log
from models.assets import Save
from models.device import Device
from models.device_save_sync import DeviceSaveSync
from utils.datetime import to_utc
from utils.router import APIRouter
def _build_save_schema(
save: Save,
device: Device | None = None,
sync: DeviceSaveSync | None = None,
) -> SaveSchema:
save_schema = SaveSchema.model_validate(save)
if device:
if sync:
is_current = to_utc(sync.last_synced_at) >= to_utc(save.updated_at)
last_synced = sync.last_synced_at
is_untracked = sync.is_untracked
else:
is_current = False
last_synced = save.updated_at
is_untracked = False
save_schema.device_syncs = [
DeviceSyncSchema(
device_id=device.id,
device_name=device.name,
last_synced_at=last_synced,
is_untracked=is_untracked,
is_current=is_current,
)
]
return save_schema
DATETIME_TAG_PATTERN = re.compile(r" \[\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\]")
def _apply_datetime_tag(filename: str) -> str:
name, ext = os.path.splitext(filename)
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d_%H-%M-%S")
if DATETIME_TAG_PATTERN.search(name):
name = DATETIME_TAG_PATTERN.sub("", name)
return f"{name} [{timestamp}]{ext}"
def _resolve_device(
device_id: str | None,
user_id: int,
scopes: set[str] | None = None,
required_scope: Scope | None = None,
) -> Device | None:
if not device_id:
return None
if required_scope and scopes and required_scope not in scopes:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
device = db_device_handler.get_device(device_id=device_id, user_id=user_id)
if not device:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Device with ID {device_id} not found",
)
return device
router = APIRouter(
prefix="/saves",
tags=["saves"],
@@ -27,22 +105,23 @@ async def add_save(
request: Request,
rom_id: int,
emulator: str | None = None,
slot: str | None = None,
device_id: str | None = None,
overwrite: bool = False,
autocleanup: bool = False,
autocleanup_limit: int = 10,
) -> SaveSchema:
"""Upload a save file for a ROM."""
device = _resolve_device(
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_WRITE
)
data = await request.form()
rom = db_rom_handler.get_rom(rom_id)
if not rom:
raise RomNotFoundInDatabaseException(rom_id)
log.info(f"Uploading save of {rom.name}")
saves_path = fs_asset_handler.build_saves_file_path(
user=request.user,
platform_fs_slug=rom.platform.fs_slug,
rom_id=rom_id,
emulator=emulator,
)
if "saveFile" not in data:
log.error("No save file provided")
raise HTTPException(
@@ -57,12 +136,45 @@ async def add_save(
status_code=status.HTTP_400_BAD_REQUEST, detail="Save file has no filename"
)
rom = db_rom_handler.get_rom(rom_id)
if not rom:
raise RomNotFoundInDatabaseException(rom_id)
actual_filename = saveFile.filename
if slot:
actual_filename = _apply_datetime_tag(saveFile.filename)
db_save = db_save_handler.get_save_by_filename(
user_id=request.user.id, rom_id=rom.id, file_name=actual_filename
)
if device and slot and not overwrite:
slot_saves = db_save_handler.get_saves(
user_id=request.user.id,
rom_id=rom.id,
slot=slot,
order_by="updated_at",
)
if slot_saves:
latest_in_slot = slot_saves[0]
sync = db_device_save_sync_handler.get_sync(
device_id=device.id, save_id=latest_in_slot.id
)
if not sync or to_utc(sync.last_synced_at) < to_utc(
latest_in_slot.updated_at
):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Slot has a newer save since your last sync",
)
elif device and db_save and not overwrite:
sync = db_device_save_sync_handler.get_sync(
device_id=device.id, save_id=db_save.id
)
if sync and to_utc(sync.last_synced_at) < to_utc(db_save.updated_at):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Save has been updated since your last sync",
)
log.info(
f"Uploading save {hl(saveFile.filename)} for {hl(str(rom.name), color=BLUE)}"
f"Uploading save {hl(actual_filename)} for {hl(str(rom.name), color=BLUE)}"
)
saves_path = fs_asset_handler.build_saves_file_path(
@@ -72,29 +184,72 @@ async def add_save(
emulator=emulator,
)
await fs_asset_handler.write_file(file=saveFile, path=saves_path)
await fs_asset_handler.write_file(
file=saveFile, path=saves_path, filename=actual_filename
)
# Scan or update save
scanned_save = await scan_save(
file_name=saveFile.filename,
file_name=actual_filename,
user=request.user,
platform_fs_slug=rom.platform.fs_slug,
rom_id=rom_id,
emulator=emulator,
)
db_save = db_save_handler.get_save_by_filename(
user_id=request.user.id, rom_id=rom.id, file_name=saveFile.filename
)
if db_save:
db_save = db_save_handler.update_save(
db_save.id, {"file_size_bytes": scanned_save.file_size_bytes}
if slot and scanned_save.content_hash and not overwrite:
existing_by_hash = db_save_handler.get_save_by_content_hash(
user_id=request.user.id,
rom_id=rom.id,
content_hash=scanned_save.content_hash,
)
if existing_by_hash:
try:
await fs_asset_handler.remove_file(f"{saves_path}/{actual_filename}")
except FileNotFoundError:
pass
sync = None
if device:
sync = db_device_save_sync_handler.get_sync(
device_id=device.id, save_id=existing_by_hash.id
)
return _build_save_schema(existing_by_hash, device, sync)
if db_save:
update_data: dict = {
"file_size_bytes": scanned_save.file_size_bytes,
"content_hash": scanned_save.content_hash,
}
if slot is not None:
update_data["slot"] = slot
db_save = db_save_handler.update_save(db_save.id, update_data)
else:
scanned_save.rom_id = rom.id
scanned_save.user_id = request.user.id
scanned_save.emulator = emulator
scanned_save.slot = slot
db_save = db_save_handler.add_save(save=scanned_save)
if device:
db_device_save_sync_handler.upsert_sync(
device_id=device.id, save_id=db_save.id, synced_at=db_save.updated_at
)
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)
if slot and autocleanup:
slot_saves = db_save_handler.get_saves(
user_id=request.user.id,
rom_id=rom.id,
slot=slot,
order_by="updated_at",
)
if len(slot_saves) > autocleanup_limit:
for old_save in slot_saves[autocleanup_limit:]:
db_save_handler.delete_save(old_save.id)
try:
await fs_asset_handler.remove_file(old_save.full_path)
except FileNotFoundError:
log.warning(f"Could not delete old save file: {old_save.full_path}")
screenshotFile: UploadFile | None = data.get("screenshotFile", None) # type: ignore
if screenshotFile and screenshotFile.filename:
screenshots_path = fs_asset_handler.build_screenshots_file_path(
@@ -103,7 +258,6 @@ async def add_save(
await fs_asset_handler.write_file(file=screenshotFile, path=screenshots_path)
# Scan or update screenshot
scanned_screenshot = await scan_screenshot(
file_name=screenshotFile.filename,
user=request.user,
@@ -125,7 +279,6 @@ async def add_save(
scanned_screenshot.user_id = request.user.id
db_screenshot_handler.add_screenshot(screenshot=scanned_screenshot)
# Set the last played time for the current user
rom_user = db_rom_handler.get_rom_user(rom_id=rom.id, user_id=request.user.id)
if not rom_user:
rom_user = db_rom_handler.add_rom_user(rom_id=rom.id, user_id=request.user.id)
@@ -133,37 +286,47 @@ async def add_save(
rom_user.id, {"last_played": datetime.now(timezone.utc)}
)
# Refetch the rom to get updated saves
rom = db_rom_handler.get_rom(rom_id)
if not rom:
raise RomNotFoundInDatabaseException(rom_id)
return SaveSchema.model_validate(db_save)
sync = None
if device:
sync = db_device_save_sync_handler.get_sync(
device_id=device.id, save_id=db_save.id
)
return _build_save_schema(db_save, device, sync)
@protected_route(router.get, "", [Scope.ASSETS_READ])
def get_saves(
request: Request, rom_id: int | None = None, platform_id: int | None = None
request: Request,
rom_id: int | None = None,
platform_id: int | None = None,
device_id: str | None = None,
slot: str | None = None,
) -> list[SaveSchema]:
saves = db_save_handler.get_saves(
user_id=request.user.id, rom_id=rom_id, platform_id=platform_id
"""Retrieve saves for the current user."""
device = _resolve_device(
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ
)
return [SaveSchema.model_validate(save) for save in saves]
saves = db_save_handler.get_saves(
user_id=request.user.id, rom_id=rom_id, platform_id=platform_id, slot=slot
)
if not device:
return [_build_save_schema(save) for save in saves]
syncs = db_device_save_sync_handler.get_syncs_for_device_and_saves(
device_id=device.id, save_ids=[s.id for s in saves]
)
sync_by_save_id = {s.save_id: s for s in syncs}
return [
_build_save_schema(save, device, sync_by_save_id.get(save.id)) for save in saves
]
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
def get_save_identifiers(
request: Request,
) -> list[int]:
"""Get save identifiers endpoint
Args:
request (Request): Fastapi Request object
Returns:
list[int]: List of save IDs
"""
def get_save_identifiers(request: Request) -> list[int]:
"""Retrieve save identifiers."""
saves = db_save_handler.get_saves(
user_id=request.user.id,
only_fields=[Save.id],
@@ -172,20 +335,121 @@ def get_save_identifiers(
return [save.id for save in saves]
@protected_route(router.get, "/summary", [Scope.ASSETS_READ])
def get_saves_summary(request: Request, rom_id: int) -> SaveSummarySchema:
"""Retrieve saves summary grouped by slot."""
summary_data = db_save_handler.get_saves_summary(
user_id=request.user.id, rom_id=rom_id
)
slots = [
SlotSummarySchema(
slot=slot_data["slot"],
count=slot_data["count"],
latest=_build_save_schema(slot_data["latest"]),
)
for slot_data in summary_data["slots"]
]
return SaveSummarySchema(total_count=summary_data["total_count"], slots=slots)
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
def get_save(request: Request, id: int) -> SaveSchema:
def get_save(request: Request, id: int, device_id: str | None = None) -> SaveSchema:
"""Retrieve a save by ID."""
device = _resolve_device(
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ
)
save = db_save_handler.get_save(user_id=request.user.id, id=id)
if not save:
error = f"Save with ID {id} not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Save with ID {id} not found",
)
return SaveSchema.model_validate(save)
sync = None
if device:
sync = db_device_save_sync_handler.get_sync(
device_id=device.id, save_id=save.id
)
return _build_save_schema(save, device, sync)
@protected_route(router.get, "/{id}/content", [Scope.ASSETS_READ])
def download_save(
request: Request,
id: int,
device_id: str | None = None,
optimistic: bool = True,
) -> FileResponse:
"""Download a save file."""
device = _resolve_device(
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ
)
save = db_save_handler.get_save(user_id=request.user.id, id=id)
if not save:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Save with ID {id} not found",
)
try:
file_path = fs_asset_handler.validate_path(save.full_path)
except ValueError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Save file not found",
) from None
if not file_path.exists() or not file_path.is_file():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Save file not found on disk",
)
if device and optimistic:
db_device_save_sync_handler.upsert_sync(
device_id=device.id,
save_id=save.id,
synced_at=save.updated_at,
)
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)
return FileResponse(path=str(file_path), filename=save.file_name)
@protected_route(router.post, "/{id}/downloaded", [Scope.DEVICES_WRITE])
def confirm_download(
request: Request,
id: int,
device_id: str = Body(..., embed=True),
) -> SaveSchema:
"""Confirm a save was downloaded successfully."""
save = db_save_handler.get_save(user_id=request.user.id, id=id)
if not save:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Save with ID {id} not found",
)
device = _resolve_device(device_id, request.user.id)
assert device is not None
sync = db_device_save_sync_handler.upsert_sync(
device_id=device_id,
save_id=save.id,
synced_at=save.updated_at,
)
db_device_handler.update_last_seen(device_id=device_id, user_id=request.user.id)
return _build_save_schema(save, device, sync)
@protected_route(router.put, "/{id}", [Scope.ASSETS_WRITE])
async def update_save(request: Request, id: int) -> SaveSchema:
"""Update a save file."""
data = await request.form()
db_save = db_save_handler.get_save(user_id=request.user.id, id=id)
@@ -300,3 +564,51 @@ async def delete_saves(
log.error(error)
return saves
@protected_route(router.post, "/{id}/track", [Scope.DEVICES_WRITE])
def track_save(
request: Request,
id: int,
device_id: str = Body(..., embed=True),
) -> SaveSchema:
"""Re-enable sync tracking for a save on a device."""
save = db_save_handler.get_save(user_id=request.user.id, id=id)
if not save:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Save with ID {id} not found",
)
device = _resolve_device(device_id, request.user.id)
assert device is not None
sync = db_device_save_sync_handler.set_untracked(
device_id=device_id, save_id=id, untracked=False
)
return _build_save_schema(save, device, sync)
@protected_route(router.post, "/{id}/untrack", [Scope.DEVICES_WRITE])
def untrack_save(
request: Request,
id: int,
device_id: str = Body(..., embed=True),
) -> SaveSchema:
"""Disable sync tracking for a save on a device."""
save = db_save_handler.get_save(user_id=request.user.id, id=id)
if not save:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Save with ID {id} not found",
)
device = _resolve_device(device_id, request.user.id)
assert device is not None
sync = db_device_save_sync_handler.set_untracked(
device_id=device_id, save_id=id, untracked=True
)
return _build_save_schema(save, device, sync)

View File

@@ -17,6 +17,8 @@ class Scope(enum.StrEnum):
PLATFORMS_WRITE = "platforms.write"
ASSETS_READ = "assets.read"
ASSETS_WRITE = "assets.write"
DEVICES_READ = "devices.read"
DEVICES_WRITE = "devices.write"
FIRMWARE_READ = "firmware.read"
FIRMWARE_WRITE = "firmware.write"
COLLECTIONS_READ = "collections.read"
@@ -31,6 +33,7 @@ READ_SCOPES_MAP: Final = {
Scope.ROMS_READ: "View ROMs",
Scope.PLATFORMS_READ: "View platforms",
Scope.ASSETS_READ: "View assets",
Scope.DEVICES_READ: "View devices",
Scope.FIRMWARE_READ: "View firmware",
Scope.ROMS_USER_READ: "View user-rom properties",
Scope.COLLECTIONS_READ: "View collections",
@@ -39,6 +42,7 @@ READ_SCOPES_MAP: Final = {
WRITE_SCOPES_MAP: Final = {
Scope.ME_WRITE: "Modify your profile",
Scope.ASSETS_WRITE: "Modify assets",
Scope.DEVICES_WRITE: "Modify devices",
Scope.ROMS_USER_WRITE: "Modify user-rom properties",
Scope.COLLECTIONS_WRITE: "Modify collections",
}

View File

@@ -1,4 +1,6 @@
from .collections_handler import DBCollectionsHandler
from .device_save_sync_handler import DBDeviceSaveSyncHandler
from .devices_handler import DBDevicesHandler
from .firmware_handler import DBFirmwareHandler
from .platforms_handler import DBPlatformsHandler
from .roms_handler import DBRomsHandler
@@ -8,6 +10,9 @@ from .states_handler import DBStatesHandler
from .stats_handler import DBStatsHandler
from .users_handler import DBUsersHandler
db_collection_handler = DBCollectionsHandler()
db_device_handler = DBDevicesHandler()
db_device_save_sync_handler = DBDeviceSaveSyncHandler()
db_firmware_handler = DBFirmwareHandler()
db_platform_handler = DBPlatformsHandler()
db_rom_handler = DBRomsHandler()
@@ -16,4 +21,3 @@ db_screenshot_handler = DBScreenshotsHandler()
db_state_handler = DBStatesHandler()
db_stats_handler = DBStatsHandler()
db_user_handler = DBUsersHandler()
db_collection_handler = DBCollectionsHandler()

View File

@@ -0,0 +1,129 @@
from collections.abc import Sequence
from datetime import datetime, timezone
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
from decorators.database import begin_session
from models.device_save_sync import DeviceSaveSync
from .base_handler import DBBaseHandler
class DBDeviceSaveSyncHandler(DBBaseHandler):
@begin_session
def get_sync(
self,
device_id: str,
save_id: int,
session: Session = None, # type: ignore
) -> DeviceSaveSync | None:
return session.scalar(
select(DeviceSaveSync)
.filter_by(device_id=device_id, save_id=save_id)
.limit(1)
)
@begin_session
def get_syncs_for_device_and_saves(
self,
device_id: str,
save_ids: list[int],
session: Session = None, # type: ignore
) -> Sequence[DeviceSaveSync]:
if not save_ids:
return []
return session.scalars(
select(DeviceSaveSync).filter(
DeviceSaveSync.device_id == device_id,
DeviceSaveSync.save_id.in_(save_ids),
)
).all()
@begin_session
def upsert_sync(
self,
device_id: str,
save_id: int,
synced_at: datetime | None = None,
session: Session = None, # type: ignore
) -> DeviceSaveSync:
now = synced_at or datetime.now(timezone.utc)
existing = session.scalar(
select(DeviceSaveSync)
.filter_by(device_id=device_id, save_id=save_id)
.limit(1)
)
if existing:
session.execute(
update(DeviceSaveSync)
.where(
DeviceSaveSync.device_id == device_id,
DeviceSaveSync.save_id == save_id,
)
.values(last_synced_at=now, is_untracked=False)
.execution_options(synchronize_session="evaluate")
)
existing.last_synced_at = now
existing.is_untracked = False
return existing
else:
sync = DeviceSaveSync(
device_id=device_id,
save_id=save_id,
last_synced_at=now,
is_untracked=False,
)
session.add(sync)
session.flush()
return sync
@begin_session
def set_untracked(
self,
device_id: str,
save_id: int,
untracked: bool,
session: Session = None, # type: ignore
) -> DeviceSaveSync | None:
existing = session.scalar(
select(DeviceSaveSync)
.filter_by(device_id=device_id, save_id=save_id)
.limit(1)
)
if existing:
session.execute(
update(DeviceSaveSync)
.where(
DeviceSaveSync.device_id == device_id,
DeviceSaveSync.save_id == save_id,
)
.values(is_untracked=untracked)
.execution_options(synchronize_session="evaluate")
)
existing.is_untracked = untracked
return existing
elif untracked:
now = datetime.now(timezone.utc)
sync = DeviceSaveSync(
device_id=device_id,
save_id=save_id,
last_synced_at=now,
is_untracked=True,
)
session.add(sync)
session.flush()
return sync
return None
@begin_session
def delete_syncs_for_device(
self,
device_id: str,
session: Session = None, # type: ignore
) -> None:
session.execute(
delete(DeviceSaveSync)
.where(DeviceSaveSync.device_id == device_id)
.execution_options(synchronize_session="evaluate")
)

View File

@@ -0,0 +1,111 @@
from collections.abc import Sequence
from datetime import datetime, timezone
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
from decorators.database import begin_session
from models.device import Device
from .base_handler import DBBaseHandler
class DBDevicesHandler(DBBaseHandler):
@begin_session
def add_device(
self,
device: Device,
session: Session = None, # type: ignore
) -> Device:
return session.merge(device)
@begin_session
def get_device(
self,
device_id: str,
user_id: int,
session: Session = None, # type: ignore
) -> Device | None:
return session.scalar(
select(Device).filter_by(id=device_id, user_id=user_id).limit(1)
)
@begin_session
def get_device_by_fingerprint(
self,
user_id: int,
mac_address: str | None = None,
hostname: str | None = None,
platform: str | None = None,
session: Session = None, # type: ignore
) -> Device | None:
if mac_address:
device = session.scalar(
select(Device)
.filter_by(user_id=user_id, mac_address=mac_address)
.limit(1)
)
if device:
return device
if hostname and platform:
return session.scalar(
select(Device)
.filter_by(user_id=user_id, hostname=hostname, platform=platform)
.limit(1)
)
return None
@begin_session
def get_devices(
self,
user_id: int,
session: Session = None, # type: ignore
) -> Sequence[Device]:
return session.scalars(select(Device).filter_by(user_id=user_id)).all()
@begin_session
def update_device(
self,
device_id: str,
user_id: int,
data: dict,
session: Session = None, # type: ignore
) -> Device | None:
session.execute(
update(Device)
.where(Device.id == device_id, Device.user_id == user_id)
.values(**data)
.execution_options(synchronize_session="evaluate")
)
return session.scalar(
select(Device).filter_by(id=device_id, user_id=user_id).limit(1)
)
@begin_session
def update_last_seen(
self,
device_id: str,
user_id: int,
session: Session = None, # type: ignore
) -> None:
session.execute(
update(Device)
.where(Device.id == device_id, Device.user_id == user_id)
.values(last_seen=datetime.now(timezone.utc))
.execution_options(synchronize_session="evaluate")
)
@begin_session
def delete_device(
self,
device_id: str,
user_id: int,
session: Session = None, # type: ignore
) -> None:
session.execute(
delete(Device)
.where(Device.id == device_id, Device.user_id == user_id)
.execution_options(synchronize_session="evaluate")
)

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from typing import Literal
from sqlalchemy import and_, delete, select, update
from sqlalchemy import and_, asc, delete, desc, select, update
from sqlalchemy.orm import QueryableAttribute, Session, load_only
from decorators.database import begin_session
@@ -42,12 +43,29 @@ class DBSavesHandler(DBBaseHandler):
.limit(1)
).first()
@begin_session
def get_save_by_content_hash(
self,
user_id: int,
rom_id: int,
content_hash: str,
session: Session = None, # type: ignore
) -> Save | None:
return session.scalar(
select(Save)
.filter_by(rom_id=rom_id, user_id=user_id, content_hash=content_hash)
.limit(1)
)
@begin_session
def get_saves(
self,
user_id: int,
rom_id: int | None = None,
platform_id: int | None = None,
slot: str | None = None,
order_by: Literal["updated_at", "created_at"] | None = None,
order_dir: Literal["asc", "desc"] = "desc",
only_fields: Sequence[QueryableAttribute] | None = None,
session: Session = None, # type: ignore
) -> Sequence[Save]:
@@ -61,6 +79,14 @@ class DBSavesHandler(DBBaseHandler):
Rom.platform_id == platform_id
)
if slot is not None:
query = query.filter(Save.slot == slot)
if order_by:
order_col = getattr(Save, order_by)
order_fn = asc if order_dir == "asc" else desc
query = query.order_by(order_fn(order_col))
if only_fields:
query = query.options(load_only(*only_fields))
@@ -125,3 +151,28 @@ class DBSavesHandler(DBBaseHandler):
)
return missing_saves
@begin_session
def get_saves_summary(
self,
user_id: int,
rom_id: int,
session: Session = None, # type: ignore
) -> dict:
saves = session.scalars(
select(Save)
.filter_by(user_id=user_id, rom_id=rom_id)
.order_by(desc(Save.updated_at))
).all()
slots_data: dict[str | None, dict] = {}
for save in saves:
slot_key = save.slot
if slot_key not in slots_data:
slots_data[slot_key] = {"slot": slot_key, "count": 0, "latest": save}
slots_data[slot_key]["count"] += 1
return {
"total_count": len(saves),
"slots": list(slots_data.values()),
}

View File

@@ -1,11 +1,44 @@
import hashlib
import os
import zipfile
from config import ASSETS_BASE_PATH
from logger.logger import log
from models.user import User
from .base_handler import FSHandler
def compute_file_hash(file_path: str) -> str:
hash_obj = hashlib.md5(usedforsecurity=False)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_obj.update(chunk)
return hash_obj.hexdigest()
def compute_zip_hash(zip_path: str) -> str:
with zipfile.ZipFile(zip_path, "r") as zf:
file_hashes = []
for name in sorted(zf.namelist()):
if not name.endswith("/"):
content = zf.read(name)
file_hash = hashlib.md5(content, usedforsecurity=False).hexdigest()
file_hashes.append(f"{name}:{file_hash}")
combined = "\n".join(file_hashes)
return hashlib.md5(combined.encode(), usedforsecurity=False).hexdigest()
def compute_content_hash(file_path: str) -> str | None:
try:
if zipfile.is_zipfile(file_path):
return compute_zip_hash(file_path)
return compute_file_hash(file_path)
except Exception as e:
log.debug(f"Failed to compute content hash for {file_path}: {e}")
return None
class FSAssetsHandler(FSHandler):
def __init__(self) -> None:
super().__init__(base_path=ASSETS_BASE_PATH)

View File

@@ -4,10 +4,12 @@ from typing import Any
import socketio # type: ignore
from config import ASSETS_BASE_PATH
from config.config_manager import config_manager as cm
from endpoints.responses.rom import SimpleRomSchema
from handler.database import db_platform_handler, db_rom_handler
from handler.filesystem import fs_asset_handler, fs_firmware_handler
from handler.filesystem.assets_handler import compute_content_hash
from handler.filesystem.roms_handler import FSRom
from handler.metadata import (
meta_flashpoint_handler,
@@ -817,11 +819,11 @@ async def scan_rom(
return Rom(**rom_attrs)
async def _scan_asset(file_name: str, asset_path: str):
async def _scan_asset(file_name: str, asset_path: str, should_hash: bool = False):
file_path = f"{asset_path}/{file_name}"
file_size = await fs_asset_handler.get_file_size(file_path)
return {
result = {
"file_path": asset_path,
"file_name": file_name,
"file_name_no_tags": fs_asset_handler.get_file_name_with_no_tags(file_name),
@@ -830,6 +832,12 @@ async def _scan_asset(file_name: str, asset_path: str):
"file_size_bytes": file_size,
}
if should_hash:
absolute_path = f"{ASSETS_BASE_PATH}/{file_path}"
result["content_hash"] = compute_content_hash(absolute_path)
return result
async def scan_save(
file_name: str,
@@ -841,7 +849,7 @@ async def scan_save(
saves_path = fs_asset_handler.build_saves_file_path(
user=user, platform_fs_slug=platform_fs_slug, rom_id=rom_id, emulator=emulator
)
scanned_asset = await _scan_asset(file_name, saves_path)
scanned_asset = await _scan_asset(file_name, saves_path, should_hash=True)
return Save(**scanned_asset)

View File

@@ -28,6 +28,7 @@ from endpoints import (
auth,
collections,
configs,
device,
feeds,
firmware,
gamelist,
@@ -122,6 +123,7 @@ app.middleware("http")(set_context_middleware)
app.include_router(heartbeat.router, prefix="/api")
app.include_router(auth.router, prefix="/api")
app.include_router(user.router, prefix="/api")
app.include_router(device.router, prefix="/api")
app.include_router(platform.router, prefix="/api")
app.include_router(rom.router, prefix="/api")
app.include_router(search.router, prefix="/api")

View File

@@ -14,6 +14,7 @@ from models.base import (
)
if TYPE_CHECKING:
from models.device_save_sync import DeviceSaveSync
from models.rom import Rom
from models.user import User
@@ -54,9 +55,16 @@ class Save(RomAsset):
__table_args__ = {"extend_existing": True}
emulator: Mapped[str | None] = mapped_column(String(length=50))
slot: Mapped[str | None] = mapped_column(String(length=255))
content_hash: Mapped[str | None] = mapped_column(String(length=32))
rom: Mapped[Rom] = relationship(lazy="joined", back_populates="saves")
user: Mapped[User] = relationship(lazy="joined", back_populates="saves")
device_syncs: Mapped[list[DeviceSaveSync]] = relationship(
back_populates="save",
cascade="all, delete-orphan",
lazy="raise",
)
@cached_property
def screenshot(self) -> Screenshot | None:

49
backend/models/device.py Normal file
View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import enum
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import TIMESTAMP, Boolean, Enum, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from models.base import BaseModel
if TYPE_CHECKING:
from models.device_save_sync import DeviceSaveSync
from models.user import User
class SyncMode(enum.StrEnum):
API = "api"
FILE_TRANSFER = "file_transfer"
PUSH_PULL = "push_pull"
class Device(BaseModel):
__tablename__ = "devices"
__table_args__ = {"extend_existing": True}
id: Mapped[str] = mapped_column(String(255), primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
name: Mapped[str | None] = mapped_column(String(255))
platform: Mapped[str | None] = mapped_column(String(50))
client: Mapped[str | None] = mapped_column(String(50))
client_version: Mapped[str | None] = mapped_column(String(50))
ip_address: Mapped[str | None] = mapped_column(String(45))
mac_address: Mapped[str | None] = mapped_column(String(17))
hostname: Mapped[str | None] = mapped_column(String(255))
sync_mode: Mapped[SyncMode] = mapped_column(Enum(SyncMode), default=SyncMode.API)
sync_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
last_seen: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))
user: Mapped[User] = relationship(lazy="joined")
save_syncs: Mapped[list[DeviceSaveSync]] = relationship(
back_populates="device",
cascade="all, delete-orphan",
lazy="raise",
)

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import TIMESTAMP, Boolean, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from models.base import BaseModel
if TYPE_CHECKING:
from models.assets import Save
from models.device import Device
class DeviceSaveSync(BaseModel):
__tablename__ = "device_save_sync"
__table_args__ = {"extend_existing": True}
device_id: Mapped[str] = mapped_column(
String(255),
ForeignKey("devices.id", ondelete="CASCADE"),
primary_key=True,
)
save_id: Mapped[int] = mapped_column(
ForeignKey("saves.id", ondelete="CASCADE"),
primary_key=True,
)
last_synced_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True))
is_untracked: Mapped[bool] = mapped_column(Boolean, default=False)
device: Mapped[Device] = relationship(back_populates="save_syncs", lazy="raise")
save: Mapped[Save] = relationship(back_populates="device_syncs", lazy="raise")

View File

@@ -22,6 +22,7 @@ from utils.database import CustomJSON
if TYPE_CHECKING:
from models.assets import Save, Screenshot, State
from models.collection import Collection, SmartCollection
from models.device import Device
from models.rom import RomNote, RomUser
@@ -79,6 +80,9 @@ class User(BaseModel, SimpleUser):
smart_collections: Mapped[list["SmartCollection"]] = relationship(
lazy="raise", back_populates="user"
)
devices: Mapped[list["Device"]] = relationship(
lazy="raise", back_populates="user", cascade="all, delete-orphan"
)
@classmethod
def kiosk_mode_user(cls) -> User:

View File

@@ -14,6 +14,8 @@ from handler.database import (
db_user_handler,
)
from models.assets import Save, Screenshot, State
from models.device import Device
from models.device_save_sync import DeviceSaveSync
from models.platform import Platform
from models.rom import Rom
from models.user import Role, User
@@ -30,6 +32,8 @@ def setup_database():
@pytest.fixture(autouse=True)
def clear_database():
with session.begin() as s:
s.query(DeviceSaveSync).delete(synchronize_session="evaluate")
s.query(Device).delete(synchronize_session="evaluate")
s.query(Save).delete(synchronize_session="evaluate")
s.query(State).delete(synchronize_session="evaluate")
s.query(Screenshot).delete(synchronize_session="evaluate")

View File

@@ -0,0 +1,509 @@
from datetime import timedelta
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from main import app
from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES
from handler.auth import oauth_handler
from handler.database import db_device_handler
from handler.redis_handler import sync_cache
from models.device import Device
from models.user import User
@pytest.fixture
def client():
with TestClient(app) as client:
yield client
@pytest.fixture(autouse=True)
def clear_cache():
yield
sync_cache.flushall()
@pytest.fixture
def editor_access_token(editor_user: User):
return oauth_handler.create_oauth_token(
data={
"sub": editor_user.username,
"iss": "romm:oauth",
"scopes": " ".join(editor_user.oauth_scopes),
"type": "access",
},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
class TestDeviceEndpoints:
def test_register_device(self, client, access_token: str):
response = client.post(
"/api/devices",
json={
"name": "Test Device",
"platform": "android",
"client": "argosy",
"client_version": "0.16.0",
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["name"] == "Test Device"
assert "device_id" in data
assert "created_at" in data
def test_register_device_minimal(self, client, access_token: str):
response = client.post(
"/api/devices",
json={},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["name"] is None
assert "device_id" in data
def test_list_devices(self, client, access_token: str, admin_user: User):
db_device_handler.add_device(
Device(
id="test-device-1",
user_id=admin_user.id,
name="Device 1",
)
)
db_device_handler.add_device(
Device(
id="test-device-2",
user_id=admin_user.id,
name="Device 2",
)
)
response = client.get(
"/api/devices",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 2
names = [d["name"] for d in data]
assert "Device 1" in names
assert "Device 2" in names
def test_get_device(self, client, access_token: str, admin_user: User):
device = db_device_handler.add_device(
Device(
id="test-device-get",
user_id=admin_user.id,
name="Get Test Device",
platform="linux",
)
)
response = client.get(
f"/api/devices/{device.id}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["id"] == "test-device-get"
assert data["name"] == "Get Test Device"
assert data["platform"] == "linux"
def test_get_device_not_found(self, client, access_token: str):
response = client.get(
"/api/devices/nonexistent-device",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_update_device(self, client, access_token: str, admin_user: User):
device = db_device_handler.add_device(
Device(
id="test-device-update",
user_id=admin_user.id,
name="Original Name",
)
)
response = client.put(
f"/api/devices/{device.id}",
json={
"name": "Updated Name",
"platform": "android",
"client": "daijishou",
"client_version": "4.0.0",
"ip_address": "192.168.1.100",
"mac_address": "AA:BB:CC:DD:EE:FF",
"hostname": "my-odin3",
"sync_enabled": False,
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["name"] == "Updated Name"
assert data["platform"] == "android"
assert data["client"] == "daijishou"
assert data["client_version"] == "4.0.0"
assert data["ip_address"] == "192.168.1.100"
assert data["mac_address"] == "AA:BB:CC:DD:EE:FF"
assert data["hostname"] == "my-odin3"
assert data["sync_enabled"] is False
def test_delete_device(self, client, access_token: str, admin_user: User):
device = db_device_handler.add_device(
Device(
id="test-device-delete",
user_id=admin_user.id,
name="To Delete",
)
)
response = client.delete(
f"/api/devices/{device.id}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
get_response = client.get(
f"/api/devices/{device.id}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert get_response.status_code == status.HTTP_404_NOT_FOUND
class TestDeviceUserIsolation:
def test_list_devices_only_returns_own_devices(
self,
client,
access_token: str,
editor_access_token: str,
admin_user: User,
editor_user: User,
):
db_device_handler.add_device(
Device(id="admin-device", user_id=admin_user.id, name="Admin Device")
)
db_device_handler.add_device(
Device(id="editor-device", user_id=editor_user.id, name="Editor Device")
)
admin_response = client.get(
"/api/devices",
headers={"Authorization": f"Bearer {access_token}"},
)
assert admin_response.status_code == status.HTTP_200_OK
admin_devices = admin_response.json()
assert len(admin_devices) == 1
assert admin_devices[0]["name"] == "Admin Device"
editor_response = client.get(
"/api/devices",
headers={"Authorization": f"Bearer {editor_access_token}"},
)
assert editor_response.status_code == status.HTTP_200_OK
editor_devices = editor_response.json()
assert len(editor_devices) == 1
assert editor_devices[0]["name"] == "Editor Device"
def test_cannot_get_other_users_device(
self,
client,
editor_access_token: str,
admin_user: User,
):
device = db_device_handler.add_device(
Device(id="admin-only-device", user_id=admin_user.id, name="Admin Only")
)
response = client.get(
f"/api/devices/{device.id}",
headers={"Authorization": f"Bearer {editor_access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_update_other_users_device(
self,
client,
editor_access_token: str,
admin_user: User,
):
device = db_device_handler.add_device(
Device(id="admin-protected-device", user_id=admin_user.id, name="Protected")
)
response = client.put(
f"/api/devices/{device.id}",
json={"name": "Hacked Name"},
headers={"Authorization": f"Bearer {editor_access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
original = db_device_handler.get_device(
device_id=device.id, user_id=admin_user.id
)
assert original.name == "Protected"
def test_cannot_delete_other_users_device(
self,
client,
editor_access_token: str,
admin_user: User,
):
device = db_device_handler.add_device(
Device(id="admin-nodelete-device", user_id=admin_user.id, name="No Delete")
)
response = client.delete(
f"/api/devices/{device.id}",
headers={"Authorization": f"Bearer {editor_access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
still_exists = db_device_handler.get_device(
device_id=device.id, user_id=admin_user.id
)
assert still_exists is not None
class TestDeviceDuplicateHandling:
def test_duplicate_mac_address_returns_existing(
self, client, access_token: str, admin_user: User
):
db_device_handler.add_device(
Device(
id="existing-mac-device",
user_id=admin_user.id,
name="Existing Device",
mac_address="AA:BB:CC:DD:EE:FF",
)
)
response = client.post(
"/api/devices",
json={
"name": "New Device",
"mac_address": "AA:BB:CC:DD:EE:FF",
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["device_id"] == "existing-mac-device"
assert data["name"] == "Existing Device"
def test_duplicate_hostname_platform_returns_existing(
self, client, access_token: str, admin_user: User
):
db_device_handler.add_device(
Device(
id="existing-hostname-device",
user_id=admin_user.id,
name="Existing Device",
hostname="my-device",
platform="android",
)
)
response = client.post(
"/api/devices",
json={
"name": "New Device",
"hostname": "my-device",
"platform": "android",
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["device_id"] == "existing-hostname-device"
assert data["name"] == "Existing Device"
def test_duplicate_with_allow_existing_false_returns_409(
self, client, access_token: str, admin_user: User
):
db_device_handler.add_device(
Device(
id="reject-duplicate-device",
user_id=admin_user.id,
name="Existing Device",
mac_address="FF:EE:DD:CC:BB:AA",
)
)
response = client.post(
"/api/devices",
json={
"name": "New Device",
"mac_address": "FF:EE:DD:CC:BB:AA",
"allow_existing": False,
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_409_CONFLICT
data = response.json()["detail"]
assert data["error"] == "device_exists"
assert data["device_id"] == "reject-duplicate-device"
def test_allow_existing_returns_existing_device(
self, client, access_token: str, admin_user: User
):
existing = db_device_handler.add_device(
Device(
id="allow-existing-device",
user_id=admin_user.id,
name="Existing Device",
mac_address="11:22:33:44:55:66",
)
)
response = client.post(
"/api/devices",
json={
"name": "New Device Name",
"mac_address": "11:22:33:44:55:66",
"allow_existing": True,
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["device_id"] == existing.id
assert data["name"] == "Existing Device"
def test_allow_existing_with_reset_syncs(
self, client, access_token: str, admin_user: User, rom
):
from handler.database import db_device_save_sync_handler, db_save_handler
from models.assets import Save
existing = db_device_handler.add_device(
Device(
id="reset-syncs-device",
user_id=admin_user.id,
name="Device With Syncs",
mac_address="77:88:99:AA:BB:CC",
)
)
save = db_save_handler.add_save(
Save(
file_name="test.sav",
file_name_no_tags="test",
file_name_no_ext="test",
file_extension="sav",
file_path="/saves",
file_size_bytes=100,
rom_id=rom.id,
user_id=admin_user.id,
)
)
db_device_save_sync_handler.upsert_sync(device_id=existing.id, save_id=save.id)
sync_before = db_device_save_sync_handler.get_sync(
device_id=existing.id, save_id=save.id
)
assert sync_before is not None
response = client.post(
"/api/devices",
json={
"mac_address": "77:88:99:AA:BB:CC",
"allow_existing": True,
"reset_syncs": True,
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
assert response.json()["device_id"] == existing.id
sync_after = db_device_save_sync_handler.get_sync(
device_id=existing.id, save_id=save.id
)
assert sync_after is None
def test_allow_duplicate_creates_new_device(
self, client, access_token: str, admin_user: User
):
existing = db_device_handler.add_device(
Device(
id="original-device",
user_id=admin_user.id,
name="Original Device",
mac_address="DD:EE:FF:00:11:22",
)
)
response = client.post(
"/api/devices",
json={
"name": "Duplicate Install",
"mac_address": "DD:EE:FF:00:11:22",
"allow_duplicate": True,
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["device_id"] != existing.id
assert data["name"] == "Duplicate Install"
def test_no_conflict_without_fingerprint(self, client, access_token: str):
response1 = client.post(
"/api/devices",
json={"name": "Device 1"},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response1.status_code == status.HTTP_201_CREATED
response2 = client.post(
"/api/devices",
json={"name": "Device 2"},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response2.status_code == status.HTTP_201_CREATED
assert response1.json()["device_id"] != response2.json()["device_id"]
def test_hostname_only_no_conflict_without_platform(
self, client, access_token: str, admin_user: User
):
db_device_handler.add_device(
Device(
id="hostname-only-device",
user_id=admin_user.id,
name="Existing",
hostname="my-device",
)
)
response = client.post(
"/api/devices",
json={
"name": "New Device",
"hostname": "my-device",
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_201_CREATED

File diff suppressed because it is too large Load Diff

View File

@@ -166,3 +166,280 @@ class TestDBSavesHandlerPlatformFiltering:
# Verify the save is associated with the correct platform through ROM
assert retrieved_save.rom.platform_id == platform.id
class TestDBSavesHandlerSlotFiltering:
def test_get_saves_with_slot_filter(self, admin_user: User, rom: Rom):
save1 = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="slot_test_1.sav",
file_name_no_tags="slot_test_1",
file_name_no_ext="slot_test_1",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="Slot A",
)
save2 = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="slot_test_2.sav",
file_name_no_tags="slot_test_2",
file_name_no_ext="slot_test_2",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="Slot A",
)
save3 = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="slot_test_3.sav",
file_name_no_tags="slot_test_3",
file_name_no_ext="slot_test_3",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="Slot B",
)
db_save_handler.add_save(save1)
db_save_handler.add_save(save2)
db_save_handler.add_save(save3)
slot_a_saves = db_save_handler.get_saves(
user_id=admin_user.id, rom_id=rom.id, slot="Slot A"
)
assert len(slot_a_saves) == 2
assert all(s.slot == "Slot A" for s in slot_a_saves)
slot_b_saves = db_save_handler.get_saves(
user_id=admin_user.id, rom_id=rom.id, slot="Slot B"
)
assert len(slot_b_saves) == 1
assert slot_b_saves[0].slot == "Slot B"
def test_get_saves_with_null_slot_filter(self, admin_user: User, rom: Rom):
save_with_slot = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="with_slot.sav",
file_name_no_tags="with_slot",
file_name_no_ext="with_slot",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="Main",
)
save_without_slot = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="without_slot.sav",
file_name_no_tags="without_slot",
file_name_no_ext="without_slot",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot=None,
)
db_save_handler.add_save(save_with_slot)
db_save_handler.add_save(save_without_slot)
all_saves = db_save_handler.get_saves(user_id=admin_user.id, rom_id=rom.id)
assert len(all_saves) >= 2
def test_get_saves_order_by(self, admin_user: User, rom: Rom):
from datetime import datetime, timedelta, timezone
base_time = datetime.now(timezone.utc)
save1 = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="order_test_1.sav",
file_name_no_tags="order_test_1",
file_name_no_ext="order_test_1",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="order_test",
)
save2 = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="order_test_2.sav",
file_name_no_tags="order_test_2",
file_name_no_ext="order_test_2",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="order_test",
)
created1 = db_save_handler.add_save(save1)
created2 = db_save_handler.add_save(save2)
db_save_handler.update_save(
created1.id, {"updated_at": base_time - timedelta(hours=2)}
)
db_save_handler.update_save(
created2.id, {"updated_at": base_time - timedelta(hours=1)}
)
ordered_saves_desc = db_save_handler.get_saves(
user_id=admin_user.id,
rom_id=rom.id,
slot="order_test",
order_by="updated_at",
)
assert len(ordered_saves_desc) == 2
assert ordered_saves_desc[0].id == created2.id
assert ordered_saves_desc[1].id == created1.id
ordered_saves_asc = db_save_handler.get_saves(
user_id=admin_user.id,
rom_id=rom.id,
slot="order_test",
order_by="updated_at",
order_dir="asc",
)
assert len(ordered_saves_asc) == 2
assert ordered_saves_asc[0].id == created1.id
assert ordered_saves_asc[1].id == created2.id
class TestDBSavesHandlerSummary:
def test_get_saves_summary_basic(self, admin_user: User, rom: Rom):
from datetime import datetime, timedelta, timezone
base_time = datetime.now(timezone.utc)
configs = [
("summary_a_1.sav", "Slot A", -3),
("summary_a_2.sav", "Slot A", -1),
("summary_b_1.sav", "Slot B", -2),
("summary_none_1.sav", None, -4),
]
for filename, slot, hours_offset in configs:
save = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name=filename,
file_name_no_tags=filename.replace(".sav", ""),
file_name_no_ext=filename.replace(".sav", ""),
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot=slot,
)
created = db_save_handler.add_save(save)
db_save_handler.update_save(
created.id, {"updated_at": base_time + timedelta(hours=hours_offset)}
)
summary = db_save_handler.get_saves_summary(
user_id=admin_user.id, rom_id=rom.id
)
assert "total_count" in summary
assert "slots" in summary
assert summary["total_count"] == 4
assert len(summary["slots"]) == 3
def test_get_saves_summary_latest_per_slot(self, admin_user: User, rom: Rom):
from datetime import datetime, timedelta, timezone
base_time = datetime.now(timezone.utc)
old_save = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="latest_test_old.sav",
file_name_no_tags="latest_test_old",
file_name_no_ext="latest_test_old",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="latest_test",
)
new_save = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name="latest_test_new.sav",
file_name_no_tags="latest_test_new",
file_name_no_ext="latest_test_new",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="latest_test",
)
old_created = db_save_handler.add_save(old_save)
new_created = db_save_handler.add_save(new_save)
db_save_handler.update_save(
old_created.id, {"updated_at": base_time - timedelta(hours=5)}
)
db_save_handler.update_save(
new_created.id, {"updated_at": base_time - timedelta(hours=1)}
)
summary = db_save_handler.get_saves_summary(
user_id=admin_user.id, rom_id=rom.id
)
latest_slot = next(
(s for s in summary["slots"] if s["slot"] == "latest_test"), None
)
assert latest_slot is not None
assert latest_slot["count"] == 2
assert latest_slot["latest"].file_name == "latest_test_new.sav"
def test_get_saves_summary_empty_rom(self, admin_user: User):
summary = db_save_handler.get_saves_summary(
user_id=admin_user.id, rom_id=999999
)
assert summary["total_count"] == 0
assert summary["slots"] == []
def test_get_saves_summary_count_accuracy(self, admin_user: User, rom: Rom):
for i in range(5):
save = Save(
rom_id=rom.id,
user_id=admin_user.id,
file_name=f"count_test_{i}.sav",
file_name_no_tags=f"count_test_{i}",
file_name_no_ext=f"count_test_{i}",
file_extension="sav",
emulator="test_emu",
file_path=f"{rom.platform_slug}/saves",
file_size_bytes=100,
slot="count_test",
)
db_save_handler.add_save(save)
summary = db_save_handler.get_saves_summary(
user_id=admin_user.id, rom_id=rom.id
)
count_slot = next(
(s for s in summary["slots"] if s["slot"] == "count_test"), None
)
assert count_slot is not None
assert count_slot["count"] == 5

View File

@@ -0,0 +1,7 @@
from datetime import datetime, timezone
def to_utc(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)