mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 00:27:41 +01:00
Merge pull request #2917 from tmgast/feature/device-registration-save-sync
Add device-based save synchronization
This commit is contained in:
102
backend/alembic/versions/0068_save_sync.py
Normal file
102
backend/alembic/versions/0068_save_sync.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Add device-based save synchronization
|
||||
|
||||
Revision ID: 0068_save_sync
|
||||
Revises: 0067_romfile_category_enum_cheat
|
||||
Create Date: 2026-01-17
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "0068_save_sync"
|
||||
down_revision = "0067_romfile_category_enum_cheat"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"devices",
|
||||
sa.Column("id", sa.String(255), primary_key=True),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=True),
|
||||
sa.Column("platform", sa.String(50), nullable=True),
|
||||
sa.Column("client", sa.String(50), nullable=True),
|
||||
sa.Column("client_version", sa.String(50), nullable=True),
|
||||
sa.Column("ip_address", sa.String(45), nullable=True),
|
||||
sa.Column("mac_address", sa.String(17), nullable=True),
|
||||
sa.Column("hostname", sa.String(255), nullable=True),
|
||||
sa.Column(
|
||||
"sync_mode",
|
||||
sa.Enum("API", "FILE_TRANSFER", "PUSH_PULL", name="syncmode"),
|
||||
nullable=False,
|
||||
server_default="API",
|
||||
),
|
||||
sa.Column("sync_enabled", sa.Boolean(), nullable=False, server_default="1"),
|
||||
sa.Column("last_seen", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"device_save_sync",
|
||||
sa.Column("device_id", sa.String(255), nullable=False),
|
||||
sa.Column("save_id", sa.Integer(), nullable=False),
|
||||
sa.Column("last_synced_at", sa.TIMESTAMP(timezone=True), nullable=False),
|
||||
sa.Column("is_untracked", sa.Boolean(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["device_id"], ["devices.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["save_id"], ["saves.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("device_id", "save_id"),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("saves", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("slot", sa.String(255), nullable=True))
|
||||
batch_op.add_column(sa.Column("content_hash", sa.String(32), nullable=True))
|
||||
|
||||
op.create_index("ix_devices_user_id", "devices", ["user_id"])
|
||||
op.create_index("ix_devices_last_seen", "devices", ["last_seen"])
|
||||
op.create_index("ix_device_save_sync_save_id", "device_save_sync", ["save_id"])
|
||||
op.create_index("ix_saves_slot", "saves", ["slot"])
|
||||
op.create_index(
|
||||
"ix_saves_rom_user_hash", "saves", ["rom_id", "user_id", "content_hash"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("ix_saves_rom_user_hash", "saves")
|
||||
op.drop_index("ix_saves_slot", "saves")
|
||||
op.drop_index("ix_device_save_sync_save_id", "device_save_sync")
|
||||
op.drop_index("ix_devices_last_seen", "devices")
|
||||
op.drop_index("ix_devices_user_id", "devices")
|
||||
|
||||
with op.batch_alter_table("saves", schema=None) as batch_op:
|
||||
batch_op.drop_column("content_hash")
|
||||
batch_op.drop_column("slot")
|
||||
|
||||
op.drop_table("device_save_sync")
|
||||
op.drop_table("devices")
|
||||
op.execute("DROP TYPE IF EXISTS syncmode")
|
||||
179
backend/endpoints/device.py
Normal file
179
backend/endpoints/device.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from decorators.auth import protected_route
|
||||
from endpoints.responses.device import DeviceCreateResponse, DeviceSchema
|
||||
from handler.auth.constants import Scope
|
||||
from handler.database import db_device_handler, db_device_save_sync_handler
|
||||
from logger.logger import log
|
||||
from models.device import Device
|
||||
from utils.router import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/devices",
|
||||
tags=["devices"],
|
||||
)
|
||||
|
||||
|
||||
class DeviceCreatePayload(BaseModel):
|
||||
name: str | None = None
|
||||
platform: str | None = None
|
||||
client: str | None = None
|
||||
client_version: str | None = None
|
||||
ip_address: str | None = None
|
||||
mac_address: str | None = None
|
||||
hostname: str | None = None
|
||||
allow_existing: bool = True
|
||||
allow_duplicate: bool = False
|
||||
reset_syncs: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _duplicate_disables_existing(self) -> "DeviceCreatePayload":
|
||||
if self.allow_duplicate:
|
||||
self.allow_existing = False
|
||||
return self
|
||||
|
||||
|
||||
class DeviceUpdatePayload(BaseModel):
|
||||
name: str | None = None
|
||||
platform: str | None = None
|
||||
client: str | None = None
|
||||
client_version: str | None = None
|
||||
ip_address: str | None = None
|
||||
mac_address: str | None = None
|
||||
hostname: str | None = None
|
||||
sync_enabled: bool | None = None
|
||||
|
||||
|
||||
@protected_route(router.post, "", [Scope.DEVICES_WRITE])
|
||||
def register_device(
|
||||
request: Request,
|
||||
response: Response,
|
||||
payload: DeviceCreatePayload,
|
||||
) -> DeviceCreateResponse:
|
||||
existing_device = None
|
||||
if not payload.allow_duplicate:
|
||||
existing_device = db_device_handler.get_device_by_fingerprint(
|
||||
user_id=request.user.id,
|
||||
mac_address=payload.mac_address,
|
||||
hostname=payload.hostname,
|
||||
platform=payload.platform,
|
||||
)
|
||||
|
||||
if existing_device:
|
||||
if not payload.allow_existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"error": "device_exists",
|
||||
"message": "A device with this fingerprint already exists",
|
||||
"device_id": existing_device.id,
|
||||
},
|
||||
)
|
||||
|
||||
if payload.reset_syncs:
|
||||
db_device_save_sync_handler.delete_syncs_for_device(
|
||||
device_id=existing_device.id
|
||||
)
|
||||
|
||||
db_device_handler.update_last_seen(
|
||||
device_id=existing_device.id, user_id=request.user.id
|
||||
)
|
||||
log.info(
|
||||
f"Returned existing device {existing_device.id} for user {request.user.username}"
|
||||
)
|
||||
|
||||
response.status_code = status.HTTP_200_OK
|
||||
return DeviceCreateResponse(
|
||||
device_id=existing_device.id,
|
||||
name=existing_device.name,
|
||||
created_at=existing_device.created_at,
|
||||
)
|
||||
|
||||
response.status_code = status.HTTP_201_CREATED
|
||||
device_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
device = Device(
|
||||
id=device_id,
|
||||
user_id=request.user.id,
|
||||
name=payload.name,
|
||||
platform=payload.platform,
|
||||
client=payload.client,
|
||||
client_version=payload.client_version,
|
||||
ip_address=payload.ip_address,
|
||||
mac_address=payload.mac_address,
|
||||
hostname=payload.hostname,
|
||||
last_seen=now,
|
||||
)
|
||||
|
||||
db_device = db_device_handler.add_device(device)
|
||||
log.info(f"Registered device {device_id} for user {request.user.username}")
|
||||
|
||||
return DeviceCreateResponse(
|
||||
device_id=db_device.id,
|
||||
name=db_device.name,
|
||||
created_at=db_device.created_at,
|
||||
)
|
||||
|
||||
|
||||
@protected_route(router.get, "", [Scope.DEVICES_READ])
|
||||
def get_devices(request: Request) -> list[DeviceSchema]:
|
||||
devices = db_device_handler.get_devices(user_id=request.user.id)
|
||||
return [DeviceSchema.model_validate(device) for device in devices]
|
||||
|
||||
|
||||
@protected_route(router.get, "/{device_id}", [Scope.DEVICES_READ])
|
||||
def get_device(request: Request, device_id: str) -> DeviceSchema:
|
||||
device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id)
|
||||
if not device:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device with ID {device_id} not found",
|
||||
)
|
||||
return DeviceSchema.model_validate(device)
|
||||
|
||||
|
||||
@protected_route(router.put, "/{device_id}", [Scope.DEVICES_WRITE])
|
||||
def update_device(
|
||||
request: Request,
|
||||
device_id: str,
|
||||
payload: DeviceUpdatePayload,
|
||||
) -> DeviceSchema:
|
||||
device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id)
|
||||
if not device:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device with ID {device_id} not found",
|
||||
)
|
||||
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
device = db_device_handler.update_device(
|
||||
device_id=device_id,
|
||||
user_id=request.user.id,
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
return DeviceSchema.model_validate(device)
|
||||
|
||||
|
||||
@protected_route(
|
||||
router.delete,
|
||||
"/{device_id}",
|
||||
[Scope.DEVICES_WRITE],
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
def delete_device(request: Request, device_id: str) -> None:
|
||||
device = db_device_handler.get_device(device_id=device_id, user_id=request.user.id)
|
||||
if not device:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device with ID {device_id} not found",
|
||||
)
|
||||
|
||||
db_device_handler.delete_device(device_id=device_id, user_id=request.user.id)
|
||||
log.info(f"Deleted device {device_id} for user {request.user.username}")
|
||||
@@ -1,6 +1,12 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.exc import InvalidRequestError
|
||||
|
||||
from .base import BaseModel
|
||||
from .device import DeviceSyncSchema
|
||||
|
||||
|
||||
class BaseAsset(BaseModel):
|
||||
@@ -31,7 +37,40 @@ class ScreenshotSchema(BaseAsset):
|
||||
|
||||
class SaveSchema(BaseAsset):
|
||||
emulator: str | None
|
||||
slot: str | None = None
|
||||
content_hash: str | None = None
|
||||
screenshot: ScreenshotSchema | None
|
||||
device_syncs: list[DeviceSyncSchema] = []
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def handle_lazy_relationships(cls, data: Any) -> Any:
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
try:
|
||||
state = inspect(data)
|
||||
except Exception:
|
||||
return data
|
||||
result = {}
|
||||
for field_name in cls.model_fields:
|
||||
if field_name in state.unloaded:
|
||||
continue
|
||||
try:
|
||||
result[field_name] = getattr(data, field_name)
|
||||
except InvalidRequestError:
|
||||
continue
|
||||
return result
|
||||
|
||||
|
||||
class SlotSummarySchema(BaseModel):
|
||||
slot: str | None
|
||||
count: int
|
||||
latest: SaveSchema
|
||||
|
||||
|
||||
class SaveSummarySchema(BaseModel):
|
||||
total_count: int
|
||||
slots: list[SlotSummarySchema]
|
||||
|
||||
|
||||
class StateSchema(BaseAsset):
|
||||
|
||||
42
backend/endpoints/responses/device.py
Normal file
42
backend/endpoints/responses/device.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
|
||||
from models.device import SyncMode
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
class DeviceSyncSchema(BaseModel):
|
||||
device_id: str
|
||||
device_name: str | None
|
||||
last_synced_at: datetime
|
||||
is_untracked: bool
|
||||
is_current: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DeviceSchema(BaseModel):
|
||||
id: str
|
||||
user_id: int
|
||||
name: str | None
|
||||
platform: str | None
|
||||
client: str | None
|
||||
client_version: str | None
|
||||
ip_address: str | None
|
||||
mac_address: str | None
|
||||
hostname: str | None
|
||||
sync_mode: SyncMode
|
||||
sync_enabled: bool
|
||||
last_seen: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DeviceCreateResponse(BaseModel):
|
||||
device_id: str
|
||||
name: str | None
|
||||
created_at: datetime
|
||||
@@ -1,21 +1,99 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Body, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from decorators.auth import protected_route
|
||||
from endpoints.responses.assets import SaveSchema
|
||||
from endpoints.responses.assets import SaveSchema, SaveSummarySchema, SlotSummarySchema
|
||||
from endpoints.responses.device import DeviceSyncSchema
|
||||
from exceptions.endpoint_exceptions import RomNotFoundInDatabaseException
|
||||
from handler.auth.constants import Scope
|
||||
from handler.database import db_rom_handler, db_save_handler, db_screenshot_handler
|
||||
from handler.database import (
|
||||
db_device_handler,
|
||||
db_device_save_sync_handler,
|
||||
db_rom_handler,
|
||||
db_save_handler,
|
||||
db_screenshot_handler,
|
||||
)
|
||||
from handler.filesystem import fs_asset_handler
|
||||
from handler.scan_handler import scan_save, scan_screenshot
|
||||
from logger.formatter import BLUE
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
from models.assets import Save
|
||||
from models.device import Device
|
||||
from models.device_save_sync import DeviceSaveSync
|
||||
from utils.datetime import to_utc
|
||||
from utils.router import APIRouter
|
||||
|
||||
|
||||
def _build_save_schema(
|
||||
save: Save,
|
||||
device: Device | None = None,
|
||||
sync: DeviceSaveSync | None = None,
|
||||
) -> SaveSchema:
|
||||
save_schema = SaveSchema.model_validate(save)
|
||||
|
||||
if device:
|
||||
if sync:
|
||||
is_current = to_utc(sync.last_synced_at) >= to_utc(save.updated_at)
|
||||
last_synced = sync.last_synced_at
|
||||
is_untracked = sync.is_untracked
|
||||
else:
|
||||
is_current = False
|
||||
last_synced = save.updated_at
|
||||
is_untracked = False
|
||||
|
||||
save_schema.device_syncs = [
|
||||
DeviceSyncSchema(
|
||||
device_id=device.id,
|
||||
device_name=device.name,
|
||||
last_synced_at=last_synced,
|
||||
is_untracked=is_untracked,
|
||||
is_current=is_current,
|
||||
)
|
||||
]
|
||||
|
||||
return save_schema
|
||||
|
||||
|
||||
DATETIME_TAG_PATTERN = re.compile(r" \[\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\]")
|
||||
|
||||
|
||||
def _apply_datetime_tag(filename: str) -> str:
|
||||
name, ext = os.path.splitext(filename)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
if DATETIME_TAG_PATTERN.search(name):
|
||||
name = DATETIME_TAG_PATTERN.sub("", name)
|
||||
|
||||
return f"{name} [{timestamp}]{ext}"
|
||||
|
||||
|
||||
def _resolve_device(
|
||||
device_id: str | None,
|
||||
user_id: int,
|
||||
scopes: set[str] | None = None,
|
||||
required_scope: Scope | None = None,
|
||||
) -> Device | None:
|
||||
if not device_id:
|
||||
return None
|
||||
|
||||
if required_scope and scopes and required_scope not in scopes:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
device = db_device_handler.get_device(device_id=device_id, user_id=user_id)
|
||||
if not device:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Device with ID {device_id} not found",
|
||||
)
|
||||
return device
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/saves",
|
||||
tags=["saves"],
|
||||
@@ -27,22 +105,23 @@ async def add_save(
|
||||
request: Request,
|
||||
rom_id: int,
|
||||
emulator: str | None = None,
|
||||
slot: str | None = None,
|
||||
device_id: str | None = None,
|
||||
overwrite: bool = False,
|
||||
autocleanup: bool = False,
|
||||
autocleanup_limit: int = 10,
|
||||
) -> SaveSchema:
|
||||
"""Upload a save file for a ROM."""
|
||||
device = _resolve_device(
|
||||
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_WRITE
|
||||
)
|
||||
|
||||
data = await request.form()
|
||||
|
||||
rom = db_rom_handler.get_rom(rom_id)
|
||||
if not rom:
|
||||
raise RomNotFoundInDatabaseException(rom_id)
|
||||
|
||||
log.info(f"Uploading save of {rom.name}")
|
||||
|
||||
saves_path = fs_asset_handler.build_saves_file_path(
|
||||
user=request.user,
|
||||
platform_fs_slug=rom.platform.fs_slug,
|
||||
rom_id=rom_id,
|
||||
emulator=emulator,
|
||||
)
|
||||
|
||||
if "saveFile" not in data:
|
||||
log.error("No save file provided")
|
||||
raise HTTPException(
|
||||
@@ -57,12 +136,45 @@ async def add_save(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Save file has no filename"
|
||||
)
|
||||
|
||||
rom = db_rom_handler.get_rom(rom_id)
|
||||
if not rom:
|
||||
raise RomNotFoundInDatabaseException(rom_id)
|
||||
actual_filename = saveFile.filename
|
||||
if slot:
|
||||
actual_filename = _apply_datetime_tag(saveFile.filename)
|
||||
|
||||
db_save = db_save_handler.get_save_by_filename(
|
||||
user_id=request.user.id, rom_id=rom.id, file_name=actual_filename
|
||||
)
|
||||
|
||||
if device and slot and not overwrite:
|
||||
slot_saves = db_save_handler.get_saves(
|
||||
user_id=request.user.id,
|
||||
rom_id=rom.id,
|
||||
slot=slot,
|
||||
order_by="updated_at",
|
||||
)
|
||||
if slot_saves:
|
||||
latest_in_slot = slot_saves[0]
|
||||
sync = db_device_save_sync_handler.get_sync(
|
||||
device_id=device.id, save_id=latest_in_slot.id
|
||||
)
|
||||
if not sync or to_utc(sync.last_synced_at) < to_utc(
|
||||
latest_in_slot.updated_at
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Slot has a newer save since your last sync",
|
||||
)
|
||||
elif device and db_save and not overwrite:
|
||||
sync = db_device_save_sync_handler.get_sync(
|
||||
device_id=device.id, save_id=db_save.id
|
||||
)
|
||||
if sync and to_utc(sync.last_synced_at) < to_utc(db_save.updated_at):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Save has been updated since your last sync",
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Uploading save {hl(saveFile.filename)} for {hl(str(rom.name), color=BLUE)}"
|
||||
f"Uploading save {hl(actual_filename)} for {hl(str(rom.name), color=BLUE)}"
|
||||
)
|
||||
|
||||
saves_path = fs_asset_handler.build_saves_file_path(
|
||||
@@ -72,29 +184,72 @@ async def add_save(
|
||||
emulator=emulator,
|
||||
)
|
||||
|
||||
await fs_asset_handler.write_file(file=saveFile, path=saves_path)
|
||||
await fs_asset_handler.write_file(
|
||||
file=saveFile, path=saves_path, filename=actual_filename
|
||||
)
|
||||
|
||||
# Scan or update save
|
||||
scanned_save = await scan_save(
|
||||
file_name=saveFile.filename,
|
||||
file_name=actual_filename,
|
||||
user=request.user,
|
||||
platform_fs_slug=rom.platform.fs_slug,
|
||||
rom_id=rom_id,
|
||||
emulator=emulator,
|
||||
)
|
||||
db_save = db_save_handler.get_save_by_filename(
|
||||
user_id=request.user.id, rom_id=rom.id, file_name=saveFile.filename
|
||||
)
|
||||
if db_save:
|
||||
db_save = db_save_handler.update_save(
|
||||
db_save.id, {"file_size_bytes": scanned_save.file_size_bytes}
|
||||
|
||||
if slot and scanned_save.content_hash and not overwrite:
|
||||
existing_by_hash = db_save_handler.get_save_by_content_hash(
|
||||
user_id=request.user.id,
|
||||
rom_id=rom.id,
|
||||
content_hash=scanned_save.content_hash,
|
||||
)
|
||||
if existing_by_hash:
|
||||
try:
|
||||
await fs_asset_handler.remove_file(f"{saves_path}/{actual_filename}")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
sync = None
|
||||
if device:
|
||||
sync = db_device_save_sync_handler.get_sync(
|
||||
device_id=device.id, save_id=existing_by_hash.id
|
||||
)
|
||||
return _build_save_schema(existing_by_hash, device, sync)
|
||||
|
||||
if db_save:
|
||||
update_data: dict = {
|
||||
"file_size_bytes": scanned_save.file_size_bytes,
|
||||
"content_hash": scanned_save.content_hash,
|
||||
}
|
||||
if slot is not None:
|
||||
update_data["slot"] = slot
|
||||
db_save = db_save_handler.update_save(db_save.id, update_data)
|
||||
else:
|
||||
scanned_save.rom_id = rom.id
|
||||
scanned_save.user_id = request.user.id
|
||||
scanned_save.emulator = emulator
|
||||
scanned_save.slot = slot
|
||||
db_save = db_save_handler.add_save(save=scanned_save)
|
||||
|
||||
if device:
|
||||
db_device_save_sync_handler.upsert_sync(
|
||||
device_id=device.id, save_id=db_save.id, synced_at=db_save.updated_at
|
||||
)
|
||||
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)
|
||||
|
||||
if slot and autocleanup:
|
||||
slot_saves = db_save_handler.get_saves(
|
||||
user_id=request.user.id,
|
||||
rom_id=rom.id,
|
||||
slot=slot,
|
||||
order_by="updated_at",
|
||||
)
|
||||
if len(slot_saves) > autocleanup_limit:
|
||||
for old_save in slot_saves[autocleanup_limit:]:
|
||||
db_save_handler.delete_save(old_save.id)
|
||||
try:
|
||||
await fs_asset_handler.remove_file(old_save.full_path)
|
||||
except FileNotFoundError:
|
||||
log.warning(f"Could not delete old save file: {old_save.full_path}")
|
||||
|
||||
screenshotFile: UploadFile | None = data.get("screenshotFile", None) # type: ignore
|
||||
if screenshotFile and screenshotFile.filename:
|
||||
screenshots_path = fs_asset_handler.build_screenshots_file_path(
|
||||
@@ -103,7 +258,6 @@ async def add_save(
|
||||
|
||||
await fs_asset_handler.write_file(file=screenshotFile, path=screenshots_path)
|
||||
|
||||
# Scan or update screenshot
|
||||
scanned_screenshot = await scan_screenshot(
|
||||
file_name=screenshotFile.filename,
|
||||
user=request.user,
|
||||
@@ -125,7 +279,6 @@ async def add_save(
|
||||
scanned_screenshot.user_id = request.user.id
|
||||
db_screenshot_handler.add_screenshot(screenshot=scanned_screenshot)
|
||||
|
||||
# Set the last played time for the current user
|
||||
rom_user = db_rom_handler.get_rom_user(rom_id=rom.id, user_id=request.user.id)
|
||||
if not rom_user:
|
||||
rom_user = db_rom_handler.add_rom_user(rom_id=rom.id, user_id=request.user.id)
|
||||
@@ -133,37 +286,47 @@ async def add_save(
|
||||
rom_user.id, {"last_played": datetime.now(timezone.utc)}
|
||||
)
|
||||
|
||||
# Refetch the rom to get updated saves
|
||||
rom = db_rom_handler.get_rom(rom_id)
|
||||
if not rom:
|
||||
raise RomNotFoundInDatabaseException(rom_id)
|
||||
|
||||
return SaveSchema.model_validate(db_save)
|
||||
sync = None
|
||||
if device:
|
||||
sync = db_device_save_sync_handler.get_sync(
|
||||
device_id=device.id, save_id=db_save.id
|
||||
)
|
||||
return _build_save_schema(db_save, device, sync)
|
||||
|
||||
|
||||
@protected_route(router.get, "", [Scope.ASSETS_READ])
|
||||
def get_saves(
|
||||
request: Request, rom_id: int | None = None, platform_id: int | None = None
|
||||
request: Request,
|
||||
rom_id: int | None = None,
|
||||
platform_id: int | None = None,
|
||||
device_id: str | None = None,
|
||||
slot: str | None = None,
|
||||
) -> list[SaveSchema]:
|
||||
saves = db_save_handler.get_saves(
|
||||
user_id=request.user.id, rom_id=rom_id, platform_id=platform_id
|
||||
"""Retrieve saves for the current user."""
|
||||
device = _resolve_device(
|
||||
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ
|
||||
)
|
||||
|
||||
return [SaveSchema.model_validate(save) for save in saves]
|
||||
saves = db_save_handler.get_saves(
|
||||
user_id=request.user.id, rom_id=rom_id, platform_id=platform_id, slot=slot
|
||||
)
|
||||
|
||||
if not device:
|
||||
return [_build_save_schema(save) for save in saves]
|
||||
|
||||
syncs = db_device_save_sync_handler.get_syncs_for_device_and_saves(
|
||||
device_id=device.id, save_ids=[s.id for s in saves]
|
||||
)
|
||||
sync_by_save_id = {s.save_id: s for s in syncs}
|
||||
|
||||
return [
|
||||
_build_save_schema(save, device, sync_by_save_id.get(save.id)) for save in saves
|
||||
]
|
||||
|
||||
|
||||
@protected_route(router.get, "/identifiers", [Scope.ASSETS_READ])
|
||||
def get_save_identifiers(
|
||||
request: Request,
|
||||
) -> list[int]:
|
||||
"""Get save identifiers endpoint
|
||||
|
||||
Args:
|
||||
request (Request): Fastapi Request object
|
||||
|
||||
Returns:
|
||||
list[int]: List of save IDs
|
||||
"""
|
||||
def get_save_identifiers(request: Request) -> list[int]:
|
||||
"""Retrieve save identifiers."""
|
||||
saves = db_save_handler.get_saves(
|
||||
user_id=request.user.id,
|
||||
only_fields=[Save.id],
|
||||
@@ -172,20 +335,121 @@ def get_save_identifiers(
|
||||
return [save.id for save in saves]
|
||||
|
||||
|
||||
@protected_route(router.get, "/summary", [Scope.ASSETS_READ])
|
||||
def get_saves_summary(request: Request, rom_id: int) -> SaveSummarySchema:
|
||||
"""Retrieve saves summary grouped by slot."""
|
||||
summary_data = db_save_handler.get_saves_summary(
|
||||
user_id=request.user.id, rom_id=rom_id
|
||||
)
|
||||
|
||||
slots = [
|
||||
SlotSummarySchema(
|
||||
slot=slot_data["slot"],
|
||||
count=slot_data["count"],
|
||||
latest=_build_save_schema(slot_data["latest"]),
|
||||
)
|
||||
for slot_data in summary_data["slots"]
|
||||
]
|
||||
|
||||
return SaveSummarySchema(total_count=summary_data["total_count"], slots=slots)
|
||||
|
||||
|
||||
@protected_route(router.get, "/{id}", [Scope.ASSETS_READ])
|
||||
def get_save(request: Request, id: int) -> SaveSchema:
|
||||
def get_save(request: Request, id: int, device_id: str | None = None) -> SaveSchema:
|
||||
"""Retrieve a save by ID."""
|
||||
device = _resolve_device(
|
||||
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ
|
||||
)
|
||||
|
||||
save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
|
||||
if not save:
|
||||
error = f"Save with ID {id} not found"
|
||||
log.error(error)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Save with ID {id} not found",
|
||||
)
|
||||
|
||||
return SaveSchema.model_validate(save)
|
||||
sync = None
|
||||
if device:
|
||||
sync = db_device_save_sync_handler.get_sync(
|
||||
device_id=device.id, save_id=save.id
|
||||
)
|
||||
return _build_save_schema(save, device, sync)
|
||||
|
||||
|
||||
@protected_route(router.get, "/{id}/content", [Scope.ASSETS_READ])
|
||||
def download_save(
|
||||
request: Request,
|
||||
id: int,
|
||||
device_id: str | None = None,
|
||||
optimistic: bool = True,
|
||||
) -> FileResponse:
|
||||
"""Download a save file."""
|
||||
device = _resolve_device(
|
||||
device_id, request.user.id, request.auth.scopes, Scope.DEVICES_READ
|
||||
)
|
||||
|
||||
save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
if not save:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Save with ID {id} not found",
|
||||
)
|
||||
|
||||
try:
|
||||
file_path = fs_asset_handler.validate_path(save.full_path)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Save file not found",
|
||||
) from None
|
||||
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Save file not found on disk",
|
||||
)
|
||||
|
||||
if device and optimistic:
|
||||
db_device_save_sync_handler.upsert_sync(
|
||||
device_id=device.id,
|
||||
save_id=save.id,
|
||||
synced_at=save.updated_at,
|
||||
)
|
||||
db_device_handler.update_last_seen(device_id=device.id, user_id=request.user.id)
|
||||
|
||||
return FileResponse(path=str(file_path), filename=save.file_name)
|
||||
|
||||
|
||||
@protected_route(router.post, "/{id}/downloaded", [Scope.DEVICES_WRITE])
|
||||
def confirm_download(
|
||||
request: Request,
|
||||
id: int,
|
||||
device_id: str = Body(..., embed=True),
|
||||
) -> SaveSchema:
|
||||
"""Confirm a save was downloaded successfully."""
|
||||
save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
if not save:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Save with ID {id} not found",
|
||||
)
|
||||
|
||||
device = _resolve_device(device_id, request.user.id)
|
||||
assert device is not None
|
||||
|
||||
sync = db_device_save_sync_handler.upsert_sync(
|
||||
device_id=device_id,
|
||||
save_id=save.id,
|
||||
synced_at=save.updated_at,
|
||||
)
|
||||
db_device_handler.update_last_seen(device_id=device_id, user_id=request.user.id)
|
||||
|
||||
return _build_save_schema(save, device, sync)
|
||||
|
||||
|
||||
@protected_route(router.put, "/{id}", [Scope.ASSETS_WRITE])
|
||||
async def update_save(request: Request, id: int) -> SaveSchema:
|
||||
"""Update a save file."""
|
||||
data = await request.form()
|
||||
|
||||
db_save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
@@ -300,3 +564,51 @@ async def delete_saves(
|
||||
log.error(error)
|
||||
|
||||
return saves
|
||||
|
||||
|
||||
@protected_route(router.post, "/{id}/track", [Scope.DEVICES_WRITE])
|
||||
def track_save(
|
||||
request: Request,
|
||||
id: int,
|
||||
device_id: str = Body(..., embed=True),
|
||||
) -> SaveSchema:
|
||||
"""Re-enable sync tracking for a save on a device."""
|
||||
save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
if not save:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Save with ID {id} not found",
|
||||
)
|
||||
|
||||
device = _resolve_device(device_id, request.user.id)
|
||||
assert device is not None
|
||||
|
||||
sync = db_device_save_sync_handler.set_untracked(
|
||||
device_id=device_id, save_id=id, untracked=False
|
||||
)
|
||||
|
||||
return _build_save_schema(save, device, sync)
|
||||
|
||||
|
||||
@protected_route(router.post, "/{id}/untrack", [Scope.DEVICES_WRITE])
|
||||
def untrack_save(
|
||||
request: Request,
|
||||
id: int,
|
||||
device_id: str = Body(..., embed=True),
|
||||
) -> SaveSchema:
|
||||
"""Disable sync tracking for a save on a device."""
|
||||
save = db_save_handler.get_save(user_id=request.user.id, id=id)
|
||||
if not save:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Save with ID {id} not found",
|
||||
)
|
||||
|
||||
device = _resolve_device(device_id, request.user.id)
|
||||
assert device is not None
|
||||
|
||||
sync = db_device_save_sync_handler.set_untracked(
|
||||
device_id=device_id, save_id=id, untracked=True
|
||||
)
|
||||
|
||||
return _build_save_schema(save, device, sync)
|
||||
|
||||
@@ -17,6 +17,8 @@ class Scope(enum.StrEnum):
|
||||
PLATFORMS_WRITE = "platforms.write"
|
||||
ASSETS_READ = "assets.read"
|
||||
ASSETS_WRITE = "assets.write"
|
||||
DEVICES_READ = "devices.read"
|
||||
DEVICES_WRITE = "devices.write"
|
||||
FIRMWARE_READ = "firmware.read"
|
||||
FIRMWARE_WRITE = "firmware.write"
|
||||
COLLECTIONS_READ = "collections.read"
|
||||
@@ -31,6 +33,7 @@ READ_SCOPES_MAP: Final = {
|
||||
Scope.ROMS_READ: "View ROMs",
|
||||
Scope.PLATFORMS_READ: "View platforms",
|
||||
Scope.ASSETS_READ: "View assets",
|
||||
Scope.DEVICES_READ: "View devices",
|
||||
Scope.FIRMWARE_READ: "View firmware",
|
||||
Scope.ROMS_USER_READ: "View user-rom properties",
|
||||
Scope.COLLECTIONS_READ: "View collections",
|
||||
@@ -39,6 +42,7 @@ READ_SCOPES_MAP: Final = {
|
||||
WRITE_SCOPES_MAP: Final = {
|
||||
Scope.ME_WRITE: "Modify your profile",
|
||||
Scope.ASSETS_WRITE: "Modify assets",
|
||||
Scope.DEVICES_WRITE: "Modify devices",
|
||||
Scope.ROMS_USER_WRITE: "Modify user-rom properties",
|
||||
Scope.COLLECTIONS_WRITE: "Modify collections",
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from .collections_handler import DBCollectionsHandler
|
||||
from .device_save_sync_handler import DBDeviceSaveSyncHandler
|
||||
from .devices_handler import DBDevicesHandler
|
||||
from .firmware_handler import DBFirmwareHandler
|
||||
from .platforms_handler import DBPlatformsHandler
|
||||
from .roms_handler import DBRomsHandler
|
||||
@@ -8,6 +10,9 @@ from .states_handler import DBStatesHandler
|
||||
from .stats_handler import DBStatsHandler
|
||||
from .users_handler import DBUsersHandler
|
||||
|
||||
db_collection_handler = DBCollectionsHandler()
|
||||
db_device_handler = DBDevicesHandler()
|
||||
db_device_save_sync_handler = DBDeviceSaveSyncHandler()
|
||||
db_firmware_handler = DBFirmwareHandler()
|
||||
db_platform_handler = DBPlatformsHandler()
|
||||
db_rom_handler = DBRomsHandler()
|
||||
@@ -16,4 +21,3 @@ db_screenshot_handler = DBScreenshotsHandler()
|
||||
db_state_handler = DBStatesHandler()
|
||||
db_stats_handler = DBStatsHandler()
|
||||
db_user_handler = DBUsersHandler()
|
||||
db_collection_handler = DBCollectionsHandler()
|
||||
|
||||
129
backend/handler/database/device_save_sync_handler.py
Normal file
129
backend/handler/database/device_save_sync_handler.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.device_save_sync import DeviceSaveSync
|
||||
|
||||
from .base_handler import DBBaseHandler
|
||||
|
||||
|
||||
class DBDeviceSaveSyncHandler(DBBaseHandler):
|
||||
@begin_session
|
||||
def get_sync(
|
||||
self,
|
||||
device_id: str,
|
||||
save_id: int,
|
||||
session: Session = None, # type: ignore
|
||||
) -> DeviceSaveSync | None:
|
||||
return session.scalar(
|
||||
select(DeviceSaveSync)
|
||||
.filter_by(device_id=device_id, save_id=save_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_syncs_for_device_and_saves(
|
||||
self,
|
||||
device_id: str,
|
||||
save_ids: list[int],
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[DeviceSaveSync]:
|
||||
if not save_ids:
|
||||
return []
|
||||
return session.scalars(
|
||||
select(DeviceSaveSync).filter(
|
||||
DeviceSaveSync.device_id == device_id,
|
||||
DeviceSaveSync.save_id.in_(save_ids),
|
||||
)
|
||||
).all()
|
||||
|
||||
@begin_session
|
||||
def upsert_sync(
|
||||
self,
|
||||
device_id: str,
|
||||
save_id: int,
|
||||
synced_at: datetime | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> DeviceSaveSync:
|
||||
now = synced_at or datetime.now(timezone.utc)
|
||||
existing = session.scalar(
|
||||
select(DeviceSaveSync)
|
||||
.filter_by(device_id=device_id, save_id=save_id)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
session.execute(
|
||||
update(DeviceSaveSync)
|
||||
.where(
|
||||
DeviceSaveSync.device_id == device_id,
|
||||
DeviceSaveSync.save_id == save_id,
|
||||
)
|
||||
.values(last_synced_at=now, is_untracked=False)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
existing.last_synced_at = now
|
||||
existing.is_untracked = False
|
||||
return existing
|
||||
else:
|
||||
sync = DeviceSaveSync(
|
||||
device_id=device_id,
|
||||
save_id=save_id,
|
||||
last_synced_at=now,
|
||||
is_untracked=False,
|
||||
)
|
||||
session.add(sync)
|
||||
session.flush()
|
||||
return sync
|
||||
|
||||
@begin_session
|
||||
def set_untracked(
|
||||
self,
|
||||
device_id: str,
|
||||
save_id: int,
|
||||
untracked: bool,
|
||||
session: Session = None, # type: ignore
|
||||
) -> DeviceSaveSync | None:
|
||||
existing = session.scalar(
|
||||
select(DeviceSaveSync)
|
||||
.filter_by(device_id=device_id, save_id=save_id)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
session.execute(
|
||||
update(DeviceSaveSync)
|
||||
.where(
|
||||
DeviceSaveSync.device_id == device_id,
|
||||
DeviceSaveSync.save_id == save_id,
|
||||
)
|
||||
.values(is_untracked=untracked)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
existing.is_untracked = untracked
|
||||
return existing
|
||||
elif untracked:
|
||||
now = datetime.now(timezone.utc)
|
||||
sync = DeviceSaveSync(
|
||||
device_id=device_id,
|
||||
save_id=save_id,
|
||||
last_synced_at=now,
|
||||
is_untracked=True,
|
||||
)
|
||||
session.add(sync)
|
||||
session.flush()
|
||||
return sync
|
||||
return None
|
||||
|
||||
@begin_session
|
||||
def delete_syncs_for_device(
|
||||
self,
|
||||
device_id: str,
|
||||
session: Session = None, # type: ignore
|
||||
) -> None:
|
||||
session.execute(
|
||||
delete(DeviceSaveSync)
|
||||
.where(DeviceSaveSync.device_id == device_id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
111
backend/handler/database/devices_handler.py
Normal file
111
backend/handler/database/devices_handler.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from decorators.database import begin_session
|
||||
from models.device import Device
|
||||
|
||||
from .base_handler import DBBaseHandler
|
||||
|
||||
|
||||
class DBDevicesHandler(DBBaseHandler):
|
||||
@begin_session
|
||||
def add_device(
|
||||
self,
|
||||
device: Device,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Device:
|
||||
return session.merge(device)
|
||||
|
||||
@begin_session
|
||||
def get_device(
|
||||
self,
|
||||
device_id: str,
|
||||
user_id: int,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Device | None:
|
||||
return session.scalar(
|
||||
select(Device).filter_by(id=device_id, user_id=user_id).limit(1)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_device_by_fingerprint(
|
||||
self,
|
||||
user_id: int,
|
||||
mac_address: str | None = None,
|
||||
hostname: str | None = None,
|
||||
platform: str | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Device | None:
|
||||
if mac_address:
|
||||
device = session.scalar(
|
||||
select(Device)
|
||||
.filter_by(user_id=user_id, mac_address=mac_address)
|
||||
.limit(1)
|
||||
)
|
||||
if device:
|
||||
return device
|
||||
|
||||
if hostname and platform:
|
||||
return session.scalar(
|
||||
select(Device)
|
||||
.filter_by(user_id=user_id, hostname=hostname, platform=platform)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@begin_session
|
||||
def get_devices(
|
||||
self,
|
||||
user_id: int,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[Device]:
|
||||
return session.scalars(select(Device).filter_by(user_id=user_id)).all()
|
||||
|
||||
@begin_session
|
||||
def update_device(
|
||||
self,
|
||||
device_id: str,
|
||||
user_id: int,
|
||||
data: dict,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Device | None:
|
||||
session.execute(
|
||||
update(Device)
|
||||
.where(Device.id == device_id, Device.user_id == user_id)
|
||||
.values(**data)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
return session.scalar(
|
||||
select(Device).filter_by(id=device_id, user_id=user_id).limit(1)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def update_last_seen(
|
||||
self,
|
||||
device_id: str,
|
||||
user_id: int,
|
||||
session: Session = None, # type: ignore
|
||||
) -> None:
|
||||
session.execute(
|
||||
update(Device)
|
||||
.where(Device.id == device_id, Device.user_id == user_id)
|
||||
.values(last_seen=datetime.now(timezone.utc))
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def delete_device(
|
||||
self,
|
||||
device_id: str,
|
||||
user_id: int,
|
||||
session: Session = None, # type: ignore
|
||||
) -> None:
|
||||
session.execute(
|
||||
delete(Device)
|
||||
.where(Device.id == device_id, Device.user_id == user_id)
|
||||
.execution_options(synchronize_session="evaluate")
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy import and_, asc, delete, desc, select, update
|
||||
from sqlalchemy.orm import QueryableAttribute, Session, load_only
|
||||
|
||||
from decorators.database import begin_session
|
||||
@@ -42,12 +43,29 @@ class DBSavesHandler(DBBaseHandler):
|
||||
.limit(1)
|
||||
).first()
|
||||
|
||||
@begin_session
|
||||
def get_save_by_content_hash(
|
||||
self,
|
||||
user_id: int,
|
||||
rom_id: int,
|
||||
content_hash: str,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Save | None:
|
||||
return session.scalar(
|
||||
select(Save)
|
||||
.filter_by(rom_id=rom_id, user_id=user_id, content_hash=content_hash)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@begin_session
|
||||
def get_saves(
|
||||
self,
|
||||
user_id: int,
|
||||
rom_id: int | None = None,
|
||||
platform_id: int | None = None,
|
||||
slot: str | None = None,
|
||||
order_by: Literal["updated_at", "created_at"] | None = None,
|
||||
order_dir: Literal["asc", "desc"] = "desc",
|
||||
only_fields: Sequence[QueryableAttribute] | None = None,
|
||||
session: Session = None, # type: ignore
|
||||
) -> Sequence[Save]:
|
||||
@@ -61,6 +79,14 @@ class DBSavesHandler(DBBaseHandler):
|
||||
Rom.platform_id == platform_id
|
||||
)
|
||||
|
||||
if slot is not None:
|
||||
query = query.filter(Save.slot == slot)
|
||||
|
||||
if order_by:
|
||||
order_col = getattr(Save, order_by)
|
||||
order_fn = asc if order_dir == "asc" else desc
|
||||
query = query.order_by(order_fn(order_col))
|
||||
|
||||
if only_fields:
|
||||
query = query.options(load_only(*only_fields))
|
||||
|
||||
@@ -125,3 +151,28 @@ class DBSavesHandler(DBBaseHandler):
|
||||
)
|
||||
|
||||
return missing_saves
|
||||
|
||||
@begin_session
|
||||
def get_saves_summary(
|
||||
self,
|
||||
user_id: int,
|
||||
rom_id: int,
|
||||
session: Session = None, # type: ignore
|
||||
) -> dict:
|
||||
saves = session.scalars(
|
||||
select(Save)
|
||||
.filter_by(user_id=user_id, rom_id=rom_id)
|
||||
.order_by(desc(Save.updated_at))
|
||||
).all()
|
||||
|
||||
slots_data: dict[str | None, dict] = {}
|
||||
for save in saves:
|
||||
slot_key = save.slot
|
||||
if slot_key not in slots_data:
|
||||
slots_data[slot_key] = {"slot": slot_key, "count": 0, "latest": save}
|
||||
slots_data[slot_key]["count"] += 1
|
||||
|
||||
return {
|
||||
"total_count": len(saves),
|
||||
"slots": list(slots_data.values()),
|
||||
}
|
||||
|
||||
@@ -1,11 +1,44 @@
|
||||
import hashlib
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from config import ASSETS_BASE_PATH
|
||||
from logger.logger import log
|
||||
from models.user import User
|
||||
|
||||
from .base_handler import FSHandler
|
||||
|
||||
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
hash_obj = hashlib.md5(usedforsecurity=False)
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
hash_obj.update(chunk)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
|
||||
def compute_zip_hash(zip_path: str) -> str:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
file_hashes = []
|
||||
for name in sorted(zf.namelist()):
|
||||
if not name.endswith("/"):
|
||||
content = zf.read(name)
|
||||
file_hash = hashlib.md5(content, usedforsecurity=False).hexdigest()
|
||||
file_hashes.append(f"{name}:{file_hash}")
|
||||
combined = "\n".join(file_hashes)
|
||||
return hashlib.md5(combined.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def compute_content_hash(file_path: str) -> str | None:
|
||||
try:
|
||||
if zipfile.is_zipfile(file_path):
|
||||
return compute_zip_hash(file_path)
|
||||
return compute_file_hash(file_path)
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to compute content hash for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class FSAssetsHandler(FSHandler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(base_path=ASSETS_BASE_PATH)
|
||||
|
||||
@@ -4,10 +4,12 @@ from typing import Any
|
||||
|
||||
import socketio # type: ignore
|
||||
|
||||
from config import ASSETS_BASE_PATH
|
||||
from config.config_manager import config_manager as cm
|
||||
from endpoints.responses.rom import SimpleRomSchema
|
||||
from handler.database import db_platform_handler, db_rom_handler
|
||||
from handler.filesystem import fs_asset_handler, fs_firmware_handler
|
||||
from handler.filesystem.assets_handler import compute_content_hash
|
||||
from handler.filesystem.roms_handler import FSRom
|
||||
from handler.metadata import (
|
||||
meta_flashpoint_handler,
|
||||
@@ -817,11 +819,11 @@ async def scan_rom(
|
||||
return Rom(**rom_attrs)
|
||||
|
||||
|
||||
async def _scan_asset(file_name: str, asset_path: str):
|
||||
async def _scan_asset(file_name: str, asset_path: str, should_hash: bool = False):
|
||||
file_path = f"{asset_path}/{file_name}"
|
||||
file_size = await fs_asset_handler.get_file_size(file_path)
|
||||
|
||||
return {
|
||||
result = {
|
||||
"file_path": asset_path,
|
||||
"file_name": file_name,
|
||||
"file_name_no_tags": fs_asset_handler.get_file_name_with_no_tags(file_name),
|
||||
@@ -830,6 +832,12 @@ async def _scan_asset(file_name: str, asset_path: str):
|
||||
"file_size_bytes": file_size,
|
||||
}
|
||||
|
||||
if should_hash:
|
||||
absolute_path = f"{ASSETS_BASE_PATH}/{file_path}"
|
||||
result["content_hash"] = compute_content_hash(absolute_path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def scan_save(
|
||||
file_name: str,
|
||||
@@ -841,7 +849,7 @@ async def scan_save(
|
||||
saves_path = fs_asset_handler.build_saves_file_path(
|
||||
user=user, platform_fs_slug=platform_fs_slug, rom_id=rom_id, emulator=emulator
|
||||
)
|
||||
scanned_asset = await _scan_asset(file_name, saves_path)
|
||||
scanned_asset = await _scan_asset(file_name, saves_path, should_hash=True)
|
||||
return Save(**scanned_asset)
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from endpoints import (
|
||||
auth,
|
||||
collections,
|
||||
configs,
|
||||
device,
|
||||
feeds,
|
||||
firmware,
|
||||
gamelist,
|
||||
@@ -122,6 +123,7 @@ app.middleware("http")(set_context_middleware)
|
||||
app.include_router(heartbeat.router, prefix="/api")
|
||||
app.include_router(auth.router, prefix="/api")
|
||||
app.include_router(user.router, prefix="/api")
|
||||
app.include_router(device.router, prefix="/api")
|
||||
app.include_router(platform.router, prefix="/api")
|
||||
app.include_router(rom.router, prefix="/api")
|
||||
app.include_router(search.router, prefix="/api")
|
||||
|
||||
@@ -14,6 +14,7 @@ from models.base import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.device_save_sync import DeviceSaveSync
|
||||
from models.rom import Rom
|
||||
from models.user import User
|
||||
|
||||
@@ -54,9 +55,16 @@ class Save(RomAsset):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
emulator: Mapped[str | None] = mapped_column(String(length=50))
|
||||
slot: Mapped[str | None] = mapped_column(String(length=255))
|
||||
content_hash: Mapped[str | None] = mapped_column(String(length=32))
|
||||
|
||||
rom: Mapped[Rom] = relationship(lazy="joined", back_populates="saves")
|
||||
user: Mapped[User] = relationship(lazy="joined", back_populates="saves")
|
||||
device_syncs: Mapped[list[DeviceSaveSync]] = relationship(
|
||||
back_populates="save",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def screenshot(self) -> Screenshot | None:
|
||||
|
||||
49
backend/models/device.py
Normal file
49
backend/models/device.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import TIMESTAMP, Boolean, Enum, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.device_save_sync import DeviceSaveSync
|
||||
from models.user import User
|
||||
|
||||
|
||||
class SyncMode(enum.StrEnum):
|
||||
API = "api"
|
||||
FILE_TRANSFER = "file_transfer"
|
||||
PUSH_PULL = "push_pull"
|
||||
|
||||
|
||||
class Device(BaseModel):
|
||||
__tablename__ = "devices"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
|
||||
|
||||
name: Mapped[str | None] = mapped_column(String(255))
|
||||
platform: Mapped[str | None] = mapped_column(String(50))
|
||||
client: Mapped[str | None] = mapped_column(String(50))
|
||||
client_version: Mapped[str | None] = mapped_column(String(50))
|
||||
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45))
|
||||
mac_address: Mapped[str | None] = mapped_column(String(17))
|
||||
hostname: Mapped[str | None] = mapped_column(String(255))
|
||||
|
||||
sync_mode: Mapped[SyncMode] = mapped_column(Enum(SyncMode), default=SyncMode.API)
|
||||
sync_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
last_seen: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))
|
||||
|
||||
user: Mapped[User] = relationship(lazy="joined")
|
||||
save_syncs: Mapped[list[DeviceSaveSync]] = relationship(
|
||||
back_populates="device",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="raise",
|
||||
)
|
||||
34
backend/models/device_save_sync.py
Normal file
34
backend/models/device_save_sync.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import TIMESTAMP, Boolean, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.assets import Save
|
||||
from models.device import Device
|
||||
|
||||
|
||||
class DeviceSaveSync(BaseModel):
|
||||
__tablename__ = "device_save_sync"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
device_id: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
ForeignKey("devices.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
save_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("saves.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
last_synced_at: Mapped[datetime] = mapped_column(TIMESTAMP(timezone=True))
|
||||
is_untracked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
device: Mapped[Device] = relationship(back_populates="save_syncs", lazy="raise")
|
||||
save: Mapped[Save] = relationship(back_populates="device_syncs", lazy="raise")
|
||||
@@ -22,6 +22,7 @@ from utils.database import CustomJSON
|
||||
if TYPE_CHECKING:
|
||||
from models.assets import Save, Screenshot, State
|
||||
from models.collection import Collection, SmartCollection
|
||||
from models.device import Device
|
||||
from models.rom import RomNote, RomUser
|
||||
|
||||
|
||||
@@ -79,6 +80,9 @@ class User(BaseModel, SimpleUser):
|
||||
smart_collections: Mapped[list["SmartCollection"]] = relationship(
|
||||
lazy="raise", back_populates="user"
|
||||
)
|
||||
devices: Mapped[list["Device"]] = relationship(
|
||||
lazy="raise", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def kiosk_mode_user(cls) -> User:
|
||||
|
||||
@@ -14,6 +14,8 @@ from handler.database import (
|
||||
db_user_handler,
|
||||
)
|
||||
from models.assets import Save, Screenshot, State
|
||||
from models.device import Device
|
||||
from models.device_save_sync import DeviceSaveSync
|
||||
from models.platform import Platform
|
||||
from models.rom import Rom
|
||||
from models.user import Role, User
|
||||
@@ -30,6 +32,8 @@ def setup_database():
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_database():
|
||||
with session.begin() as s:
|
||||
s.query(DeviceSaveSync).delete(synchronize_session="evaluate")
|
||||
s.query(Device).delete(synchronize_session="evaluate")
|
||||
s.query(Save).delete(synchronize_session="evaluate")
|
||||
s.query(State).delete(synchronize_session="evaluate")
|
||||
s.query(Screenshot).delete(synchronize_session="evaluate")
|
||||
|
||||
509
backend/tests/endpoints/test_device.py
Normal file
509
backend/tests/endpoints/test_device.py
Normal file
@@ -0,0 +1,509 @@
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
|
||||
from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
from handler.auth import oauth_handler
|
||||
from handler.database import db_device_handler
|
||||
from handler.redis_handler import sync_cache
|
||||
from models.device import Device
|
||||
from models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
yield
|
||||
sync_cache.flushall()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_access_token(editor_user: User):
|
||||
return oauth_handler.create_oauth_token(
|
||||
data={
|
||||
"sub": editor_user.username,
|
||||
"iss": "romm:oauth",
|
||||
"scopes": " ".join(editor_user.oauth_scopes),
|
||||
"type": "access",
|
||||
},
|
||||
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
)
|
||||
|
||||
|
||||
class TestDeviceEndpoints:
|
||||
def test_register_device(self, client, access_token: str):
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "Test Device",
|
||||
"platform": "android",
|
||||
"client": "argosy",
|
||||
"client_version": "0.16.0",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Device"
|
||||
assert "device_id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
def test_register_device_minimal(self, client, access_token: str):
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["name"] is None
|
||||
assert "device_id" in data
|
||||
|
||||
def test_list_devices(self, client, access_token: str, admin_user: User):
|
||||
|
||||
db_device_handler.add_device(
|
||||
Device(
|
||||
id="test-device-1",
|
||||
user_id=admin_user.id,
|
||||
name="Device 1",
|
||||
)
|
||||
)
|
||||
db_device_handler.add_device(
|
||||
Device(
|
||||
id="test-device-2",
|
||||
user_id=admin_user.id,
|
||||
name="Device 2",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/devices",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
names = [d["name"] for d in data]
|
||||
assert "Device 1" in names
|
||||
assert "Device 2" in names
|
||||
|
||||
def test_get_device(self, client, access_token: str, admin_user: User):
|
||||
|
||||
device = db_device_handler.add_device(
|
||||
Device(
|
||||
id="test-device-get",
|
||||
user_id=admin_user.id,
|
||||
name="Get Test Device",
|
||||
platform="linux",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
f"/api/devices/{device.id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["id"] == "test-device-get"
|
||||
assert data["name"] == "Get Test Device"
|
||||
assert data["platform"] == "linux"
|
||||
|
||||
def test_get_device_not_found(self, client, access_token: str):
|
||||
response = client.get(
|
||||
"/api/devices/nonexistent-device",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_update_device(self, client, access_token: str, admin_user: User):
|
||||
device = db_device_handler.add_device(
|
||||
Device(
|
||||
id="test-device-update",
|
||||
user_id=admin_user.id,
|
||||
name="Original Name",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.put(
|
||||
f"/api/devices/{device.id}",
|
||||
json={
|
||||
"name": "Updated Name",
|
||||
"platform": "android",
|
||||
"client": "daijishou",
|
||||
"client_version": "4.0.0",
|
||||
"ip_address": "192.168.1.100",
|
||||
"mac_address": "AA:BB:CC:DD:EE:FF",
|
||||
"hostname": "my-odin3",
|
||||
"sync_enabled": False,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["platform"] == "android"
|
||||
assert data["client"] == "daijishou"
|
||||
assert data["client_version"] == "4.0.0"
|
||||
assert data["ip_address"] == "192.168.1.100"
|
||||
assert data["mac_address"] == "AA:BB:CC:DD:EE:FF"
|
||||
assert data["hostname"] == "my-odin3"
|
||||
assert data["sync_enabled"] is False
|
||||
|
||||
def test_delete_device(self, client, access_token: str, admin_user: User):
|
||||
|
||||
device = db_device_handler.add_device(
|
||||
Device(
|
||||
id="test-device-delete",
|
||||
user_id=admin_user.id,
|
||||
name="To Delete",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
f"/api/devices/{device.id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
get_response = client.get(
|
||||
f"/api/devices/{device.id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestDeviceUserIsolation:
|
||||
def test_list_devices_only_returns_own_devices(
|
||||
self,
|
||||
client,
|
||||
access_token: str,
|
||||
editor_access_token: str,
|
||||
admin_user: User,
|
||||
editor_user: User,
|
||||
):
|
||||
db_device_handler.add_device(
|
||||
Device(id="admin-device", user_id=admin_user.id, name="Admin Device")
|
||||
)
|
||||
db_device_handler.add_device(
|
||||
Device(id="editor-device", user_id=editor_user.id, name="Editor Device")
|
||||
)
|
||||
|
||||
admin_response = client.get(
|
||||
"/api/devices",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert admin_response.status_code == status.HTTP_200_OK
|
||||
admin_devices = admin_response.json()
|
||||
assert len(admin_devices) == 1
|
||||
assert admin_devices[0]["name"] == "Admin Device"
|
||||
|
||||
editor_response = client.get(
|
||||
"/api/devices",
|
||||
headers={"Authorization": f"Bearer {editor_access_token}"},
|
||||
)
|
||||
assert editor_response.status_code == status.HTTP_200_OK
|
||||
editor_devices = editor_response.json()
|
||||
assert len(editor_devices) == 1
|
||||
assert editor_devices[0]["name"] == "Editor Device"
|
||||
|
||||
def test_cannot_get_other_users_device(
|
||||
self,
|
||||
client,
|
||||
editor_access_token: str,
|
||||
admin_user: User,
|
||||
):
|
||||
device = db_device_handler.add_device(
|
||||
Device(id="admin-only-device", user_id=admin_user.id, name="Admin Only")
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
f"/api/devices/{device.id}",
|
||||
headers={"Authorization": f"Bearer {editor_access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_cannot_update_other_users_device(
|
||||
self,
|
||||
client,
|
||||
editor_access_token: str,
|
||||
admin_user: User,
|
||||
):
|
||||
device = db_device_handler.add_device(
|
||||
Device(id="admin-protected-device", user_id=admin_user.id, name="Protected")
|
||||
)
|
||||
|
||||
response = client.put(
|
||||
f"/api/devices/{device.id}",
|
||||
json={"name": "Hacked Name"},
|
||||
headers={"Authorization": f"Bearer {editor_access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
original = db_device_handler.get_device(
|
||||
device_id=device.id, user_id=admin_user.id
|
||||
)
|
||||
assert original.name == "Protected"
|
||||
|
||||
def test_cannot_delete_other_users_device(
|
||||
self,
|
||||
client,
|
||||
editor_access_token: str,
|
||||
admin_user: User,
|
||||
):
|
||||
device = db_device_handler.add_device(
|
||||
Device(id="admin-nodelete-device", user_id=admin_user.id, name="No Delete")
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
f"/api/devices/{device.id}",
|
||||
headers={"Authorization": f"Bearer {editor_access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
still_exists = db_device_handler.get_device(
|
||||
device_id=device.id, user_id=admin_user.id
|
||||
)
|
||||
assert still_exists is not None
|
||||
|
||||
|
||||
class TestDeviceDuplicateHandling:
|
||||
def test_duplicate_mac_address_returns_existing(
|
||||
self, client, access_token: str, admin_user: User
|
||||
):
|
||||
db_device_handler.add_device(
|
||||
Device(
|
||||
id="existing-mac-device",
|
||||
user_id=admin_user.id,
|
||||
name="Existing Device",
|
||||
mac_address="AA:BB:CC:DD:EE:FF",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "New Device",
|
||||
"mac_address": "AA:BB:CC:DD:EE:FF",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["device_id"] == "existing-mac-device"
|
||||
assert data["name"] == "Existing Device"
|
||||
|
||||
def test_duplicate_hostname_platform_returns_existing(
|
||||
self, client, access_token: str, admin_user: User
|
||||
):
|
||||
db_device_handler.add_device(
|
||||
Device(
|
||||
id="existing-hostname-device",
|
||||
user_id=admin_user.id,
|
||||
name="Existing Device",
|
||||
hostname="my-device",
|
||||
platform="android",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "New Device",
|
||||
"hostname": "my-device",
|
||||
"platform": "android",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["device_id"] == "existing-hostname-device"
|
||||
assert data["name"] == "Existing Device"
|
||||
|
||||
def test_duplicate_with_allow_existing_false_returns_409(
|
||||
self, client, access_token: str, admin_user: User
|
||||
):
|
||||
db_device_handler.add_device(
|
||||
Device(
|
||||
id="reject-duplicate-device",
|
||||
user_id=admin_user.id,
|
||||
name="Existing Device",
|
||||
mac_address="FF:EE:DD:CC:BB:AA",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "New Device",
|
||||
"mac_address": "FF:EE:DD:CC:BB:AA",
|
||||
"allow_existing": False,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
data = response.json()["detail"]
|
||||
assert data["error"] == "device_exists"
|
||||
assert data["device_id"] == "reject-duplicate-device"
|
||||
|
||||
def test_allow_existing_returns_existing_device(
|
||||
self, client, access_token: str, admin_user: User
|
||||
):
|
||||
existing = db_device_handler.add_device(
|
||||
Device(
|
||||
id="allow-existing-device",
|
||||
user_id=admin_user.id,
|
||||
name="Existing Device",
|
||||
mac_address="11:22:33:44:55:66",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "New Device Name",
|
||||
"mac_address": "11:22:33:44:55:66",
|
||||
"allow_existing": True,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["device_id"] == existing.id
|
||||
assert data["name"] == "Existing Device"
|
||||
|
||||
def test_allow_existing_with_reset_syncs(
|
||||
self, client, access_token: str, admin_user: User, rom
|
||||
):
|
||||
from handler.database import db_device_save_sync_handler, db_save_handler
|
||||
from models.assets import Save
|
||||
|
||||
existing = db_device_handler.add_device(
|
||||
Device(
|
||||
id="reset-syncs-device",
|
||||
user_id=admin_user.id,
|
||||
name="Device With Syncs",
|
||||
mac_address="77:88:99:AA:BB:CC",
|
||||
)
|
||||
)
|
||||
|
||||
save = db_save_handler.add_save(
|
||||
Save(
|
||||
file_name="test.sav",
|
||||
file_name_no_tags="test",
|
||||
file_name_no_ext="test",
|
||||
file_extension="sav",
|
||||
file_path="/saves",
|
||||
file_size_bytes=100,
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
)
|
||||
)
|
||||
db_device_save_sync_handler.upsert_sync(device_id=existing.id, save_id=save.id)
|
||||
|
||||
sync_before = db_device_save_sync_handler.get_sync(
|
||||
device_id=existing.id, save_id=save.id
|
||||
)
|
||||
assert sync_before is not None
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"mac_address": "77:88:99:AA:BB:CC",
|
||||
"allow_existing": True,
|
||||
"reset_syncs": True,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["device_id"] == existing.id
|
||||
|
||||
sync_after = db_device_save_sync_handler.get_sync(
|
||||
device_id=existing.id, save_id=save.id
|
||||
)
|
||||
assert sync_after is None
|
||||
|
||||
def test_allow_duplicate_creates_new_device(
|
||||
self, client, access_token: str, admin_user: User
|
||||
):
|
||||
existing = db_device_handler.add_device(
|
||||
Device(
|
||||
id="original-device",
|
||||
user_id=admin_user.id,
|
||||
name="Original Device",
|
||||
mac_address="DD:EE:FF:00:11:22",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "Duplicate Install",
|
||||
"mac_address": "DD:EE:FF:00:11:22",
|
||||
"allow_duplicate": True,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["device_id"] != existing.id
|
||||
assert data["name"] == "Duplicate Install"
|
||||
|
||||
def test_no_conflict_without_fingerprint(self, client, access_token: str):
|
||||
response1 = client.post(
|
||||
"/api/devices",
|
||||
json={"name": "Device 1"},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert response1.status_code == status.HTTP_201_CREATED
|
||||
|
||||
response2 = client.post(
|
||||
"/api/devices",
|
||||
json={"name": "Device 2"},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert response2.status_code == status.HTTP_201_CREATED
|
||||
assert response1.json()["device_id"] != response2.json()["device_id"]
|
||||
|
||||
def test_hostname_only_no_conflict_without_platform(
|
||||
self, client, access_token: str, admin_user: User
|
||||
):
|
||||
db_device_handler.add_device(
|
||||
Device(
|
||||
id="hostname-only-device",
|
||||
user_id=admin_user.id,
|
||||
name="Existing",
|
||||
hostname="my-device",
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/devices",
|
||||
json={
|
||||
"name": "New Device",
|
||||
"hostname": "my-device",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
2137
backend/tests/endpoints/test_saves.py
Normal file
2137
backend/tests/endpoints/test_saves.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -166,3 +166,280 @@ class TestDBSavesHandlerPlatformFiltering:
|
||||
|
||||
# Verify the save is associated with the correct platform through ROM
|
||||
assert retrieved_save.rom.platform_id == platform.id
|
||||
|
||||
|
||||
class TestDBSavesHandlerSlotFiltering:
|
||||
def test_get_saves_with_slot_filter(self, admin_user: User, rom: Rom):
|
||||
save1 = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="slot_test_1.sav",
|
||||
file_name_no_tags="slot_test_1",
|
||||
file_name_no_ext="slot_test_1",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="Slot A",
|
||||
)
|
||||
save2 = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="slot_test_2.sav",
|
||||
file_name_no_tags="slot_test_2",
|
||||
file_name_no_ext="slot_test_2",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="Slot A",
|
||||
)
|
||||
save3 = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="slot_test_3.sav",
|
||||
file_name_no_tags="slot_test_3",
|
||||
file_name_no_ext="slot_test_3",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="Slot B",
|
||||
)
|
||||
|
||||
db_save_handler.add_save(save1)
|
||||
db_save_handler.add_save(save2)
|
||||
db_save_handler.add_save(save3)
|
||||
|
||||
slot_a_saves = db_save_handler.get_saves(
|
||||
user_id=admin_user.id, rom_id=rom.id, slot="Slot A"
|
||||
)
|
||||
assert len(slot_a_saves) == 2
|
||||
assert all(s.slot == "Slot A" for s in slot_a_saves)
|
||||
|
||||
slot_b_saves = db_save_handler.get_saves(
|
||||
user_id=admin_user.id, rom_id=rom.id, slot="Slot B"
|
||||
)
|
||||
assert len(slot_b_saves) == 1
|
||||
assert slot_b_saves[0].slot == "Slot B"
|
||||
|
||||
def test_get_saves_with_null_slot_filter(self, admin_user: User, rom: Rom):
|
||||
save_with_slot = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="with_slot.sav",
|
||||
file_name_no_tags="with_slot",
|
||||
file_name_no_ext="with_slot",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="Main",
|
||||
)
|
||||
save_without_slot = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="without_slot.sav",
|
||||
file_name_no_tags="without_slot",
|
||||
file_name_no_ext="without_slot",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot=None,
|
||||
)
|
||||
|
||||
db_save_handler.add_save(save_with_slot)
|
||||
db_save_handler.add_save(save_without_slot)
|
||||
|
||||
all_saves = db_save_handler.get_saves(user_id=admin_user.id, rom_id=rom.id)
|
||||
assert len(all_saves) >= 2
|
||||
|
||||
def test_get_saves_order_by(self, admin_user: User, rom: Rom):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
base_time = datetime.now(timezone.utc)
|
||||
|
||||
save1 = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="order_test_1.sav",
|
||||
file_name_no_tags="order_test_1",
|
||||
file_name_no_ext="order_test_1",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="order_test",
|
||||
)
|
||||
save2 = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="order_test_2.sav",
|
||||
file_name_no_tags="order_test_2",
|
||||
file_name_no_ext="order_test_2",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="order_test",
|
||||
)
|
||||
|
||||
created1 = db_save_handler.add_save(save1)
|
||||
created2 = db_save_handler.add_save(save2)
|
||||
|
||||
db_save_handler.update_save(
|
||||
created1.id, {"updated_at": base_time - timedelta(hours=2)}
|
||||
)
|
||||
db_save_handler.update_save(
|
||||
created2.id, {"updated_at": base_time - timedelta(hours=1)}
|
||||
)
|
||||
|
||||
ordered_saves_desc = db_save_handler.get_saves(
|
||||
user_id=admin_user.id,
|
||||
rom_id=rom.id,
|
||||
slot="order_test",
|
||||
order_by="updated_at",
|
||||
)
|
||||
|
||||
assert len(ordered_saves_desc) == 2
|
||||
assert ordered_saves_desc[0].id == created2.id
|
||||
assert ordered_saves_desc[1].id == created1.id
|
||||
|
||||
ordered_saves_asc = db_save_handler.get_saves(
|
||||
user_id=admin_user.id,
|
||||
rom_id=rom.id,
|
||||
slot="order_test",
|
||||
order_by="updated_at",
|
||||
order_dir="asc",
|
||||
)
|
||||
|
||||
assert len(ordered_saves_asc) == 2
|
||||
assert ordered_saves_asc[0].id == created1.id
|
||||
assert ordered_saves_asc[1].id == created2.id
|
||||
|
||||
|
||||
class TestDBSavesHandlerSummary:
|
||||
def test_get_saves_summary_basic(self, admin_user: User, rom: Rom):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
base_time = datetime.now(timezone.utc)
|
||||
|
||||
configs = [
|
||||
("summary_a_1.sav", "Slot A", -3),
|
||||
("summary_a_2.sav", "Slot A", -1),
|
||||
("summary_b_1.sav", "Slot B", -2),
|
||||
("summary_none_1.sav", None, -4),
|
||||
]
|
||||
|
||||
for filename, slot, hours_offset in configs:
|
||||
save = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name=filename,
|
||||
file_name_no_tags=filename.replace(".sav", ""),
|
||||
file_name_no_ext=filename.replace(".sav", ""),
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot=slot,
|
||||
)
|
||||
created = db_save_handler.add_save(save)
|
||||
db_save_handler.update_save(
|
||||
created.id, {"updated_at": base_time + timedelta(hours=hours_offset)}
|
||||
)
|
||||
|
||||
summary = db_save_handler.get_saves_summary(
|
||||
user_id=admin_user.id, rom_id=rom.id
|
||||
)
|
||||
|
||||
assert "total_count" in summary
|
||||
assert "slots" in summary
|
||||
assert summary["total_count"] == 4
|
||||
assert len(summary["slots"]) == 3
|
||||
|
||||
def test_get_saves_summary_latest_per_slot(self, admin_user: User, rom: Rom):
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
base_time = datetime.now(timezone.utc)
|
||||
|
||||
old_save = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="latest_test_old.sav",
|
||||
file_name_no_tags="latest_test_old",
|
||||
file_name_no_ext="latest_test_old",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="latest_test",
|
||||
)
|
||||
new_save = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name="latest_test_new.sav",
|
||||
file_name_no_tags="latest_test_new",
|
||||
file_name_no_ext="latest_test_new",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="latest_test",
|
||||
)
|
||||
|
||||
old_created = db_save_handler.add_save(old_save)
|
||||
new_created = db_save_handler.add_save(new_save)
|
||||
|
||||
db_save_handler.update_save(
|
||||
old_created.id, {"updated_at": base_time - timedelta(hours=5)}
|
||||
)
|
||||
db_save_handler.update_save(
|
||||
new_created.id, {"updated_at": base_time - timedelta(hours=1)}
|
||||
)
|
||||
|
||||
summary = db_save_handler.get_saves_summary(
|
||||
user_id=admin_user.id, rom_id=rom.id
|
||||
)
|
||||
|
||||
latest_slot = next(
|
||||
(s for s in summary["slots"] if s["slot"] == "latest_test"), None
|
||||
)
|
||||
assert latest_slot is not None
|
||||
assert latest_slot["count"] == 2
|
||||
assert latest_slot["latest"].file_name == "latest_test_new.sav"
|
||||
|
||||
def test_get_saves_summary_empty_rom(self, admin_user: User):
|
||||
summary = db_save_handler.get_saves_summary(
|
||||
user_id=admin_user.id, rom_id=999999
|
||||
)
|
||||
|
||||
assert summary["total_count"] == 0
|
||||
assert summary["slots"] == []
|
||||
|
||||
def test_get_saves_summary_count_accuracy(self, admin_user: User, rom: Rom):
|
||||
for i in range(5):
|
||||
save = Save(
|
||||
rom_id=rom.id,
|
||||
user_id=admin_user.id,
|
||||
file_name=f"count_test_{i}.sav",
|
||||
file_name_no_tags=f"count_test_{i}",
|
||||
file_name_no_ext=f"count_test_{i}",
|
||||
file_extension="sav",
|
||||
emulator="test_emu",
|
||||
file_path=f"{rom.platform_slug}/saves",
|
||||
file_size_bytes=100,
|
||||
slot="count_test",
|
||||
)
|
||||
db_save_handler.add_save(save)
|
||||
|
||||
summary = db_save_handler.get_saves_summary(
|
||||
user_id=admin_user.id, rom_id=rom.id
|
||||
)
|
||||
|
||||
count_slot = next(
|
||||
(s for s in summary["slots"] if s["slot"] == "count_test"), None
|
||||
)
|
||||
assert count_slot is not None
|
||||
assert count_slot["count"] == 5
|
||||
|
||||
7
backend/utils/datetime.py
Normal file
7
backend/utils/datetime.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def to_utc(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
Reference in New Issue
Block a user