backend structure refactor

This commit is contained in:
Zurdi
2024-01-12 13:07:52 +01:00
parent 63afc05a6d
commit 4b9e76f550
75 changed files with 1600 additions and 1518 deletions

View File

@@ -18,6 +18,8 @@ depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Drop platform_slug foreign key on all tables
with op.batch_alter_table('states', schema=None) as batch_op:
batch_op.drop_constraint("states_ibfk_1", type_='foreignkey')
batch_op.drop_column('platform_slug')
@@ -32,25 +34,58 @@ def upgrade() -> None:
with op.batch_alter_table('roms', schema=None) as batch_op:
batch_op.drop_constraint("fk_platform_roms", type_='foreignkey')
# Drop platform_slug foreign key on all tables
# Change platforms primary key
with op.batch_alter_table('platforms', schema=None) as batch_op:
batch_op.drop_constraint(constraint_name="PRIMARY", type_="primary")
batch_op.drop_column('n_roms')
with op.batch_alter_table('platforms', schema=None) as batch_op:
batch_op.execute("ALTER TABLE platforms ADD COLUMN id INTEGER(11) NOT NULL AUTO_INCREMENT PRIMARY KEY")
# Change platforms primary key
# Create platform id foreign key
with op.batch_alter_table('states', schema=None) as batch_op:
batch_op.add_column(sa.Column('platform_id', mysql.INTEGER(display_width=11), autoincrement=False, nullable=False))
with op.batch_alter_table('states', schema=None) as batch_op:
batch_op.create_foreign_key('states_platforms_FK', 'platforms', ['platform_id'], ['id'])
with op.batch_alter_table('screenshots', schema=None) as batch_op:
batch_op.add_column(sa.Column('platform_id', mysql.INTEGER(display_width=11), autoincrement=False, nullable=False))
with op.batch_alter_table('screenshots', schema=None) as batch_op:
batch_op.create_foreign_key('screenshots_platforms_FK', 'platforms', ['platform_id'], ['id'])
with op.batch_alter_table('saves', schema=None) as batch_op:
batch_op.add_column(sa.Column('platform_id', mysql.INTEGER(display_width=11), autoincrement=False, nullable=False))
with op.batch_alter_table('saves', schema=None) as batch_op:
batch_op.create_foreign_key('saves_platforms_FK', 'platforms', ['platform_id'], ['id'])
with op.batch_alter_table('roms', schema=None) as batch_op:
batch_op.add_column(sa.Column('platform_id', mysql.INTEGER(display_width=11), autoincrement=False, nullable=False))
with op.batch_alter_table('roms', schema=None) as batch_op:
batch_op.execute("update roms inner join platforms on roms.platform_slug = platforms.slug set roms.platform_id = platforms.id")
with op.batch_alter_table('roms', schema=None) as batch_op:
batch_op.execute("update roms inner join platforms on roms.platform_slug = platforms.slug set roms.platform_id = platforms.id")
# Update platform id values on other tables
with op.batch_alter_table('roms', schema=None) as batch_op:
batch_op.create_foreign_key('roms_platforms_FK', 'platforms', ['platform_id'], ['id'])
batch_op.drop_column('platform_slug')
# Update platform id values on other tables
# Create platform id foreign key
# Clean roms table
with op.batch_alter_table('roms', schema=None) as batch_op:
batch_op.drop_column('p_sgdb_id')
batch_op.drop_column('p_igdb_id')
batch_op.drop_column('p_name')
# Clean roms table
# ### end Alembic commands ###

View File

@@ -1,21 +1,12 @@
import os
import sys
from urllib.parse import quote_plus
from typing import Final
from typing_extensions import TypedDict
from urllib.parse import quote_plus
import pydash
import yaml
from config import (
DB_HOST,
DB_NAME,
DB_PASSWD,
DB_PORT,
DB_USER,
LIBRARY_BASE_PATH,
ROMM_DB_DRIVER,
ROMM_BASE_PATH,
)
from config import (DB_HOST, DB_NAME, DB_PASSWD, DB_PORT, DB_USER,
ROMM_BASE_PATH, ROMM_DB_DRIVER)
from logger.logger import log
from yaml.loader import SafeLoader
@@ -23,19 +14,7 @@ ROMM_USER_CONFIG_PATH: Final = f"{ROMM_BASE_PATH}/config.yml"
SQLITE_DB_BASE_PATH: Final = f"{ROMM_BASE_PATH}/database"
class ConfigDict(TypedDict):
EXCLUDED_PLATFORMS: list[str]
EXCLUDED_SINGLE_EXT: list[str]
EXCLUDED_SINGLE_FILES: list[str]
EXCLUDED_MULTI_FILES: list[str]
EXCLUDED_MULTI_PARTS_EXT: list[str]
EXCLUDED_MULTI_PARTS_FILES: list[str]
PLATFORMS_BINDING: dict[str, str]
ROMS_FOLDER_NAME: str
SAVES_FOLDER_NAME: str
STATES_FOLDER_NAME: str
SCREENSHOTS_FOLDER_NAME: str
HIGH_PRIO_STRUCTURE_PATH: str
from config import LIBRARY_BASE_PATH
class Config:
@@ -100,10 +79,12 @@ class ConfigLoader:
% quote_plus(DB_PASSWD)
)
# DEPRECATED
if ROMM_DB_DRIVER == "sqlite":
if not os.path.exists(SQLITE_DB_BASE_PATH):
os.makedirs(SQLITE_DB_BASE_PATH)
return f"sqlite:////{SQLITE_DB_BASE_PATH}/romm.db"
# DEPRECATED
log.critical(f"{ROMM_DB_DRIVER} database not supported")
sys.exit(3)

View File

@@ -0,0 +1,41 @@
from typing import Any
from fastapi import Security
from fastapi.security.http import HTTPBasic
from fastapi.security.oauth2 import OAuth2PasswordBearer
from fastapi.types import DecoratedCallable
from handler.auth_handler import DEFAULT_SCOPES_MAP, FULL_SCOPES_MAP, WRITE_SCOPES_MAP
from starlette.authentication import requires
oauth2_password_bearer = OAuth2PasswordBearer(
tokenUrl="/token",
auto_error=False,
scopes={
**DEFAULT_SCOPES_MAP,
**WRITE_SCOPES_MAP,
**FULL_SCOPES_MAP,
},
)
def protected_route(
method: Any,
path: str,
scopes: list[str] = [],
**kwargs,
):
def decorator(func: DecoratedCallable):
fn = requires(scopes)(func)
return method(
path,
dependencies=[
Security(
dependency=oauth2_password_bearer,
scopes=scopes,
),
Security(dependency=HTTPBasic(auto_error=False)),
],
**kwargs,
)(fn)
return decorator

View File

@@ -1,6 +1,7 @@
from pathlib import Path
from config.config_loader import config
from decorators.oauth import protected_route
from endpoints.responses.assets import (
SaveSchema,
StateSchema,
@@ -10,9 +11,7 @@ from endpoints.responses.assets import (
from fastapi import APIRouter, File, HTTPException, Request, UploadFile, status
from handler import dbh
from logger.logger import log
from utils.fastapi import scan_save, scan_state
from utils.fs import build_upload_file_path, remove_file
from utils.oauth import protected_route
from handler.scan_handler import scan_save, scan_state
router = APIRouter()
@@ -43,7 +42,7 @@ def upload_saves(
detail="No saves were uploaded",
)
saves_path = build_upload_file_path(
saves_path = fsh.build_upload_file_path(
rom.platform.fs_slug, folder=config.SAVES_FOLDER_NAME
)
@@ -91,7 +90,7 @@ async def delete_saves(request: Request) -> list[SaveSchema]:
log.info(f"Deleting {save.file_name} from filesystem")
try:
remove_file(file_name=save.file_name, file_path=save.file_path)
fsh.remove_file(file_name=save.file_name, file_path=save.file_path)
except FileNotFoundError:
error = f"Save file {save.file_name} not found for platform {save.platform_slug}"
log.error(error)
@@ -114,7 +113,7 @@ def upload_states(
detail="No states were uploaded",
)
states_path = build_upload_file_path(
states_path = fsh.build_upload_file_path(
rom.platform.fs_slug, folder=config.STATES_FOLDER_NAME
)
@@ -161,7 +160,7 @@ async def delete_states(request: Request) -> list[StateSchema]:
if delete_from_fs:
log.info(f"Deleting {state.file_name} from filesystem")
try:
remove_file(file_name=state.file_name, file_path=state.file_path)
fsh.remove_file(file_name=state.file_name, file_path=state.file_path)
except FileNotFoundError:
error = f"Save file {state.file_name} not found for platform {state.platform_slug}"
log.error(error)

View File

@@ -0,0 +1,40 @@
from typing import Optional
from fastapi import File, UploadFile
from fastapi.param_functions import Form
class UserForm:
def __init__(
self,
username: Optional[str] = None,
password: Optional[str] = None,
role: Optional[str] = None,
enabled: Optional[bool] = None,
avatar: Optional[UploadFile] = File(None),
):
self.username = username
self.password = password
self.role = role
self.enabled = enabled
self.avatar = avatar
class OAuth2RequestForm:
def __init__(
self,
grant_type: str = Form(default="password"),
scope: str = Form(default=""),
username: Optional[str] = Form(default=None),
password: Optional[str] = Form(default=None),
client_id: Optional[str] = Form(default=None),
client_secret: Optional[str] = Form(default=None),
refresh_token: Optional[str] = Form(default=None),
):
self.grant_type = grant_type
self.scopes = scope.split()
self.username = username
self.password = password
self.client_id = client_id
self.client_secret = client_secret
self.refresh_token = refresh_token

View File

@@ -0,0 +1,58 @@
from config import (
ENABLE_RESCAN_ON_FILESYSTEM_CHANGE,
ENABLE_SCHEDULED_RESCAN,
ENABLE_SCHEDULED_UPDATE_MAME_XML,
ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB,
RESCAN_ON_FILESYSTEM_CHANGE_DELAY,
ROMM_AUTH_ENABLED,
SCHEDULED_RESCAN_CRON,
SCHEDULED_UPDATE_MAME_XML_CRON,
SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
)
from config.config_loader import config
from endpoints.responses.heartbeat import HeartbeatReturn
from fastapi import APIRouter
from handler import ghh
router = APIRouter()
@router.get("/heartbeat")
def heartbeat() -> HeartbeatReturn:
"""Endpoint to set the CSFR token in cache and return all the basic RomM config
Returns:
HeartbeatReturn: TypedDict structure with all the defined values in the HeartbeatReturn class.
"""
return {
"VERSION": ghh.get_version(),
"NEW_VERSION": ghh.check_new_version(),
"ROMM_AUTH_ENABLED": ROMM_AUTH_ENABLED,
"WATCHER": {
"ENABLED": ENABLE_RESCAN_ON_FILESYSTEM_CHANGE,
"TITLE": "Rescan on filesystem change",
"MESSAGE": f"Runs a scan when a change is detected in the library path, with a {RESCAN_ON_FILESYSTEM_CHANGE_DELAY} minute delay",
},
"SCHEDULER": {
"RESCAN": {
"ENABLED": ENABLE_SCHEDULED_RESCAN,
"CRON": SCHEDULED_RESCAN_CRON,
"TITLE": "Scheduled rescan",
"MESSAGE": "Rescans the entire library",
},
"SWITCH_TITLEDB": {
"ENABLED": ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB, # noqa
"CRON": SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
"TITLE": "Scheduled Switch TitleDB update",
"MESSAGE": "Updates the Nintedo Switch TitleDB file",
},
"MAME_XML": {
"ENABLED": ENABLE_SCHEDULED_UPDATE_MAME_XML,
"CRON": SCHEDULED_UPDATE_MAME_XML_CRON,
"TITLE": "Scheduled MAME XML update",
"MESSAGE": "Updates the MAME XML file",
},
},
"CONFIG": config.__dict__,
}

View File

@@ -1,22 +1,17 @@
import secrets
from typing import Annotated
from endpoints.responses import MessageResponse
from config import ROMM_AUTH_ENABLED
from exceptions.credentials_exceptions import CredentialsException, DisabledException
from fastapi import APIRouter, Depends, File, HTTPException, Request, status
from decorators.oauth import protected_route
from endpoints.forms.identity import UserForm
from endpoints.responses import MessageResponse
from endpoints.responses.identity import UserSchema
from exceptions.auth_exceptions import AuthCredentialsException, DisabledException
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security.http import HTTPBasic
from handler import dbh
from handler import authh, dbh
from models.user import Role, User
from utils.auth import authenticate_user, clear_session, get_password_hash
from utils.cache import cache
from utils.fs import build_avatar_path
from utils.oauth import protected_route
from endpoints.responses.identity import (
UserSchema,
UserUpdateForm,
)
from handler.redis_handler import cache
router = APIRouter()
@@ -37,9 +32,9 @@ def login(request: Request, credentials=Depends(HTTPBasic())) -> MessageResponse
MessageResponse: Standard message response
"""
user = authenticate_user(credentials.username, credentials.password)
user = authh.authenticate_user(credentials.username, credentials.password)
if not user:
raise CredentialsException
raise AuthCredentialsException
if not user.enabled:
raise DisabledException
@@ -70,7 +65,7 @@ def logout(request: Request) -> MessageResponse:
if not request.user.is_authenticated:
return {"message": "Already logged out"}
clear_session(request)
authh.clear_session(request)
return {"message": "Successfully logged out"}
@@ -153,7 +148,7 @@ def create_user(
user = User(
username=username,
hashed_password=get_password_hash(password),
hashed_password=authh.get_password_hash(password),
role=Role[role.upper()],
)
@@ -162,7 +157,7 @@ def create_user(
@protected_route(router.patch, "/users/{user_id}", ["users.write"])
def update_user(
request: Request, user_id: int, form_data: Annotated[UserUpdateForm, Depends()]
request: Request, user_id: int, form_data: Annotated[UserForm, Depends()]
) -> UserSchema:
"""Update user endpoint
@@ -201,7 +196,7 @@ def update_user(
cleaned_data["username"] = form_data.username.lower()
if form_data.password:
cleaned_data["hashed_password"] = get_password_hash(form_data.password)
cleaned_data["hashed_password"] = authh.get_password_hash(form_data.password)
# You can't change your own role
if form_data.role and request.user.id != user_id:
@@ -212,7 +207,7 @@ def update_user(
cleaned_data["enabled"] = form_data.enabled # type: ignore[assignment]
if form_data.avatar is not None:
cleaned_data["avatar_path"], avatar_user_path = build_avatar_path(
cleaned_data["avatar_path"], avatar_user_path = fsh.build_avatar_path(
form_data.avatar.filename, form_data.username
)
file_location = f"{avatar_user_path}/{form_data.avatar.filename}"
@@ -227,7 +222,7 @@ def update_user(
"hashed_password"
)
if request.user.id == user_id and creds_updated:
clear_session(request)
authh.clear_session(request)
return dbh.get_user(user_id)

View File

@@ -1,15 +1,10 @@
from datetime import timedelta
from typing import Annotated, Final
from fastapi import APIRouter, Depends, HTTPException, status
from utils.auth import authenticate_user
from utils.oauth import (
OAuth2RequestForm,
create_oauth_token,
get_current_active_user_from_bearer_token,
)
from endpoints.forms.identity import OAuth2RequestForm
from endpoints.responses.oauth import TokenResponse
from fastapi import APIRouter, Depends, HTTPException, status
from handler import authh, oauthh
ACCESS_TOKEN_EXPIRE_MINUTES: Final = 30
REFRESH_TOKEN_EXPIRE_DAYS: Final = 7
@@ -45,13 +40,13 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
status_code=status.HTTP_400_BAD_REQUEST, detail="Missing refresh token"
)
user, payload = await get_current_active_user_from_bearer_token(token)
user, payload = await oauthh.get_current_active_user_from_bearer_token(token)
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
)
access_token = create_oauth_token(
access_token = oauthh.create_oauth_token(
data={
"sub": user.username,
"scopes": payload.get("scopes"),
@@ -74,7 +69,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
detail="Missing username or password",
)
user = authenticate_user(form_data.username, form_data.password)
user = authh.authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -102,7 +97,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
detail="Insufficient scope",
)
access_token = create_oauth_token(
access_token = oauthh.create_oauth_token(
data={
"sub": user.username,
"scopes": " ".join(form_data.scopes),
@@ -111,7 +106,7 @@ async def token(form_data: Annotated[OAuth2RequestForm, Depends()]) -> TokenResp
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
)
refresh_token = create_oauth_token(
refresh_token = oauthh.create_oauth_token(
data={
"sub": user.username,
"scopes": " ".join(form_data.scopes),

View File

@@ -3,13 +3,13 @@ from fastapi import APIRouter, HTTPException, Request, status
from handler import dbh
from logger.logger import log
from endpoints.responses import MessageResponse
from utils.oauth import protected_route
from decorators.oauth import protected_route
router = APIRouter()
@protected_route(router.get, "/platforms", ["platforms.read"])
def platforms(request: Request) -> list[PlatformSchema]:
def get_platforms(request: Request) -> list[PlatformSchema]:
"""Get platforms endpoint
Args:
@@ -19,20 +19,34 @@ def platforms(request: Request) -> list[PlatformSchema]:
list[PlatformSchema]: All platforms in the database
"""
return dbh.get_platforms()
return dbh.get_platform()
@protected_route(router.delete, "/platforms/{slug}", ["platforms.write"])
def delete_platform(request: Request, slug) -> MessageResponse:
@protected_route(router.get, "/platforms/{id}", ["platforms.read"])
def get_platforms(request: Request, id: int = None) -> PlatformSchema:
"""Get platform endpoint
Args:
request (Request): Fastapi Request object
Returns:
PlatformSchema: All platforms in the database
"""
return dbh.get_platform(id)
@protected_route(router.delete, "/platforms/{id}", ["platforms.write"])
def delete_platforms(request: Request, id: int) -> MessageResponse:
"""Detele platform from database [and filesystem]"""
platform = dbh.get_platform(slug)
platform = dbh.get_platform(id)
if not platform:
error = f"Platform {platform.name} - [{platform.fs_slug}] not found"
log.error(error)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=error)
log.info(f"Deleting {platform.name} [{platform.fs_slug}] from database")
dbh.delete_platform(platform.slug)
dbh.delete_platform(platform.id)
return {"msg": f"{platform.name} - [{platform.fs_slug}] deleted successfully!"}

View File

@@ -20,26 +20,23 @@ class BaseAsset(BaseModel):
class SaveSchema(BaseAsset):
rom_id: int
platform_slug: str
emulator: Optional[str]
class StateSchema(BaseAsset):
rom_id: int
platform_slug: str
emulator: Optional[str]
class ScreenshotSchema(BaseAsset):
rom_id: int
platform_slug: Optional[str]
class UploadedSavesResponse(TypedDict):
uploaded: int
saves: list[SaveSchema]
class StateSchema(BaseAsset):
rom_id: int
emulator: Optional[str]
class UploadedStatesResponse(TypedDict):
uploaded: int
states: list[StateSchema]
class ScreenshotSchema(BaseAsset):
rom_id: int

View File

@@ -0,0 +1,41 @@
from typing_extensions import TypedDict
class ConfigDict(TypedDict):
EXCLUDED_PLATFORMS: list[str]
EXCLUDED_SINGLE_EXT: list[str]
EXCLUDED_SINGLE_FILES: list[str]
EXCLUDED_MULTI_FILES: list[str]
EXCLUDED_MULTI_PARTS_EXT: list[str]
EXCLUDED_MULTI_PARTS_FILES: list[str]
PLATFORMS_BINDING: dict[str, str]
ROMS_FOLDER_NAME: str
SAVES_FOLDER_NAME: str
STATES_FOLDER_NAME: str
SCREENSHOTS_FOLDER_NAME: str
HIGH_PRIO_STRUCTURE_PATH: str
class WatcherDict(TypedDict):
ENABLED: bool
TITLE: str
MESSAGE: str
class TaskDict(WatcherDict):
CRON: str
class SchedulerDict(TypedDict):
RESCAN: TaskDict
SWITCH_TITLEDB: TaskDict
MAME_XML: TaskDict
class HeartbeatReturn(TypedDict):
VERSION: str
NEW_VERSION: str
ROMM_AUTH_ENABLED: bool
WATCHER: WatcherDict
SCHEDULER: SchedulerDict
CONFIG: ConfigDict

View File

@@ -1,6 +1,3 @@
from typing import Optional
from fastapi import File, UploadFile
from models.user import Role
from pydantic import BaseModel
@@ -15,19 +12,3 @@ class UserSchema(BaseModel):
class Config:
from_attributes = True
class UserUpdateForm:
def __init__(
self,
username: Optional[str] = None,
password: Optional[str] = None,
role: Optional[str] = None,
enabled: Optional[bool] = None,
avatar: Optional[UploadFile] = File(None),
):
self.username = username
self.password = password
self.role = role
self.enabled = enabled
self.avatar = avatar

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
class PlatformSchema(BaseModel):
id: int
slug: str
fs_slug: str
igdb_id: Optional[int] = None

View File

@@ -2,9 +2,9 @@ from typing import Optional
from endpoints.responses.assets import SaveSchema, ScreenshotSchema, StateSchema
from fastapi.responses import StreamingResponse
from handler import socketh
from pydantic import BaseModel
from typing_extensions import TypedDict
from utils.socket import socket_server
class RomSchema(BaseModel):
@@ -12,8 +12,7 @@ class RomSchema(BaseModel):
igdb_id: Optional[int]
sgdb_id: Optional[int]
platform_slug: str
platform_name: str
platform_id: int
file_name: str
file_name_no_tags: str
@@ -69,4 +68,4 @@ class CustomStreamingResponse(StreamingResponse):
async def stream_response(self, *args, **kwargs) -> None:
await super().stream_response(*args, **kwargs)
await socket_server.emit("download:complete", self.emit_body)
await socketh.socket_server.emit("download:complete", self.emit_body)

View File

@@ -1,5 +1,59 @@
from typing_extensions import TypedDict
WEBRCADE_SUPPORTED_PLATFORM_SLUGS = [
"3do",
"arcade",
"atari2600",
"atari5200",
"atari7800",
"lynx",
"wonderswan",
"wonderswan-color",
"colecovision",
"turbografx16--1",
"turbografx-16-slash-pc-engine-cd",
"supergrafx",
"pc-fx",
"nes",
"n64",
"snes",
"gb",
"gba",
"gbc",
"virtualboy",
"sg1000",
"sms",
"genesis-slash-megadrive",
"segacd",
"gamegear",
"neo-geo-cd",
"neogeoaes",
"neogeomvs",
"neo-geo-pocket",
"neo-geo-pocket-color",
"ps",
]
WEBRCADE_SLUG_TO_TYPE_MAP = {
"atari2600": "2600",
"atari5200": "5200",
"atari7800": "7800",
"lynx": "lnx",
"turbografx16--1": "pce",
"turbografx-16-slash-pc-engine-cd": "pce",
"supergrafx": "sgx",
"pc-fx": "pcfx",
"virtualboy": "vb",
"genesis-slash-megadrive": "genesis",
"gamegear": "gg",
"neogeoaes": "neogeo",
"neogeomvs": "neogeo",
"neo-geo-cd": "neogeocd",
"neo-geo-pocket": "ngp",
"neo-geo-pocket-color": "ngc",
"ps": "psx",
}
class WebrcadeFeedSchema(TypedDict):
title: str

View File

@@ -4,15 +4,14 @@ from stat import S_IFREG
from typing import Annotated, Optional
from config import LIBRARY_BASE_PATH
from decorators.oauth import protected_route
from endpoints.responses import MessageResponse
from endpoints.responses.rom import (
CustomStreamingResponse,
EnhancedRomSchema,
RomSchema,
UploadRomResponse,
)
from endpoints.responses.rom import (CustomStreamingResponse,
EnhancedRomSchema, RomSchema,
UploadRomResponse)
from exceptions.fs_exceptions import RomAlreadyExistsException
from fastapi import APIRouter, File, HTTPException, Query, Request, UploadFile, status
from fastapi import (APIRouter, File, HTTPException, Query, Request,
UploadFile, status)
from fastapi.responses import FileResponse
from fastapi_pagination.cursor import CursorPage, CursorParams
from fastapi_pagination.ext.sqlalchemy import paginate
@@ -20,17 +19,6 @@ from handler import dbh
from logger.logger import log
from models import Rom
from stream_zip import ZIP_64, stream_zip # type: ignore[import]
from utils import get_file_name_with_no_tags
from utils.fs import (
_file_exists,
build_artwork_path,
build_upload_file_path,
get_rom_cover,
get_rom_screenshots,
remove_file,
rename_file,
)
from utils.oauth import protected_route
router = APIRouter()
@@ -126,13 +114,13 @@ def upload_roms(
detail="No roms were uploaded",
)
roms_path = build_upload_file_path(platform_fs_slug)
roms_path = fsh.build_upload_file_path(platform_fs_slug)
uploaded_roms = []
skipped_roms = []
for rom in roms:
if _file_exists(roms_path, rom.filename):
if fsh._file_exists(roms_path, rom.filename):
log.warning(f" - Skipping {rom.filename} since the file already exists")
skipped_roms.append(rom.filename)
continue
@@ -252,7 +240,7 @@ async def update_rom(
try:
if db_rom.file_name != fs_safe_file_name:
rename_file(
fsh.rename_file(
old_name=db_rom.file_name,
new_name=fs_safe_file_name,
file_path=db_rom.file_path,
@@ -264,9 +252,9 @@ async def update_rom(
)
cleaned_data["file_name"] = fs_safe_file_name
cleaned_data["file_name_no_tags"] = get_file_name_with_no_tags(fs_safe_file_name)
cleaned_data["file_name_no_tags"] = fsh.get_file_name_with_no_tags(fs_safe_file_name)
cleaned_data.update(
get_rom_cover(
fsh.get_rom_cover(
overwrite=True,
fs_slug=platform_fs_slug,
rom_name=cleaned_data["name"],
@@ -275,7 +263,7 @@ async def update_rom(
)
cleaned_data.update(
get_rom_screenshots(
fsh.get_rom_screenshots(
fs_slug=platform_fs_slug,
rom_name=cleaned_data["name"],
url_screenshots=cleaned_data.get("url_screenshots", []),
@@ -284,7 +272,7 @@ async def update_rom(
if artwork is not None:
file_ext = artwork.filename.split(".")[-1]
path_cover_l, path_cover_s, artwork_path = build_artwork_path(
path_cover_l, path_cover_s, artwork_path = fsh.build_artwork_path(
cleaned_data["name"], platform_fs_slug, file_ext
)
@@ -332,7 +320,7 @@ def _delete_single_rom(rom_id: int, delete_from_fs: bool = False) -> Rom:
if delete_from_fs:
log.info(f"Deleting {rom.file_name} from filesystem")
try:
remove_file(file_name=rom.file_name, file_path=rom.file_path)
fsh.remove_file(file_name=rom.file_name, file_path=rom.file_path)
except FileNotFoundError:
error = (
f"Rom file {rom.file_name} not found for platform {rom.platform_slug}"

View File

@@ -1,10 +1,9 @@
import emoji
from decorators.oauth import protected_route
from endpoints.responses.search import RomSearchResponse
from fastapi import APIRouter, Request
from handler import dbh, igdbh
from logger.logger import log
from utils.oauth import protected_route
from endpoints.responses.search import RomSearchResponse
router = APIRouter()

View File

@@ -3,25 +3,17 @@ import socketio # type: ignore
from config import ENABLE_EXPERIMENTAL_REDIS
from endpoints.platform import PlatformSchema
from endpoints.rom import RomSchema
from exceptions.fs_exceptions import PlatformsNotFoundException, RomsNotFoundException
from handler import dbh
from logger.logger import log
from utils.fastapi import (
from exceptions.fs_exceptions import FolderStructureNotMatchException, RomsNotFoundException
from handler import dbh, socketh, platformh, romh, resourceh, asseth
from handler.redis_handler import high_prio_queue, redis_url
from handler.scan_handler import (
scan_platform,
scan_rom,
scan_save,
scan_screenshot,
scan_state,
)
from utils.fs import (
get_assets,
get_platforms,
get_roms,
get_screenshots,
store_default_resources,
)
from utils.redis import high_prio_queue, redis_url
from utils.socket import socket_server
from logger.logger import log
async def scan_platforms(
@@ -43,13 +35,13 @@ async def scan_platforms(
sm = (
socketio.AsyncRedisManager(redis_url, write_only=True)
if ENABLE_EXPERIMENTAL_REDIS
else socket_server
else socketh.socket_server
)
# Scanning file system
try:
fs_platforms: list[str] = get_platforms()
except PlatformsNotFoundException as e:
fs_platforms: list[str] = platformh.get_platforms()
except FolderStructureNotMatchException as e:
log.error(e)
await sm.emit("scan:done_ko", e.message)
return
@@ -66,19 +58,17 @@ async def scan_platforms(
for platform_slug in platform_list:
scanned_platform = scan_platform(platform_slug, fs_platforms)
_new_platform = dbh.add_platform(scanned_platform)
new_platform = dbh.get_platform(_new_platform.slug)
_added_platform = dbh.add_platform(scanned_platform)
platform = dbh.get_platform(_added_platform.id)
await sm.emit(
"scan:scanning_platform",
PlatformSchema.model_validate(new_platform).model_dump(),
PlatformSchema.model_validate(platform).model_dump(),
)
dbh.add_platform(scanned_platform)
# Scanning roms
try:
fs_roms = get_roms(scanned_platform.fs_slug)
fs_roms = romh.get_roms(platform.fs_slug)
except RomsNotFoundException as e:
log.error(e)
continue
@@ -91,39 +81,40 @@ async def scan_platforms(
log.info(f" {len(fs_roms)} roms found")
for fs_rom in fs_roms:
rom = dbh.get_rom_by_filename(scanned_platform.slug, fs_rom["file_name"])
rom = dbh.get_rom_by_filename(platform.id, fs_rom["file_name"])
if (rom and rom.id not in selected_roms and not complete_rescan) and not (
rescan_unidentified and rom and not rom.igdb_id
):
continue
scanned_rom = await scan_rom(scanned_platform, fs_rom)
scanned_rom = await scan_rom(platform, fs_rom)
if rom:
scanned_rom.id = rom.id
_new_rom = dbh.add_rom(scanned_rom)
new_rom = dbh.get_rom(_new_rom.id)
scanned_rom.platform_id = platform.id
_added_rom = dbh.add_rom(scanned_rom)
rom = dbh.get_rom(_added_rom.id)
await sm.emit(
"scan:scanning_rom",
{
"p_name": scanned_platform.name,
**RomSchema.model_validate(new_rom).model_dump(),
"p_name": platform.name,
**RomSchema.model_validate(rom).model_dump(),
},
)
fs_assets = get_assets(scanned_platform.fs_slug)
fs_assets = asseth.get_assets(platform.fs_slug)
# Scanning saves
log.info(f"\t · {len(fs_assets['saves'])} saves found")
for fs_emulator, fs_save_filename in fs_assets["saves"]:
scanned_save = scan_save(
platform=scanned_platform,
platform=platform,
file_name=fs_save_filename,
emulator=fs_emulator,
)
save = dbh.get_save_by_filename(scanned_platform.slug, fs_save_filename)
save = dbh.get_save_by_filename(platform.id, fs_save_filename)
if save:
# Update file size if changed
if save.file_size_bytes != scanned_save.file_size_bytes:
@@ -133,7 +124,6 @@ async def scan_platforms(
continue
scanned_save.emulator = fs_emulator
scanned_save.platform_slug = scanned_platform.slug
rom = dbh.get_rom_by_filename_no_tags(scanned_save.file_name_no_tags)
if rom:
@@ -144,12 +134,12 @@ async def scan_platforms(
log.info(f"\t · {len(fs_assets['states'])} states found")
for fs_emulator, fs_state_filename in fs_assets["states"]:
scanned_state = scan_state(
platform=scanned_platform,
platform=platform,
emulator=fs_emulator,
file_name=fs_state_filename,
)
state = dbh.get_state_by_filename(scanned_platform.slug, fs_state_filename)
state = dbh.get_state_by_filename(platform.id, fs_state_filename)
if state:
# Update file size if changed
if state.file_size_bytes != scanned_state.file_size_bytes:
@@ -160,7 +150,7 @@ async def scan_platforms(
continue
scanned_state.emulator = fs_emulator
scanned_state.platform_slug = scanned_platform.slug
# scanned_state.platform_slug = scanned_platform.slug TODO: remove
rom = dbh.get_rom_by_filename_no_tags(scanned_state.file_name_no_tags)
if rom:
@@ -171,7 +161,7 @@ async def scan_platforms(
log.info(f"\t · {len(fs_assets['screenshots'])} screenshots found")
for fs_screenshot_filename in fs_assets["screenshots"]:
scanned_screenshot = scan_screenshot(
file_name=fs_screenshot_filename, fs_platform=scanned_platform.slug
file_name=fs_screenshot_filename, platform=platform
)
screenshot = dbh.get_screenshot_by_filename(fs_screenshot_filename)
@@ -184,20 +174,20 @@ async def scan_platforms(
)
continue
scanned_screenshot.platform_slug = scanned_platform.slug
# scanned_screenshot.platform_slug = scanned_patform.slug TODO: remove
rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
if rom:
scanned_screenshot.rom_id = rom.id
dbh.add_screenshot(scanned_screenshot)
dbh.purge_saves(scanned_platform.slug, [s for _e, s in fs_assets["saves"]])
dbh.purge_states(scanned_platform.slug, [s for _e, s in fs_assets["states"]])
dbh.purge_screenshots(fs_assets["screenshots"], scanned_platform.slug)
dbh.purge_roms(scanned_platform.slug, [rom["file_name"] for rom in fs_roms])
dbh.purge_saves(platform.id, [s for _e, s in fs_assets["saves"]])
dbh.purge_states(platform.id, [s for _e, s in fs_assets["states"]])
dbh.purge_screenshots(platform.id, fs_assets["screenshots"])
dbh.purge_roms(platform.id, [rom["file_name"] for rom in fs_roms])
# Scanning screenshots outside platform folders
fs_screenshots = get_screenshots()
fs_screenshots = asseth.get_screenshots()
log.info("Screenshots")
log.info(f" · {len(fs_screenshots)} screenshots found")
for fs_platform, fs_screenshot_filename in fs_screenshots:
@@ -218,7 +208,7 @@ async def scan_platforms(
rom = dbh.get_rom_by_filename_no_tags(scanned_screenshot.file_name_no_tags)
if rom:
scanned_screenshot.rom_id = rom.id
scanned_screenshot.platform_slug = rom.platform_slug
# scanned_screenshot.platform_slug = rom.platform_slug TODO: remove
dbh.add_screenshot(scanned_screenshot)
dbh.purge_screenshots([s for _e, s in fs_screenshots])
@@ -229,7 +219,7 @@ async def scan_platforms(
await sm.emit("scan:done", {})
@socket_server.on("scan")
@socketh.socket_server.on("scan")
async def scan_handler(_sid: str, options: dict):
"""Scan socket endpoint
@@ -238,7 +228,7 @@ async def scan_handler(_sid: str, options: dict):
"""
log.info(emoji.emojize(":magnifying_glass_tilted_right: Scanning "))
store_default_resources()
resourceh.store_default_resources()
platform_slugs = options.get("platforms", [])
complete_rescan = options.get("completeRescan", False)

View File

@@ -1,8 +1,8 @@
from decorators.oauth import protected_route
from endpoints.responses import MessageResponse
from fastapi import APIRouter, Request
from tasks.update_mame_xml import update_mame_xml_task
from tasks.update_switch_titledb import update_switch_titledb_task
from utils.oauth import protected_route
router = APIRouter()

View File

@@ -1,65 +1,15 @@
from config import ROMM_HOST
from endpoints.responses.webrcade import WebrcadeFeedSchema
from decorators.oauth import protected_route
from endpoints.responses.webrcade import (
WEBRCADE_SLUG_TO_TYPE_MAP,
WEBRCADE_SUPPORTED_PLATFORM_SLUGS,
WebrcadeFeedSchema,
)
from fastapi import APIRouter, Request
from handler import dbh
from utils.oauth import protected_route
router = APIRouter()
WEBRCADE_SUPPORTED_PLATFORM_SLUGS = [
"3do",
"arcade",
"atari2600",
"atari5200",
"atari7800",
"lynx",
"wonderswan",
"wonderswan-color",
"colecovision",
"turbografx16--1",
"turbografx-16-slash-pc-engine-cd",
"supergrafx",
"pc-fx",
"nes",
"n64",
"snes",
"gb",
"gba",
"gbc",
"virtualboy",
"sg1000",
"sms",
"genesis-slash-megadrive",
"segacd",
"gamegear",
"neo-geo-cd",
"neogeoaes",
"neogeomvs",
"neo-geo-pocket",
"neo-geo-pocket-color",
"ps",
]
WEBRCADE_SLUG_TO_TYPE_MAP = {
"atari2600": "2600",
"atari5200": "5200",
"atari7800": "7800",
"lynx": "lnx",
"turbografx16--1": "pce",
"turbografx-16-slash-pc-engine-cd": "pce",
"supergrafx": "sgx",
"pc-fx": "pcfx",
"virtualboy": "vb",
"genesis-slash-megadrive": "genesis",
"gamegear": "gg",
"neogeoaes": "neogeo",
"neogeomvs": "neogeo",
"neo-geo-cd": "neogeocd",
"neo-geo-pocket": "ngp",
"neo-geo-pocket-color": "ngc",
"ps": "psx",
}
@protected_route(router.get, "/platforms/webrcade/feed", [])
def platforms_webrcade_feed(request: Request) -> WebrcadeFeedSchema:

View File

@@ -1,6 +1,6 @@
from fastapi import HTTPException, status
CredentialsException = HTTPException(
AuthCredentialsException = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
@@ -14,3 +14,9 @@ DisabledException = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Disabled user",
)
OAuthCredentialsException = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

View File

@@ -1,7 +1,7 @@
folder_struct_msg = "Check RomM folder structure here: https://github.com/zurdi15/romm#-folder-structure"
class PlatformsNotFoundException(Exception):
class FolderStructureNotMatchException(Exception):
def __init__(self):
self.message = f"Platforms not found. {folder_struct_msg}"
super().__init__(self.message)
@@ -28,7 +28,6 @@ class RomsNotFoundException(Exception):
return self.message
class RomAlreadyExistsException(Exception):
def __init__(self, rom_name: str):
self.message = f"Can't rename: {rom_name} already exists"

View File

@@ -1,7 +1,22 @@
from handler.auth_handler.auth_handler import AuthHandler, OAuthHandler
from handler.db_handler import DBHandler
from handler.fs_handler.platforms_handler import PlatformsHandler
from handler.fs_handler.roms_handler import RomsHandler
from handler.fs_handler.assets_handler import AssetsHandler
from handler.fs_handler.resources_handler import ResourceHandler
from handler.gh_handler import GHHandler
from handler.igdb_handler import IGDBHandler
from handler.sgdb_handler import SGDBHandler
from handler.socket_handler import SocketHandler
igdbh: IGDBHandler = IGDBHandler()
sgdbh: SGDBHandler = SGDBHandler()
dbh: DBHandler = DBHandler()
igdbh = IGDBHandler()
sgdbh = SGDBHandler()
dbh = DBHandler()
ghh = GHHandler()
authh = AuthHandler()
oauthh = OAuthHandler()
socketh = SocketHandler()
platformh = PlatformsHandler()
romh = RomsHandler()
asseth = AssetsHandler()
resourceh = ResourceHandler()

View File

@@ -0,0 +1,28 @@
from typing import Final
ALGORITHM: Final = "HS256"
DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15
DEFAULT_SCOPES_MAP: Final = {
"me.read": "View your profile",
"me.write": "Modify your profile",
"roms.read": "View ROMs",
"platforms.read": "View platforms",
"assets.read": "View assets",
}
WRITE_SCOPES_MAP: Final = {
"roms.write": "Modify ROMs",
"platforms.write": "Modify platforms",
"assets.write": "Modify assets",
}
FULL_SCOPES_MAP: Final = {
"users.read": "View users",
"users.write": "Modify users",
"tasks.run": "Run tasks",
}
DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys())
WRITE_SCOPES: Final = DEFAULT_SCOPES + list(WRITE_SCOPES_MAP.keys())
FULL_SCOPES: Final = WRITE_SCOPES + list(FULL_SCOPES_MAP.keys())

View File

@@ -0,0 +1,135 @@
from datetime import datetime, timedelta
from config import (
ROMM_AUTH_ENABLED,
ROMM_AUTH_PASSWORD,
ROMM_AUTH_SECRET_KEY,
ROMM_AUTH_USERNAME,
)
from exceptions.auth_exceptions import OAuthCredentialsException
from fastapi import HTTPException, Request, status
from handler.auth_handler import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError
from starlette.requests import HTTPConnection
from handler.redis_handler import cache
class AuthHandler:
def __init__(self) -> None:
self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def _verify_password(self, plain_password, hashed_password):
return self.pwd_context.verify(plain_password, hashed_password)
def get_password_hash(self, password):
return self.pwd_context.hash(password)
@staticmethod
def clear_session(req: HTTPConnection | Request):
session_id = req.session.get("session_id")
if session_id:
redish.cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
req.session["session_id"] = None
def authenticate_user(self, username: str, password: str):
from handler import dbh
user = dbh.get_user_by_username(username)
if not user:
return None
if not self._verify_password(password, user.hashed_password):
return None
return user
async def get_current_active_user_from_session(self, conn: HTTPConnection):
from handler import dbh
# Check if session key already stored in cache
session_id = conn.session.get("session_id")
if not session_id:
return None
username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
if not username:
return None
# Key exists therefore user is probably authenticated
user = dbh.get_user_by_username(username)
if user is None:
self.clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User not found",
)
if not user.enabled:
self.clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
def create_default_admin_user(self):
from handler import dbh
from models.user import Role, User
if not ROMM_AUTH_ENABLED:
return
try:
dbh.add_user(
User(
username=ROMM_AUTH_USERNAME,
hashed_password=self.get_password_hash(ROMM_AUTH_PASSWORD),
role=Role.ADMIN,
)
)
except IntegrityError:
pass
class OAuthHandler:
def __init__(self) -> None:
pass
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
async def get_current_active_user_from_bearer_token(token: str):
from handler import dbh
try:
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise OAuthCredentialsException
username = payload.get("sub")
if username is None:
raise OAuthCredentialsException
user = dbh.get_user_by_username(username)
if user is None:
raise OAuthCredentialsException
if not user.enabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
)
return user, payload

View File

@@ -0,0 +1,52 @@
from config import ROMM_AUTH_ENABLED
from fastapi.security.http import HTTPBasic
from handler import authh
from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.requests import HTTPConnection
from handler import oauthh
from handler.auth_handler import FULL_SCOPES
class HybridAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if not ROMM_AUTH_ENABLED:
return (AuthCredentials(FULL_SCOPES), None)
# Check if session key already stored in cache
user = await authh.get_current_active_user_from_session(conn)
if user:
return (AuthCredentials(user.oauth_scopes), user)
# Check if Authorization header exists
if "Authorization" not in conn.headers:
return (AuthCredentials([]), None)
scheme, token = conn.headers["Authorization"].split()
# Check if basic auth header is valid
if scheme.lower() == "basic":
credentials = await HTTPBasic().__call__(conn) # type: ignore[arg-type]
if not credentials:
return (AuthCredentials([]), None)
user = authh.authenticate_user(credentials.username, credentials.password)
if user is None:
return (AuthCredentials([]), None)
return (AuthCredentials(user.oauth_scopes), user)
# Check if bearer auth header is valid
if scheme.lower() == "bearer":
user, payload = await oauthh.get_current_active_user_from_bearer_token(token)
# Only access tokens can request resources
if payload.get("type") != "access":
return (AuthCredentials([]), None)
# Only grant access to resources with overlapping scopes
token_scopes = set(list(payload.get("scopes").split(" ")))
overlapping_scopes = list(token_scopes & set(user.oauth_scopes))
return (AuthCredentials(overlapping_scopes), user)
return (AuthCredentials([]), None)

View File

@@ -0,0 +1,11 @@
from starlette.types import Receive, Scope, Send
from starlette_csrf.middleware import CSRFMiddleware
class CustomCSRFMiddleware(CSRFMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
await super().__call__(scope, receive, send)

View File

@@ -38,22 +38,20 @@ class DBHandler:
return session.merge(platform)
@begin_session
def get_platforms(self, session: Session = None):
def get_platform(self, id: int = None, session: Session = None):
return (
session.scalars(select(Platform).order_by(Platform.slug.asc()))
.unique()
.all()
if not id
else session.get(Platform, id)
)
@begin_session
def get_platform(self, slug: str, session: Session = None):
return session.get(Platform, slug)
@begin_session
def get_platform_by_fs_slug(self, fs_slug: str, session: Session = None):
return session.scalars(
select(Platform).filter_by(fs_slug=fs_slug).limit(1)
).first()
# @begin_session
# def get_platform_by_fs_slug(self, fs_slug: str, session: Session = None):
# return session.scalars(
# select(Platform).filter_by(fs_slug=fs_slug).limit(1)
# ).first()
@begin_session
def delete_platform(self, slug: str, session: Session = None):
@@ -68,6 +66,12 @@ class DBHandler:
.where(Platform.slug == slug)
.execution_options(synchronize_session="evaluate")
)
@begin_session
def get_rom_count(self, platform_id: int, session: Session = None):
return session.scalar(
select(func.count()).select_from(Rom).filter_by(platform_id=platform_id)
)
@begin_session
def purge_platforms(self, platforms: list[str], session: Session = None):
@@ -132,28 +136,20 @@ class DBHandler:
)
@begin_session
def purge_roms(self, platform_slug: str, roms: list[str], session: Session = None):
def purge_roms(self, platform_id: int, roms: list[str], session: Session = None):
return session.execute(
delete(Rom)
.where(and_(Rom.platform_slug == platform_slug, Rom.file_name.not_in(roms)))
.where(and_(Rom.platform_id == platform_id, Rom.file_name.not_in(roms)))
.execution_options(synchronize_session="evaluate")
)
@begin_session
def get_rom_count(self, platform_slug: str, session: Session = None):
return session.scalar(
select(func.count()).select_from(Rom).filter_by(platform_slug=platform_slug)
)
# ==== Utils ======
@begin_session
def get_rom_by_filename(
self, platform_slug: str, file_name: str, session: Session = None
self, platform_id: int, file_name: str, session: Session = None
):
return session.scalars(
select(Rom)
.filter_by(platform_slug=platform_slug, file_name=file_name)
.limit(1)
select(Rom).filter_by(platform_id=platform_id, file_name=file_name).limit(1)
).first()
@begin_session
@@ -201,14 +197,10 @@ class DBHandler:
)
@begin_session
def purge_saves(
self, platform_slug: str, saves: list[str], session: Session = None
):
def purge_saves(self, platform_id: int, saves: list[str], session: Session = None):
return session.execute(
delete(Save)
.where(
and_(Save.platform_slug == platform_slug, Save.file_name.not_in(saves))
)
.where(and_(Save.platform_id == platform_id, Save.file_name.not_in(saves)))
.execution_options(synchronize_session="evaluate")
)
@@ -223,11 +215,11 @@ class DBHandler:
@begin_session
def get_state_by_filename(
self, platform_slug: str, file_name: str, session: Session = None
self, platform_id: int, file_name: str, session: Session = None
):
return session.scalars(
select(State)
.filter_by(platform_slug=platform_slug, file_name=file_name)
.filter_by(platform_slug=platform_id, file_name=file_name)
.limit(1)
).first()
@@ -250,14 +242,12 @@ class DBHandler:
@begin_session
def purge_states(
self, platform_slug: str, states: list[str], session: Session = None
self, platform_id: int, states: list[str], session: Session = None
):
return session.execute(
delete(State)
.where(
and_(
State.platform_slug == platform_slug, State.file_name.not_in(states)
)
and_(State.platform_id == platform_id, State.file_name.not_in(states))
)
.execution_options(synchronize_session="evaluate")
)
@@ -296,12 +286,12 @@ class DBHandler:
@begin_session
def purge_screenshots(
self, screenshots: list[str], platform_slug: str = None, session: Session = None
self, platform_id: int, screenshots: list[str], session: Session = None
):
return session.execute(
delete(Screenshot)
.where(
Screenshot.platform_slug == platform_slug,
Screenshot.platform_id == platform_id,
Screenshot.file_name.not_in(screenshots),
)
.execution_options(synchronize_session="evaluate")

View File

@@ -0,0 +1,75 @@
from enum import Enum
from typing import Final
from config import ROMM_BASE_PATH
RESOURCES_BASE_PATH: Final = f"{ROMM_BASE_PATH}/resources"
DEFAULT_WIDTH_COVER_L: Final = 264 # Width of big cover of IGDB
DEFAULT_HEIGHT_COVER_L: Final = 352 # Height of big cover of IGDB
DEFAULT_WIDTH_COVER_S: Final = 90 # Width of small cover of IGDB
DEFAULT_HEIGHT_COVER_S: Final = 120 # Height of small cover of IGDB
LANGUAGES = [
("Ar", "Arabic"),
("Da", "Danish"),
("De", "German"),
("En", "English"),
("Es", "Spanish"),
("Fi", "Finnish"),
("Fr", "French"),
("It", "Italian"),
("Ja", "Japanese"),
("Ko", "Korean"),
("Nl", "Dutch"),
("No", "Norwegian"),
("Pl", "Polish"),
("Pt", "Portuguese"),
("Ru", "Russian"),
("Sv", "Swedish"),
("Zh", "Chinese"),
("nolang", "No Language"),
]
REGIONS = [
("A", "Australia"),
("AS", "Asia"),
("B", "Brazil"),
("C", "Canada"),
("CH", "China"),
("E", "Europe"),
("F", "France"),
("FN", "Finland"),
("G", "Germany"),
("GR", "Greece"),
("H", "Holland"),
("HK", "Hong Kong"),
("I", "Italy"),
("J", "Japan"),
("K", "Korea"),
("NL", "Netherlands"),
("NO", "Norway"),
("PD", "Public Domain"),
("R", "Russia"),
("S", "Spain"),
("SW", "Sweden"),
("T", "Taiwan"),
("U", "USA"),
("UK", "England"),
("UNK", "Unknown"),
("UNL", "Unlicensed"),
("W", "World"),
]
REGIONS_BY_SHORTCODE = {region[0].lower(): region[1] for region in REGIONS}
REGIONS_NAME_KEYS = [region[1].lower() for region in REGIONS]
LANGUAGES_BY_SHORTCODE = {lang[0].lower(): lang[1] for lang in LANGUAGES}
LANGUAGES_NAME_KEYS = [lang[1].lower() for lang in LANGUAGES]
TAG_REGEX = r"\(([^)]+)\)|\[([^]]+)\]"
EXTENSION_REGEX = r"\.(([a-z]+\.)*\w+)$"
class CoverSize(Enum):
SMALL = "small"
BIG = "big"

View File

@@ -0,0 +1,142 @@
import os
import shutil
from pathlib import Path
from urllib.parse import quote
import requests
from config import LIBRARY_BASE_PATH
from config.config_loader import config
from handler.fs_handler import RESOURCES_BASE_PATH
from handler.fs_handler.fs_handler import FSHandler
class AssetsHandler(FSHandler):
def __init__(self) -> None:
pass
@staticmethod
def _store_screenshot(fs_slug: str, rom_name: str, url: str, idx: int):
"""Store roms resources in filesystem
Args:
fs_slug: short name of the platform
file_name: name of rom
url: url to get the screenshot
"""
screenshot_file: str = f"{idx}.jpg"
screenshot_path: str = f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/screenshots"
res = requests.get(url, stream=True, timeout=120)
if res.status_code == 200:
Path(screenshot_path).mkdir(parents=True, exist_ok=True)
with open(f"{screenshot_path}/{screenshot_file}", "wb") as f:
shutil.copyfileobj(res.raw, f)
@staticmethod
def _get_screenshot_path(fs_slug: str, rom_name: str, idx: str):
"""Returns rom cover filesystem path adapted to frontend folder structure
Args:
fs_slug: short name of the platform
file_name: name of rom
idx: index number of screenshot
"""
return f"{fs_slug}/{rom_name}/screenshots/{idx}.jpg"
def get_rom_screenshots(
self, platform_fs_slug: str, rom_name: str, url_screenshots: list
) -> dict:
q_rom_name = quote(rom_name)
path_screenshots: list[str] = []
for idx, url in enumerate(url_screenshots):
self._store_screenshot(platform_fs_slug, rom_name, url, idx)
path_screenshots.append(
self._get_screenshot_path(platform_fs_slug, q_rom_name, str(idx))
)
return {"path_screenshots": path_screenshots}
def get_assets(self, platform_slug: str):
saves_path = self.get_fs_structure(
platform_slug, folder=config.SAVES_FOLDER_NAME
)
saves_file_path = f"{LIBRARY_BASE_PATH}/{saves_path}"
fs_saves: list[str] = []
fs_states: list[str] = []
fs_screenshots: list[str] = []
try:
emulators = list(os.walk(saves_file_path))[0][1]
for emulator in emulators:
fs_saves += [
(emulator, file)
for file in list(os.walk(f"{saves_file_path}/{emulator}"))[0][2]
]
fs_saves += [(None, file) for file in list(os.walk(saves_file_path))[0][2]]
except IndexError:
pass
states_path = self.get_fs_structure(
platform_slug, folder=config.STATES_FOLDER_NAME
)
states_file_path = f"{LIBRARY_BASE_PATH}/{states_path}"
try:
emulators = list(os.walk(states_file_path))[0][1]
for emulator in emulators:
fs_states += [
(emulator, file)
for file in list(os.walk(f"{states_file_path}/{emulator}"))[0][2]
]
fs_states += [
(None, file) for file in list(os.walk(states_file_path))[0][2]
]
except IndexError:
pass
screenshots_path = self.get_fs_structure(
platform_slug, folder=config.SCREENSHOTS_FOLDER_NAME
)
screenshots_file_path = f"{LIBRARY_BASE_PATH}/{screenshots_path}"
try:
fs_screenshots += [
file for file in list(os.walk(screenshots_file_path))[0][2]
]
except IndexError:
pass
return {
"saves": fs_saves,
"states": fs_states,
"screenshots": fs_screenshots,
}
@staticmethod
def get_screenshots():
screenshots_path = f"{LIBRARY_BASE_PATH}/{config.SCREENSHOTS_FOLDER_NAME}"
fs_screenshots = []
try:
platforms = list(os.walk(screenshots_path))[0][1]
for platform in platforms:
fs_screenshots += [
(platform, file)
for file in list(os.walk(f"{screenshots_path}/{platform}"))[0][2]
]
fs_screenshots += [
(None, file) for file in list(os.walk(screenshots_path))[0][2]
]
except IndexError:
pass
return fs_screenshots
@staticmethod
def get_asset_size(asset_path: str, file_name: str):
return os.stat(f"{LIBRARY_BASE_PATH}/{asset_path}/{file_name}").st_size

View File

@@ -0,0 +1,40 @@
import os
import re
from abc import ABC
from config import LIBRARY_BASE_PATH
from config.config_loader import config
from handler.fs_handler import EXTENSION_REGEX, TAG_REGEX
class FSHandler(ABC):
def __init__(self) -> None:
pass
@staticmethod
def get_fs_structure(fs_slug: str, folder: str = config.ROMS_FOLDER_NAME):
return (
f"{folder}/{fs_slug}"
if os.path.exists(config.HIGH_PRIO_STRUCTURE_PATH)
else f"{fs_slug}/{folder}"
)
@staticmethod
def _get_file_name_with_no_extension(file_name: str) -> str:
return re.sub(EXTENSION_REGEX, "", file_name).strip()
@staticmethod
def get_file_name_with_no_tags(file_name: str) -> str:
file_name_no_extension = re.sub(EXTENSION_REGEX, "", file_name).strip()
return re.split(TAG_REGEX, file_name_no_extension)[0].strip()
@staticmethod
def parse_file_extension(file_name) -> str:
match = re.search(EXTENSION_REGEX, file_name)
return match.group(1) if match else ""
def build_upload_file_path(
self, fs_slug: str, folder: str = config.ROMS_FOLDER_NAME
):
rom_path = self.get_fs_structure(fs_slug, folder=folder)
return f"{LIBRARY_BASE_PATH}/{rom_path}"

View File

@@ -0,0 +1,35 @@
import os
from config import LIBRARY_BASE_PATH
from config.config_loader import config
from exceptions.fs_exceptions import FolderStructureNotMatchException
from handler.fs_handler.fs_handler import FSHandler
class PlatformsHandler(FSHandler):
def __init__(self) -> None:
pass
@staticmethod
def _exclude_platforms(platforms: list):
return [
platform
for platform in platforms
if platform not in config.EXCLUDED_PLATFORMS
]
def get_platforms(self) -> list[str]:
"""Gets all filesystem platforms
Returns list with all the filesystem platforms found in the LIBRARY_BASE_PATH.
Automatically exclude folders defined in user config.
"""
try:
platforms: list[str] = (
list(os.walk(config.HIGH_PRIO_STRUCTURE_PATH))[0][1]
if os.path.exists(config.HIGH_PRIO_STRUCTURE_PATH)
else list(os.walk(LIBRARY_BASE_PATH))[0][1]
)
return self._exclude_platforms(platforms)
except IndexError as exc:
raise FolderStructureNotMatchException from exc

View File

@@ -0,0 +1,169 @@
import datetime
import os
import shutil
from pathlib import Path
from urllib.parse import quote
import requests
from config import (
DEFAULT_PATH_COVER_L,
DEFAULT_PATH_COVER_S,
DEFAULT_URL_COVER_L,
DEFAULT_URL_COVER_S,
)
from handler.fs_handler import (
DEFAULT_HEIGHT_COVER_L,
DEFAULT_HEIGHT_COVER_S,
DEFAULT_WIDTH_COVER_L,
DEFAULT_WIDTH_COVER_S,
RESOURCES_BASE_PATH,
CoverSize,
)
from handler.fs_handler.fs_handler import FSHandler
from PIL import Image
class ResourceHandler(FSHandler):
def __init__(self) -> None:
pass
@staticmethod
def _cover_exists(fs_slug: str, rom_name: str, size: CoverSize):
"""Check if rom cover exists in filesystem
Args:
fs_slug: short name of the platform
rom_name: name of rom file
size: size of the cover
Returns
True if cover exists in filesystem else False
"""
return bool(
os.path.exists(
f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/cover/{size.value}.png"
)
)
@staticmethod
def _resize_cover(cover_path: str, size: CoverSize) -> None:
"""Resizes the cover image to the standard size
Args:
cover_path: path where the original cover were stored
size: size of the cover
"""
cover = Image.open(cover_path)
if cover.size[1] > DEFAULT_HEIGHT_COVER_L:
if size == CoverSize.BIG:
big_dimensions = (DEFAULT_WIDTH_COVER_L, DEFAULT_HEIGHT_COVER_L)
background = Image.new("RGBA", big_dimensions, (0, 0, 0, 0))
cover.thumbnail(big_dimensions)
offset = (
int(round(((DEFAULT_WIDTH_COVER_L - cover.size[0]) / 2), 0)),
0,
)
elif size == CoverSize.SMALL:
small_dimensions = (DEFAULT_WIDTH_COVER_S, DEFAULT_HEIGHT_COVER_S)
background = Image.new("RGBA", small_dimensions, (0, 0, 0, 0))
cover.thumbnail(small_dimensions)
offset = (
int(round(((DEFAULT_WIDTH_COVER_S - cover.size[0]) / 2), 0)),
0,
)
else:
return
background.paste(cover, offset)
background.save(cover_path)
@staticmethod
def _store_cover(
self, fs_slug: str, rom_name: str, url_cover: str, size: CoverSize
):
"""Store roms resources in filesystem
Args:
fs_slug: short name of the platform
rom_name: name of rom file
url_cover: url to get the cover
size: size of the cover
"""
cover_file = f"{size.value}.png"
cover_path = f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/cover"
res = requests.get(
url_cover.replace("t_thumb", f"t_cover_{size.value}"),
stream=True,
timeout=120,
)
if res.status_code == 200:
Path(cover_path).mkdir(parents=True, exist_ok=True)
with open(f"{cover_path}/{cover_file}", "wb") as f:
shutil.copyfileobj(res.raw, f)
self._resize_cover(f"{cover_path}/{cover_file}", size)
@staticmethod
def _get_cover_path(fs_slug: str, rom_name: str, size: CoverSize):
"""Returns rom cover filesystem path adapted to frontend folder structure
Args:
fs_slug: short name of the platform
file_name: name of rom file
size: size of the cover
"""
strtime = str(datetime.datetime.now().timestamp())
return f"{fs_slug}/{rom_name}/cover/{size.value}.png?timestamp={strtime}"
def get_rom_cover(
self, overwrite: bool, platform_fs_slug: str, rom_name: str, url_cover: str = ""
) -> dict:
q_rom_name = quote(rom_name)
if (
overwrite or not self._cover_exists(platform_fs_slug, rom_name, CoverSize.SMALL)
) and url_cover:
self._store_cover(platform_fs_slug, rom_name, url_cover, CoverSize.SMALL)
path_cover_s = (
self._get_cover_path(platform_fs_slug, q_rom_name, CoverSize.SMALL)
if self._cover_exists(platform_fs_slug, rom_name, CoverSize.SMALL)
else DEFAULT_PATH_COVER_S
)
if (
overwrite or not self._cover_exists(platform_fs_slug, rom_name, CoverSize.BIG)
) and url_cover:
self._store_cover(platform_fs_slug, rom_name, url_cover, CoverSize.BIG)
path_cover_l = (
self._get_cover_path(platform_fs_slug, q_rom_name, CoverSize.BIG)
if self._cover_exists(platform_fs_slug, rom_name, CoverSize.BIG)
else DEFAULT_PATH_COVER_L
)
return {
"path_cover_s": path_cover_s,
"path_cover_l": path_cover_l,
}
def store_default_resources(self):
"""Store default cover resources in the filesystem"""
defaul_covers = [
{"url": DEFAULT_URL_COVER_L, "size": CoverSize.BIG},
{"url": DEFAULT_URL_COVER_S, "size": CoverSize.SMALL},
]
for cover in defaul_covers:
if not self._cover_exists("default", "default", cover["size"]):
self._store_cover("default", "default", cover["url"], cover["size"])
@staticmethod
def build_artwork_path(rom_name: str, fs_slug: str, file_ext: str):
q_rom_name = quote(rom_name)
strtime = str(datetime.datetime.now().timestamp())
path_cover_l = f"{fs_slug}/{q_rom_name}/cover/{CoverSize.BIG.value}.{file_ext}?timestamp={strtime}"
path_cover_s = f"{fs_slug}/{q_rom_name}/cover/{CoverSize.SMALL.value}.{file_ext}?timestamp={strtime}"
artwork_path = f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/cover"
Path(artwork_path).mkdir(parents=True, exist_ok=True)
return path_cover_l, path_cover_s, artwork_path
@staticmethod
def build_avatar_path(avatar_path: str, username: str):
avatar_user_path = f"{RESOURCES_BASE_PATH}/users/{username}"
Path(avatar_user_path).mkdir(parents=True, exist_ok=True)
return f"users/{username}/{avatar_path}", avatar_user_path

View File

@@ -0,0 +1,200 @@
import fnmatch
import os
import re
import shutil
from pathlib import Path
from models.platform import Platform
from config import LIBRARY_BASE_PATH
from config.config_loader import config
from exceptions.fs_exceptions import RomAlreadyExistsException, RomsNotFoundException
from handler.fs_handler import (
LANGUAGES_BY_SHORTCODE,
LANGUAGES_NAME_KEYS,
REGIONS_BY_SHORTCODE,
REGIONS_NAME_KEYS,
TAG_REGEX,
)
from handler.fs_handler.fs_handler import FSHandler
class RomsHandler(FSHandler):
def __init__(self) -> None:
pass
@staticmethod
def parse_tags(file_name: str) -> tuple:
rev = ""
regs = []
langs = []
other_tags = []
tags = [tag[0] or tag[1] for tag in re.findall(TAG_REGEX, file_name)]
tags = [tag for subtags in tags for tag in subtags.split(",")]
tags = [tag.strip() for tag in tags]
for tag in tags:
if tag.lower() in REGIONS_BY_SHORTCODE.keys():
regs.append(REGIONS_BY_SHORTCODE[tag.lower()])
continue
if tag.lower() in REGIONS_NAME_KEYS:
regs.append(tag)
continue
if tag.lower() in LANGUAGES_BY_SHORTCODE.keys():
langs.append(LANGUAGES_BY_SHORTCODE[tag.lower()])
continue
if tag.lower() in LANGUAGES_NAME_KEYS:
langs.append(tag)
continue
if "reg" in tag.lower():
match = re.match(r"^reg[\s|-](.*)$", tag, re.IGNORECASE)
if match:
regs.append(
REGIONS_BY_SHORTCODE[match.group(1).lower()]
if match.group(1).lower() in REGIONS_BY_SHORTCODE.keys()
else match.group(1)
)
continue
if "rev" in tag.lower():
match = re.match(r"^rev[\s|-](.*)$", tag, re.IGNORECASE)
if match:
rev = match.group(1)
continue
other_tags.append(tag)
return regs, rev, langs, other_tags
def _exclude_files(self, files, filetype) -> list[str]:
excluded_extensions = getattr(config, f"EXCLUDED_{filetype.upper()}_EXT")
excluded_names = getattr(config, f"EXCLUDED_{filetype.upper()}_FILES")
excluded_files: list = []
for file_name in files:
# Split the file name to get the extension.
ext = self.parse_file_extension(file_name)
# Exclude the file if it has no extension or the extension is in the excluded list.
if not ext or ext in excluded_extensions:
excluded_files.append(file_name)
# Additionally, check if the file name mathes a pattern in the excluded list.
if len(excluded_names) > 0:
[
excluded_files.append(file_name)
for name in excluded_names
if file_name == name or fnmatch.fnmatch(file_name, name)
]
# Return files that are not in the filtered list.
return [f for f in files if f not in excluded_files]
@staticmethod
def _exclude_multi_roms(roms) -> list[str]:
excluded_names = config.EXCLUDED_MULTI_FILES
filtered_files: list = []
for rom in roms:
if rom in excluded_names:
filtered_files.append(rom)
return [f for f in roms if f not in filtered_files]
def get_rom_files(self, rom: str, roms_path: str) -> list[str]:
rom_files: list = []
for path, _, files in os.walk(f"{roms_path}/{rom}"):
for f in self._exclude_files(files, "multi_parts"):
rom_files.append(f"{Path(path, f)}".replace(f"{roms_path}/{rom}/", ""))
return rom_files
def get_roms(self, platform: Platform):
"""Gets all filesystem roms for a platform
Args:
platform: platform where roms belong
Returns:
list with all the filesystem roms for a platform found in the LIBRARY_BASE_PATH
"""
roms_path = self.get_fs_structure(platform.fs_slug)
roms_file_path = f"{LIBRARY_BASE_PATH}/{roms_path}"
try:
fs_single_roms: list[str] = list(os.walk(roms_file_path))[0][2]
except IndexError as exc:
raise RomsNotFoundException(platform.fs_slug) from exc
try:
fs_multi_roms: list[str] = list(os.walk(roms_file_path))[0][1]
except IndexError as exc:
raise RomsNotFoundException(platform.fs_slug) from exc
fs_roms: list[dict] = [
{"multi": False, "file_name": rom}
for rom in self._exclude_files(fs_single_roms, "single")
] + [
{"multi": True, "file_name": rom}
for rom in self._exclude_multi_roms(fs_multi_roms)
]
return [
dict(
rom,
files=self.get_rom_files(rom["file_name"], roms_file_path),
)
for rom in fs_roms
]
@staticmethod
def get_rom_file_size(
roms_path: str, file_name: str, multi: bool, multi_files: list = []
):
files = (
[f"{LIBRARY_BASE_PATH}/{roms_path}/{file_name}"]
if not multi
else [
f"{LIBRARY_BASE_PATH}/{roms_path}/{file_name}/{file}"
for file in multi_files
]
)
total_size: float = 0.0
for file in files:
total_size += os.stat(file).st_size
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if total_size < 1024.0 or unit == "PB":
break
total_size /= 1024.0
return round(total_size, 2), unit
@staticmethod
def _file_exists(path: str, file_name: str):
"""Check if file exists in filesystem
Args:
path: path to file
file_name: name of file
Returns
True if file exists in filesystem else False
"""
return bool(os.path.exists(f"{LIBRARY_BASE_PATH}/{path}/{file_name}"))
def rename_file(self, old_name: str, new_name: str, file_path: str):
if new_name != old_name:
if self._file_exists(path=file_path, file_name=new_name):
raise RomAlreadyExistsException(new_name)
os.rename(
f"{LIBRARY_BASE_PATH}/{file_path}/{old_name}",
f"{LIBRARY_BASE_PATH}/{file_path}/{new_name}",
)
@staticmethod
def remove_file(file_name: str, file_path: str):
try:
os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
except IsADirectoryError:
shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")

View File

@@ -0,0 +1,54 @@
import subprocess as sp
import requests
from __version__ import __version__
from logger.logger import log
from packaging.version import InvalidVersion, parse
from requests.exceptions import ReadTimeout
class GHHandler:
def __init__(self) -> None:
pass
@staticmethod
def get_version() -> str | None:
"""Returns current version or branch name."""
if not __version__ == "<version>":
return __version__
else:
try:
output = str(
sp.check_output(["git", "branch"], universal_newlines=True)
)
except sp.CalledProcessError:
return None
branch = [a for a in output.split("\n") if a.find("*") >= 0][0]
return branch[branch.find("*") + 2 :]
def check_new_version(self) -> str:
"""Check for new RomM versions
Returns:
str: New RomM version or empty if in dev mode
"""
try:
response = requests.get(
"https://api.github.com/repos/zurdi15/romm/releases/latest", timeout=0.5
)
except ReadTimeout:
log.warning("Couldn't check last RomM version.")
return ""
try:
last_version = response.json()["name"][
1:
] # remove leading 'v' from 'vX.X.X'
except KeyError: # rate limit reached
return ""
try:
if parse(self.get_version()) < parse(last_version):
return last_version
except InvalidVersion:
pass
return ""

View File

@@ -16,9 +16,7 @@ from tasks.update_mame_xml import update_mame_xml_task
from tasks.update_switch_titledb import update_switch_titledb_task
from typing_extensions import TypedDict
from unidecode import unidecode as uc
from utils import get_file_name_with_no_tags as get_search_term
from utils import normalize_search_term
from utils.cache import cache
from handler.redis_handler import cache
MAIN_GAME_CATEGORY: Final = 0
EXPANDED_GAME_CATEGORY: Final = 10
@@ -82,6 +80,16 @@ class IGDBHandler:
return func(*args)
return wrapper
@staticmethod
def normalize_search_term(search_term: str) -> str:
return (
search_term.replace("\u2122", "") # Remove trademark symbol
.replace("\u00ae", "") # Remove registered symbol
.replace("\u00a9", "") # Remove copywrite symbol
.replace("\u2120", "") # Remove service mark symbol
.strip() # Remove leading and trailing spaces
)
def _request(self, url: str, data: str, timeout: int = 120) -> list:
try:
@@ -225,8 +233,8 @@ class IGDBHandler:
search_term = index_entry["name"] # type: ignore
return search_term
@staticmethod
async def _mame_format(search_term: str) -> str:
async def _mame_format(self, search_term: str) -> str:
mame_index = {"menu": {"game": []}}
try:
@@ -248,6 +256,9 @@ class IGDBHandler:
]
if index_entry:
# Run through get_search_term to remove tags
# TODO: refactor
from handler.fs_handler.fs_handler import FSHandler
get_search_term = FSHandler.get_file_name_with_no_tags
search_term = get_search_term(
index_entry[0].get("description", search_term)
)
@@ -272,6 +283,9 @@ class IGDBHandler:
@check_twitch_token
async def get_rom(self, file_name: str, platform_idgb_id: int) -> IGDBRomType:
# TODO: refactor
from handler.fs_handler.fs_handler import FSHandler
get_search_term = FSHandler.get_file_name_with_no_tags
search_term = get_search_term(file_name)
# Support for PS2 OPL filename format
@@ -293,7 +307,7 @@ class IGDBHandler:
if platform_idgb_id in ARCADE_IGDB_IDS:
search_term = await self._mame_format(search_term)
search_term = normalize_search_term(search_term)
search_term = self.normalize_search_term(search_term)
res = (
self._search_rom(uc(search_term), platform_idgb_id, MAIN_GAME_CATEGORY)

View File

@@ -1,5 +1,15 @@
from enum import Enum
from config import ENABLE_EXPERIMENTAL_REDIS, REDIS_HOST, REDIS_PASSWORD, REDIS_PORT
from logger.logger import log
from redis import Redis
from rq import Queue
class QueuePrio(Enum):
HIGH = "high"
DEFAULT = "default"
LOW = "low"
class FallbackCache:
@@ -28,6 +38,20 @@ class FallbackCache:
return repr(self)
redis_client = Redis(
host=REDIS_HOST, port=int(REDIS_PORT), password=REDIS_PASSWORD, db=0
)
redis_url = (
f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}"
if REDIS_PASSWORD
else f"redis://{REDIS_HOST}:{REDIS_PORT}"
)
high_prio_queue = Queue(name=QueuePrio.HIGH.name, connection=redis_client)
default_queue = Queue(name=QueuePrio.DEFAULT.name, connection=redis_client)
low_prio_queue = Queue(name=QueuePrio.LOW.name, connection=redis_client)
# A seperate client that auto-decodes responses is needed
_cache_client = Redis(
host=REDIS_HOST,
@@ -37,4 +61,6 @@ _cache_client = Redis(
decode_responses=True,
)
_fallback_cache = FallbackCache()
if ENABLE_EXPERIMENTAL_REDIS:
log.info("Redis enabled: Connecting...")
cache = _cache_client if ENABLE_EXPERIMENTAL_REDIS else _fallback_cache

View File

@@ -1,12 +1,11 @@
import os
from typing import Any
import emoji
import os
from config.config_loader import config
from models import Platform, Rom, Save, State, Screenshot
from handler import dbh, igdbh
from handler import dbh, igdbh, romh, resourceh, asseth
from logger.logger import log
from utils import fs, get_file_extension, get_file_name_with_no_tags, parse_tags
from models import Platform, Rom, Save, Screenshot, State
SWAPPED_PLATFORM_BINDINGS = dict((v, k) for k, v in config.PLATFORMS_BINDING.items())
@@ -61,7 +60,7 @@ async def scan_rom(
r_igbd_id_search: str = "",
overwrite: bool = False,
) -> Rom:
roms_path = fs.get_fs_structure(platform.fs_slug)
roms_path = romh.get_fs_structure(platform.fs_slug)
log.info(f"\t · {r_igbd_id_search or rom_attrs['file_name']}")
@@ -70,19 +69,21 @@ async def scan_rom(
log.info(f"\t\t · {file}")
# Update properties that don't require IGDB
file_size, file_size_units = fs.get_rom_file_size(
file_size, file_size_units = romh.get_rom_file_size(
multi=rom_attrs["multi"],
file_name=rom_attrs["file_name"],
multi_files=rom_attrs["files"],
roms_path=roms_path,
)
regs, rev, langs, other_tags = parse_tags(rom_attrs["file_name"])
regs, rev, langs, other_tags = romh.parse_tags(rom_attrs["file_name"])
rom_attrs.update(
{
"file_path": roms_path,
"file_name": rom_attrs["file_name"],
"file_name_no_tags": get_file_name_with_no_tags(rom_attrs["file_name"]),
"file_extension": get_file_extension(rom_attrs["file_name"]),
"file_name_no_tags": romh.get_file_name_with_no_tags(
rom_attrs["file_name"]
),
"file_extension": romh.parse_file_extension(rom_attrs["file_name"]),
"file_size": file_size,
"file_size_units": file_size_units,
"multi": rom_attrs["multi"],
@@ -92,7 +93,7 @@ async def scan_rom(
"tags": other_tags,
}
)
rom_attrs["platform_slug"] = platform.slug
rom_attrs["platform_id"] = platform.id
# Search in IGDB
igdbh_rom = (
@@ -114,16 +115,16 @@ async def scan_rom(
# Update properties from IGDB
rom_attrs.update(
fs.get_rom_cover(
resourceh.get_rom_cover(
overwrite=overwrite,
fs_slug=platform.slug,
platform_fs_slug=platform.slug,
rom_name=rom_attrs["name"],
url_cover=rom_attrs["url_cover"],
)
)
rom_attrs.update(
fs.get_rom_screenshots(
fs_slug=platform.slug,
asseth.get_rom_screenshots(
platform_fs_slug=platform.slug,
rom_name=rom_attrs["name"],
url_screenshots=rom_attrs["url_screenshots"],
)
@@ -135,19 +136,21 @@ async def scan_rom(
def _scan_asset(file_name: str, path: str):
log.info(f"\t\t · {file_name}")
file_size = fs.get_fs_file_size(file_name=file_name, asset_path=path)
file_size = asseth.get_asset_size(file_name=file_name, asset_path=path)
return {
"file_path": path,
"file_name": file_name,
"file_name_no_tags": get_file_name_with_no_tags(file_name),
"file_extension": get_file_extension(file_name),
"file_name_no_tags": asseth.get_file_name_with_no_tags(file_name),
"file_extension": asseth.parse_file_extension(file_name),
"file_size_bytes": file_size,
}
def scan_save(platform: Platform, file_name: str, emulator: str = None) -> Save:
saves_path = fs.get_fs_structure(platform.fs_slug, folder=config.SAVES_FOLDER_NAME)
saves_path = asseth.get_fs_structure(
platform.fs_slug, folder=config.SAVES_FOLDER_NAME
)
# Scan asset with the sames path and emulator folder name
if emulator:
@@ -157,7 +160,7 @@ def scan_save(platform: Platform, file_name: str, emulator: str = None) -> Save:
def scan_state(platform: Platform, file_name: str, emulator: str = None) -> State:
states_path = fs.get_fs_structure(
states_path = asseth.get_fs_structure(
platform.fs_slug, folder=config.STATES_FOLDER_NAME
)
@@ -168,12 +171,12 @@ def scan_state(platform: Platform, file_name: str, emulator: str = None) -> Stat
return State(**_scan_asset(file_name, states_path))
def scan_screenshot(file_name: str, fs_platform: str = None) -> Screenshot:
screenshots_path = fs.get_fs_structure(
fs_platform, folder=config.SCREENSHOTS_FOLDER_NAME
def scan_screenshot(file_name: str, platform: Platform = None) -> Screenshot:
screenshots_path = asseth.get_fs_structure(
platform.fs_slug, folder=config.SCREENSHOTS_FOLDER_NAME
)
if fs_platform:
if platform.fs_slug:
return Screenshot(**_scan_asset(file_name, screenshots_path))
return Screenshot(**_scan_asset(file_name, config.SCREENSHOTS_FOLDER_NAME))

View File

@@ -0,0 +1,18 @@
import socketio # type: ignore
from config import ENABLE_EXPERIMENTAL_REDIS
from handler.redis_handler import redis_url
class SocketHandler:
def __init__(self) -> None:
self.socket_server = socketio.AsyncServer(
cors_allowed_origins="*",
async_mode="asgi",
logger=False,
engineio_logger=False,
client_manager=socketio.AsyncRedisManager(redis_url)
if ENABLE_EXPERIMENTAL_REDIS
else None,
)
self.socket_app = socketio.ASGIApp(self.socket_server)

View File

@@ -5,8 +5,8 @@ from pathlib import Path
from typing import Final
from config import ROMM_BASE_PATH
from .stdout_formatter import StdoutFormatter
from .file_formatter import FileFormatter
from logger.stdout_formatter import StdoutFormatter
from logger.file_formatter import FileFormatter
LOGS_BASE_PATH: Final = f"{ROMM_BASE_PATH}/logs"

View File

@@ -3,48 +3,20 @@ import sys
import alembic.config
import uvicorn
from config import DEV_HOST, DEV_PORT, ROMM_AUTH_ENABLED, ROMM_AUTH_SECRET_KEY
from endpoints import (assets, heartbeat, identity, oauth, platform, rom,
search, tasks, webrcade)
from endpoints.sockets import scan
from config import (
DEV_HOST,
DEV_PORT,
ENABLE_RESCAN_ON_FILESYSTEM_CHANGE,
ENABLE_SCHEDULED_RESCAN,
ENABLE_SCHEDULED_UPDATE_MAME_XML,
ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB,
RESCAN_ON_FILESYSTEM_CHANGE_DELAY,
ROMM_AUTH_ENABLED,
ROMM_AUTH_SECRET_KEY,
SCHEDULED_RESCAN_CRON,
SCHEDULED_UPDATE_MAME_XML_CRON,
SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
)
from config.config_loader import ConfigDict, config
from endpoints import (
assets,
identity,
oauth,
platform,
rom,
search,
tasks,
webrcade,
)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
from handler import dbh
from handler import authh, dbh, ghh, socketh
from handler.auth_handler.hybrid_auth import HybridAuthBackend
from handler.auth_handler.middleware import CustomCSRFMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
from typing_extensions import TypedDict
from utils import check_new_version, get_version
from utils.auth import (
CustomCSRFMiddleware,
HybridAuthBackend,
create_default_admin_user,
)
from utils.socket import socket_app
app = FastAPI(title="RomM API", version=get_version())
app = FastAPI(title="RomM API", version=ghh.get_version())
app.add_middleware(
CORSMiddleware,
@@ -76,6 +48,7 @@ app.add_middleware(
https_only=False,
)
app.include_router(heartbeat.router)
app.include_router(oauth.router)
app.include_router(identity.router)
app.include_router(platform.router)
@@ -86,73 +59,7 @@ app.include_router(tasks.router)
app.include_router(webrcade.router)
add_pagination(app)
app.mount("/ws", socket_app)
class WatcherDict(TypedDict):
ENABLED: bool
TITLE: str
MESSAGE: str
class TaskDict(WatcherDict):
CRON: str
class SchedulerDict(TypedDict):
RESCAN: TaskDict
SWITCH_TITLEDB: TaskDict
MAME_XML: TaskDict
class HeartbeatReturn(TypedDict):
VERSION: str
NEW_VERSION: str
ROMM_AUTH_ENABLED: bool
WATCHER: WatcherDict
SCHEDULER: SchedulerDict
CONFIG: ConfigDict
@app.get("/heartbeat")
def heartbeat() -> HeartbeatReturn:
"""Endpoint to set the CSFR token in cache and return all the basic RomM config
Returns:
HeartbeatReturn: TypedDict structure with all the defined values in the HeartbeatReturn class.
"""
return {
"VERSION": get_version(),
"NEW_VERSION": check_new_version(),
"ROMM_AUTH_ENABLED": ROMM_AUTH_ENABLED,
"WATCHER": {
"ENABLED": ENABLE_RESCAN_ON_FILESYSTEM_CHANGE,
"TITLE": "Rescan on filesystem change",
"MESSAGE": f"Runs a scan when a change is detected in the library path, with a {RESCAN_ON_FILESYSTEM_CHANGE_DELAY} minute delay",
},
"SCHEDULER": {
"RESCAN": {
"ENABLED": ENABLE_SCHEDULED_RESCAN,
"CRON": SCHEDULED_RESCAN_CRON,
"TITLE": "Scheduled rescan",
"MESSAGE": "Rescans the entire library",
},
"SWITCH_TITLEDB": {
"ENABLED": ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB, # noqa
"CRON": SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
"TITLE": "Scheduled Switch TitleDB update",
"MESSAGE": "Updates the Nintedo Switch TitleDB file",
},
"MAME_XML": {
"ENABLED": ENABLE_SCHEDULED_UPDATE_MAME_XML,
"CRON": SCHEDULED_UPDATE_MAME_XML_CRON,
"TITLE": "Scheduled MAME XML update",
"MESSAGE": "Updates the MAME XML file",
},
},
"CONFIG": config.__dict__,
}
app.mount("/ws", socketh.socket_app)
@app.on_event("startup")
@@ -161,7 +68,7 @@ def startup() -> None:
# Create default admin user if no admin user exists
if len(dbh.get_admin_users()) == 0 and "pytest" not in sys.modules:
create_default_admin_user()
authh.create_default_admin_user()
if __name__ == "__main__":

View File

@@ -1,4 +1,4 @@
from .platform import Platform # noqa[401]
from .rom import Rom # noqa[401]
from .user import User, Role # noqa[401]
from .assets import Save, State, Screenshot # noqa[401]
from models.platform import Platform # noqa[401]
from models.rom import Rom # noqa[401]
from models.user import User, Role # noqa[401]
from models.assets import Save, State, Screenshot # noqa[401]

View File

@@ -1,9 +1,10 @@
from sqlalchemy import Integer, Column, ForeignKey, String, DateTime, func
from sqlalchemy.orm import relationship
from functools import cached_property
from config import FRONTEND_LIBRARY_PATH
from .base import BaseModel
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
from sqlalchemy.orm import relationship
from models.base import BaseModel
class BaseAsset(BaseModel):
@@ -24,6 +25,10 @@ class BaseAsset(BaseModel):
file_path = Column(String(length=1000), nullable=False)
file_size_bytes = Column(Integer(), default=0, nullable=False)
rom_id = Column(
Integer(), ForeignKey("roms.id", ondelete="CASCADE"), nullable=False
)
@cached_property
def full_path(self) -> str:
return f"{self.file_path}/{self.file_name}"
@@ -39,18 +44,8 @@ class Save(BaseAsset):
emulator = Column(String(length=50), nullable=True)
rom_id = Column(
Integer(), ForeignKey("roms.id", ondelete="CASCADE"), nullable=False
)
rom = relationship("Rom", lazy="selectin", back_populates="saves")
platform_slug = Column(
String(length=50),
ForeignKey("platforms.slug", ondelete="CASCADE"),
nullable=False,
)
platform = relationship("Platform", lazy="selectin", back_populates="saves")
class State(BaseAsset):
__tablename__ = "states"
@@ -58,31 +53,11 @@ class State(BaseAsset):
emulator = Column(String(length=50), nullable=True)
rom_id = Column(
Integer(), ForeignKey("roms.id", ondelete="CASCADE"), nullable=False
)
rom = relationship("Rom", lazy="selectin", back_populates="states")
platform_slug = Column(
String(length=50),
ForeignKey("platforms.slug", ondelete="CASCADE"),
nullable=False,
)
platform = relationship("Platform", lazy="selectin", back_populates="states")
class Screenshot(BaseAsset):
__tablename__ = "screenshots"
__table_args__ = {"extend_existing": True}
rom_id = Column(
Integer(), ForeignKey("roms.id", ondelete="CASCADE"), nullable=False
)
rom = relationship("Rom", lazy="selectin", back_populates="screenshots")
platform_slug = Column(
String(length=50),
ForeignKey("platforms.slug", ondelete="CASCADE"),
nullable=True,
)
platform = relationship("Platform", lazy="selectin", back_populates="screenshots")

View File

@@ -1,46 +1,30 @@
from sqlalchemy import Column, String, Integer
from sqlalchemy.orm import relationship, Mapped
from config import DEFAULT_PATH_COVER_S
from .base import BaseModel
from models.base import BaseModel
from models.rom import Rom
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import Mapped, relationship
class Platform(BaseModel):
from .rom import Rom
from .assets import Save, State, Screenshot
__tablename__ = "platforms"
slug: str = Column(String(length=50), primary_key=True)
fs_slug: str = Column(String(length=50), nullable=False)
name: str = Column(String(length=400))
id = Column(Integer(), primary_key=True, autoincrement=True)
igdb_id: int = Column(Integer())
sgdb_id: int = Column(Integer())
slug: str = Column(String(length=50))
fs_slug: str = Column(String(length=50), nullable=False)
name: str = Column(String(length=400))
logo_path: str = Column(String(length=1000), default=DEFAULT_PATH_COVER_S)
roms: Mapped[set[Rom]] = relationship(
"Rom", lazy="selectin", back_populates="platform"
)
saves: Mapped[set[Save]] = relationship(
"Save", lazy="selectin", back_populates="platform"
)
states: Mapped[set[State]] = relationship(
"State", lazy="selectin", back_populates="platform"
)
screenshots: Mapped[set[State]] = relationship(
"Screenshot", lazy="selectin", back_populates="platform"
)
### DEPRECATED ###
n_roms: int = Column(Integer, default=0)
### DEPRECATED ###
@property
def rom_count(self) -> int:
from handler import dbh
return dbh.get_rom_count(self.slug)
return dbh.get_rom_count(self.id)
def __repr__(self) -> str:
return self.name

View File

@@ -2,16 +2,16 @@ import re
from functools import cached_property
from config import (
DEFAULT_PATH_COVER_S,
DEFAULT_PATH_COVER_L,
DEFAULT_PATH_COVER_S,
FRONTEND_LIBRARY_PATH,
FRONTEND_RESOURCES_PATH,
)
from models.assets import Save, Screenshot, State
from models.base import BaseModel
from sqlalchemy import JSON, Boolean, Column, Float, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, relationship
from .base import BaseModel
SIZE_UNIT_TO_BYTES = {
"B": 1,
"KB": 1024,
@@ -25,8 +25,6 @@ SORT_COMPARE_REGEX = r"^([Tt]he|[Aa]|[Aa]nd)\s"
class Rom(BaseModel):
from .assets import Save, State, Screenshot
__tablename__ = "roms"
id = Column(Integer(), primary_key=True, autoincrement=True)
@@ -34,33 +32,6 @@ class Rom(BaseModel):
igdb_id: int = Column(Integer())
sgdb_id: int = Column(Integer())
platform_slug = Column(
String(length=50),
ForeignKey("platforms.slug"),
nullable=False,
)
platform = relationship(
"Platform", lazy="selectin", back_populates="roms"
)
saves: Mapped[list[Save]] = relationship(
"Save",
lazy="selectin",
back_populates="rom",
)
states: Mapped[list[State]] = relationship(
"State", lazy="selectin", back_populates="rom"
)
screenshots: Mapped[list[Screenshot]] = relationship(
"Screenshot", lazy="selectin", back_populates="rom"
)
### DEPRECATED ###
p_name: str = Column(String(length=150), default="")
p_igdb_id: str = Column(String(length=10), default="")
p_sgdb_id: str = Column(String(length=10), default="")
### DEPRECATED ###
file_name: str = Column(String(length=450), nullable=False)
file_name_no_tags: str = Column(String(length=450), nullable=False)
file_extension: str = Column(String(length=100), nullable=False)
@@ -80,11 +51,41 @@ class Rom(BaseModel):
regions: JSON = Column(JSON, default=[])
languages: JSON = Column(JSON, default=[])
tags: JSON = Column(JSON, default=[])
multi: bool = Column(Boolean, default=False)
files: JSON = Column(JSON, default=[])
url_screenshots: JSON = Column(JSON, default=[])
path_screenshots: JSON = Column(JSON, default=[])
multi: bool = Column(Boolean, default=False)
files: JSON = Column(JSON, default=[])
platform_id = Column(
Integer(),
ForeignKey("platforms.id", ondelete="CASCADE"),
nullable=False,
)
platform = relationship("Platform", lazy="selectin", back_populates="roms")
saves: Mapped[list[Save]] = relationship(
"Save",
lazy="selectin",
back_populates="rom",
)
states: Mapped[list[State]] = relationship(
"State", lazy="selectin", back_populates="rom"
)
screenshots: Mapped[list[Screenshot]] = relationship(
"Screenshot", lazy="selectin", back_populates="rom"
)
@property
def platform_slug(self) -> str:
return self.platform.slug
@property
def platform_fs_slug(self) -> str:
return self.platform.fs_slug
@property
def platform_name(self) -> str:
return self.platform.name

View File

@@ -1,10 +1,9 @@
import enum
from handler.auth_handler import DEFAULT_SCOPES, FULL_SCOPES, WRITE_SCOPES
from models.base import BaseModel
from sqlalchemy import Boolean, Column, Enum, Integer, String
from starlette.authentication import SimpleUser
from utils.oauth import DEFAULT_SCOPES, FULL_SCOPES, WRITE_SCOPES
from .base import BaseModel
class Role(enum.Enum):
@@ -15,13 +14,17 @@ class Role(enum.Enum):
class User(BaseModel, SimpleUser):
__tablename__ = "users"
__table_args__ = {'extend_existing': True}
__table_args__ = {"extend_existing": True}
id = Column(Integer(), primary_key=True, autoincrement=True)
username: str = Column(String(length=255), unique=True, index=True)
hashed_password: str = Column(String(length=255))
enabled: bool = Column(Boolean(), default=True)
role: Role = Column(Enum(Role), default=Role.VIEWER)
avatar_path: str = Column(String(length=255), default="")
@property

View File

@@ -5,7 +5,7 @@ from logger.logger import log
from tasks.scan_library import scan_library_task
from tasks.update_mame_xml import update_mame_xml_task
from tasks.update_switch_titledb import update_switch_titledb_task
from tasks.utils import tasks_scheduler
from tasks.tasks import tasks_scheduler
if __name__ == "__main__":
if not ENABLE_EXPERIMENTAL_REDIS:

View File

@@ -1,8 +1,7 @@
from config import ENABLE_SCHEDULED_RESCAN, SCHEDULED_RESCAN_CRON
from endpoints.sockets.scan import scan_platforms
from logger.logger import log
from .utils import PeriodicTask
from tasks.tasks import PeriodicTask
class ScanLibraryTask(PeriodicTask):

View File

@@ -2,11 +2,10 @@ from abc import ABC, abstractmethod
import requests
from config import ENABLE_EXPERIMENTAL_REDIS
from exceptions.task_exceptions import SchedulerException
from logger.logger import log
from rq_scheduler import Scheduler
from utils.redis import low_prio_queue
from .exceptions import SchedulerException
from handler.redis_handler import low_prio_queue
tasks_scheduler = Scheduler(queue=low_prio_queue, connection=low_prio_queue.connection)

View File

@@ -3,8 +3,7 @@ from pathlib import Path
from typing import Final
from config import ENABLE_SCHEDULED_UPDATE_MAME_XML, SCHEDULED_UPDATE_MAME_XML_CRON
from .utils import RemoteFilePullTask
from tasks.tasks import RemoteFilePullTask
FIXTURE_FILE_PATH: Final = (
Path(os.path.dirname(__file__)).parent / "handler" / "fixtures" / "mame.xml"

View File

@@ -7,8 +7,7 @@ from config import (
ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB,
SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
)
from .utils import RemoteFilePullTask
from tasks.tasks import RemoteFilePullTask
FIXTURE_FILE_PATH: Final = (
Path(os.path.dirname(__file__)).parent

View File

@@ -1,177 +0,0 @@
import re
import subprocess as sp
import requests
from __version__ import __version__
from logger.logger import log
from packaging.version import InvalidVersion, parse
from requests.exceptions import ReadTimeout
LANGUAGES = [
("Ar", "Arabic"),
("Da", "Danish"),
("De", "German"),
("En", "English"),
("Es", "Spanish"),
("Fi", "Finnish"),
("Fr", "French"),
("It", "Italian"),
("Ja", "Japanese"),
("Ko", "Korean"),
("Nl", "Dutch"),
("No", "Norwegian"),
("Pl", "Polish"),
("Pt", "Portuguese"),
("Ru", "Russian"),
("Sv", "Swedish"),
("Zh", "Chinese"),
("nolang", "No Language"),
]
REGIONS = [
("A", "Australia"),
("AS", "Asia"),
("B", "Brazil"),
("C", "Canada"),
("CH", "China"),
("E", "Europe"),
("F", "France"),
("FN", "Finland"),
("G", "Germany"),
("GR", "Greece"),
("H", "Holland"),
("HK", "Hong Kong"),
("I", "Italy"),
("J", "Japan"),
("K", "Korea"),
("NL", "Netherlands"),
("NO", "Norway"),
("PD", "Public Domain"),
("R", "Russia"),
("S", "Spain"),
("SW", "Sweden"),
("T", "Taiwan"),
("U", "USA"),
("UK", "England"),
("UNK", "Unknown"),
("UNL", "Unlicensed"),
("W", "World"),
]
REGIONS_BY_SHORTCODE = {region[0].lower(): region[1] for region in REGIONS}
REGIONS_NAME_KEYS = [region[1].lower() for region in REGIONS]
LANGUAGES_BY_SHORTCODE = {lang[0].lower(): lang[1] for lang in LANGUAGES}
LANGUAGES_NAME_KEYS = [lang[1].lower() for lang in LANGUAGES]
TAG_REGEX = r"\(([^)]+)\)|\[([^]]+)\]"
EXTENSION_REGEX = r"\.(([a-z]+\.)*\w+)$"
def parse_tags(file_name: str) -> tuple:
rev = ""
regs = []
langs = []
other_tags = []
tags = [tag[0] or tag[1] for tag in re.findall(TAG_REGEX, file_name)]
tags = [tag for subtags in tags for tag in subtags.split(",")]
tags = [tag.strip() for tag in tags]
for tag in tags:
if tag.lower() in REGIONS_BY_SHORTCODE.keys():
regs.append(REGIONS_BY_SHORTCODE[tag.lower()])
continue
if tag.lower() in REGIONS_NAME_KEYS:
regs.append(tag)
continue
if tag.lower() in LANGUAGES_BY_SHORTCODE.keys():
langs.append(LANGUAGES_BY_SHORTCODE[tag.lower()])
continue
if tag.lower() in LANGUAGES_NAME_KEYS:
langs.append(tag)
continue
if "reg" in tag.lower():
match = re.match(r"^reg[\s|-](.*)$", tag, re.IGNORECASE)
if match:
regs.append(
REGIONS_BY_SHORTCODE[match.group(1).lower()]
if match.group(1).lower() in REGIONS_BY_SHORTCODE.keys()
else match.group(1)
)
continue
if "rev" in tag.lower():
match = re.match(r"^rev[\s|-](.*)$", tag, re.IGNORECASE)
if match:
rev = match.group(1)
continue
other_tags.append(tag)
return regs, rev, langs, other_tags
def get_file_name_with_no_extension(file_name: str) -> str:
return re.sub(EXTENSION_REGEX, "", file_name).strip()
def get_file_name_with_no_tags(file_name: str) -> str:
file_name_no_extension = get_file_name_with_no_extension(file_name)
return re.split(TAG_REGEX, file_name_no_extension)[0].strip()
def normalize_search_term(search_term: str) -> str:
return (
search_term.replace("\u2122", "") # Remove trademark symbol
.replace("\u00ae", "") # Remove registered symbol
.replace("\u00a9", "") # Remove copywrite symbol
.replace("\u2120", "") # Remove service mark symbol
.strip() # Remove leading and trailing spaces
)
def get_file_extension(file_name) -> str:
match = re.search(EXTENSION_REGEX, file_name)
return match.group(1) if match else ""
def get_version() -> str | None:
"""Returns current version or branch name."""
if not __version__ == "<version>":
return __version__
else:
try:
output = str(sp.check_output(["git", "branch"], universal_newlines=True))
except sp.CalledProcessError:
return None
branch = [a for a in output.split("\n") if a.find("*") >= 0][0]
return branch[branch.find("*") + 2 :]
def check_new_version() -> str:
"""Check for new RomM versions
Returns:
str: New RomM version or empty if in dev mode
"""
try:
response = requests.get(
"https://api.github.com/repos/zurdi15/romm/releases/latest", timeout=0.5
)
except ReadTimeout:
log.warning("Couldn't check last RomM version.")
return ""
try:
last_version = response.json()["name"][1:] # remove leading 'v' from 'vX.X.X'
except KeyError: # rate limit reached
return ""
try:
if parse(get_version()) < parse(last_version):
return last_version
except InvalidVersion:
pass
return ""

View File

@@ -1,142 +0,0 @@
from config import ROMM_AUTH_ENABLED, ROMM_AUTH_PASSWORD, ROMM_AUTH_USERNAME
from fastapi import HTTPException, Request, status
from fastapi.security.http import HTTPBasic
from handler import dbh
from models.user import Role, User
from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError
from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.requests import HTTPConnection
from starlette.types import Receive, Scope, Send
from starlette_csrf.middleware import CSRFMiddleware
from utils.cache import cache
from .oauth import FULL_SCOPES, get_current_active_user_from_bearer_token
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def authenticate_user(username: str, password: str):
user = dbh.get_user_by_username(username)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def clear_session(req: HTTPConnection | Request):
session_id = req.session.get("session_id")
if session_id:
cache.delete(f"romm:{session_id}") # type: ignore[attr-defined]
req.session["session_id"] = None
async def get_current_active_user_from_session(conn: HTTPConnection):
# Check if session key already stored in cache
session_id = conn.session.get("session_id")
if not session_id:
return None
username = cache.get(f"romm:{session_id}") # type: ignore[attr-defined]
if not username:
return None
# Key exists therefore user is probably authenticated
user = dbh.get_user_by_username(username)
if user is None:
clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User not found",
)
if not user.enabled:
clear_session(conn)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
def create_default_admin_user():
if not ROMM_AUTH_ENABLED:
return
try:
dbh.add_user(
User(
username=ROMM_AUTH_USERNAME,
hashed_password=get_password_hash(ROMM_AUTH_PASSWORD),
role=Role.ADMIN,
)
)
except IntegrityError:
pass
class HybridAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if not ROMM_AUTH_ENABLED:
return (AuthCredentials(FULL_SCOPES), None)
# Check if session key already stored in cache
user = await get_current_active_user_from_session(conn)
if user:
return (AuthCredentials(user.oauth_scopes), user)
# Check if Authorization header exists
if "Authorization" not in conn.headers:
return (AuthCredentials([]), None)
scheme, token = conn.headers["Authorization"].split()
# Check if basic auth header is valid
if scheme.lower() == "basic":
credentials = await HTTPBasic().__call__(conn) # type: ignore[arg-type]
if not credentials:
return (AuthCredentials([]), None)
user = authenticate_user(credentials.username, credentials.password)
if user is None:
return (AuthCredentials([]), None)
return (AuthCredentials(user.oauth_scopes), user)
# Check if bearer auth header is valid
if scheme.lower() == "bearer":
user, payload = await get_current_active_user_from_bearer_token(token)
# Only access tokens can request resources
if payload.get("type") != "access":
return (AuthCredentials([]), None)
# Only grant access to resources with overlapping scopes
token_scopes = set(list(payload.get("scopes").split(" ")))
overlapping_scopes = list(token_scopes & set(user.oauth_scopes))
return (AuthCredentials(overlapping_scopes), user)
return (AuthCredentials([]), None)
class CustomCSRFMiddleware(CSRFMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
await super().__call__(scope, receive, send)

View File

@@ -1,462 +0,0 @@
import datetime
import fnmatch
import os
import shutil
from enum import Enum
from pathlib import Path
from urllib.parse import quote
from PIL import Image
from typing import Final
import requests
from config import (
LIBRARY_BASE_PATH,
ROMM_BASE_PATH,
DEFAULT_URL_COVER_L,
DEFAULT_PATH_COVER_L,
DEFAULT_URL_COVER_S,
DEFAULT_PATH_COVER_S,
)
from config.config_loader import config
from exceptions.fs_exceptions import (
PlatformsNotFoundException,
RomAlreadyExistsException,
RomsNotFoundException,
)
from . import get_file_extension
RESOURCES_BASE_PATH: Final = f"{ROMM_BASE_PATH}/resources"
DEFAULT_WIDTH_COVER_L: Final = 264 # Width of big cover of IGDB
DEFAULT_HEIGHT_COVER_L: Final = 352 # Height of big cover of IGDB
DEFAULT_WIDTH_COVER_S: Final = 90 # Width of small cover of IGDB
DEFAULT_HEIGHT_COVER_S: Final = 120 # Height of small cover of IGDB
# ========= Resources utils =========
class CoverSize(Enum):
SMALL = "small"
BIG = "big"
def _cover_exists(fs_slug: str, rom_name: str, size: CoverSize):
"""Check if rom cover exists in filesystem
Args:
fs_slug: short name of the platform
rom_name: name of rom file
size: size of the cover
Returns
True if cover exists in filesystem else False
"""
return bool(
os.path.exists(
f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/cover/{size.value}.png"
)
)
def _resize_cover(cover_path: str, size: CoverSize) -> None:
"""Resizes the cover image to the standard size
Args:
cover_path: path where the original cover were stored
size: size of the cover
"""
cover = Image.open(cover_path)
if cover.size[1] > DEFAULT_HEIGHT_COVER_L:
if size == CoverSize.BIG:
big_dimensions = (DEFAULT_WIDTH_COVER_L, DEFAULT_HEIGHT_COVER_L)
background = Image.new("RGBA", big_dimensions, (0, 0, 0, 0))
cover.thumbnail(big_dimensions)
offset = (int(round(((DEFAULT_WIDTH_COVER_L - cover.size[0]) / 2), 0)), 0)
elif size == CoverSize.SMALL:
small_dimensions = (DEFAULT_WIDTH_COVER_S, DEFAULT_HEIGHT_COVER_S)
background = Image.new("RGBA", small_dimensions, (0, 0, 0, 0))
cover.thumbnail(small_dimensions)
offset = (int(round(((DEFAULT_WIDTH_COVER_S - cover.size[0]) / 2), 0)), 0)
else:
return
background.paste(cover, offset)
background.save(cover_path)
def _store_cover(fs_slug: str, rom_name: str, url_cover: str, size: CoverSize):
"""Store roms resources in filesystem
Args:
fs_slug: short name of the platform
rom_name: name of rom file
url_cover: url to get the cover
size: size of the cover
"""
cover_file = f"{size.value}.png"
cover_path = f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/cover"
res = requests.get(
url_cover.replace("t_thumb", f"t_cover_{size.value}"), stream=True, timeout=120
)
if res.status_code == 200:
Path(cover_path).mkdir(parents=True, exist_ok=True)
with open(f"{cover_path}/{cover_file}", "wb") as f:
shutil.copyfileobj(res.raw, f)
_resize_cover(f"{cover_path}/{cover_file}", size)
def _get_cover_path(fs_slug: str, rom_name: str, size: CoverSize):
"""Returns rom cover filesystem path adapted to frontend folder structure
Args:
fs_slug: short name of the platform
file_name: name of rom file
size: size of the cover
"""
strtime = str(datetime.datetime.now().timestamp())
return f"{fs_slug}/{rom_name}/cover/{size.value}.png?timestamp={strtime}"
def get_rom_cover(
overwrite: bool, fs_slug: str, rom_name: str, url_cover: str = ""
) -> dict:
q_rom_name = quote(rom_name)
# Cover small
if (
overwrite or not _cover_exists(fs_slug, rom_name, CoverSize.SMALL)
) and url_cover:
_store_cover(fs_slug, rom_name, url_cover, CoverSize.SMALL)
path_cover_s = (
_get_cover_path(fs_slug, q_rom_name, CoverSize.SMALL)
if _cover_exists(fs_slug, rom_name, CoverSize.SMALL)
else DEFAULT_PATH_COVER_S
)
# Cover big
if (overwrite or not _cover_exists(fs_slug, rom_name, CoverSize.BIG)) and url_cover:
_store_cover(fs_slug, rom_name, url_cover, CoverSize.BIG)
path_cover_l = (
_get_cover_path(fs_slug, q_rom_name, CoverSize.BIG)
if _cover_exists(fs_slug, rom_name, CoverSize.BIG)
else DEFAULT_PATH_COVER_L
)
return {
"path_cover_s": path_cover_s,
"path_cover_l": path_cover_l,
}
def _store_screenshot(fs_slug: str, rom_name: str, url: str, idx: int):
"""Store roms resources in filesystem
Args:
fs_slug: short name of the platform
file_name: name of rom
url: url to get the screenshot
"""
screenshot_file: str = f"{idx}.jpg"
screenshot_path: str = f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/screenshots"
res = requests.get(url, stream=True, timeout=120)
if res.status_code == 200:
Path(screenshot_path).mkdir(parents=True, exist_ok=True)
with open(f"{screenshot_path}/{screenshot_file}", "wb") as f:
shutil.copyfileobj(res.raw, f)
def _get_screenshot_path(fs_slug: str, rom_name: str, idx: str):
"""Returns rom cover filesystem path adapted to frontend folder structure
Args:
fs_slug: short name of the platform
file_name: name of rom
idx: index number of screenshot
"""
return f"{fs_slug}/{rom_name}/screenshots/{idx}.jpg"
def get_rom_screenshots(fs_slug: str, rom_name: str, url_screenshots: list) -> dict:
q_rom_name = quote(rom_name)
path_screenshots: list[str] = []
for idx, url in enumerate(url_screenshots):
_store_screenshot(fs_slug, rom_name, url, idx)
path_screenshots.append(_get_screenshot_path(fs_slug, q_rom_name, str(idx)))
return {"path_screenshots": path_screenshots}
def store_default_resources():
"""Store default cover resources in the filesystem"""
defaul_covers = [
{"url": DEFAULT_URL_COVER_L, "size": CoverSize.BIG},
{"url": DEFAULT_URL_COVER_S, "size": CoverSize.SMALL},
]
for cover in defaul_covers:
if not _cover_exists("default", "default", cover["size"]):
_store_cover("default", "default", cover["url"], cover["size"])
# ========= Platforms utils =========
def _exclude_platforms(platforms: list):
return [
platform for platform in platforms if platform not in config.EXCLUDED_PLATFORMS
]
def get_platforms() -> list[str]:
"""Gets all filesystem platforms
Returns list with all the filesystem platforms found in the LIBRARY_BASE_PATH.
Automatically exclude folders defined in user config.
"""
try:
platforms: list[str] = (
list(os.walk(config.HIGH_PRIO_STRUCTURE_PATH))[0][1]
if os.path.exists(config.HIGH_PRIO_STRUCTURE_PATH)
else list(os.walk(LIBRARY_BASE_PATH))[0][1]
)
return _exclude_platforms(platforms)
except IndexError as exc:
raise PlatformsNotFoundException from exc
# ========= Roms utils =========
def get_fs_structure(fs_slug: str, folder: str = config.ROMS_FOLDER_NAME):
return (
f"{folder}/{fs_slug}"
if os.path.exists(config.HIGH_PRIO_STRUCTURE_PATH)
else f"{fs_slug}/{folder}"
)
def _exclude_files(files, filetype) -> list[str]:
excluded_extensions = getattr(config, f"EXCLUDED_{filetype.upper()}_EXT")
excluded_names = getattr(config, f"EXCLUDED_{filetype.upper()}_FILES")
excluded_files: list = []
for file_name in files:
# Split the file name to get the extension.
ext = get_file_extension(file_name)
# Exclude the file if it has no extension or the extension is in the excluded list.
if not ext or ext in excluded_extensions:
excluded_files.append(file_name)
# Additionally, check if the file name mathes a pattern in the excluded list.
if len(excluded_names) > 0:
[
excluded_files.append(file_name)
for name in excluded_names
if file_name == name or fnmatch.fnmatch(file_name, name)
]
# Return files that are not in the filtered list.
return [f for f in files if f not in excluded_files]
def _exclude_multi_roms(roms) -> list[str]:
excluded_names = config.EXCLUDED_MULTI_FILES
filtered_files: list = []
for rom in roms:
if rom in excluded_names:
filtered_files.append(rom)
return [f for f in roms if f not in filtered_files]
def get_rom_files(rom: str, roms_path: str) -> list[str]:
rom_files: list = []
for path, _, files in os.walk(f"{roms_path}/{rom}"):
for f in _exclude_files(files, "multi_parts"):
rom_files.append(f"{Path(path, f)}".replace(f"{roms_path}/{rom}/", ""))
return rom_files
def get_roms(fs_slug: str):
"""Gets all filesystem roms for a platform
Args:
fs_slug: short name of the platform
Returns:
list with all the filesystem roms for a platform found in the LIBRARY_BASE_PATH
"""
roms_path = get_fs_structure(fs_slug)
roms_file_path = f"{LIBRARY_BASE_PATH}/{roms_path}"
try:
fs_single_roms: list[str] = list(os.walk(roms_file_path))[0][2]
except IndexError as exc:
raise RomsNotFoundException(fs_slug) from exc
try:
fs_multi_roms: list[str] = list(os.walk(roms_file_path))[0][1]
except IndexError as exc:
raise RomsNotFoundException(fs_slug) from exc
fs_roms: list[dict] = [
{"multi": False, "file_name": rom}
for rom in _exclude_files(fs_single_roms, "single")
] + [
{"multi": True, "file_name": rom} for rom in _exclude_multi_roms(fs_multi_roms)
]
return [
dict(
rom,
files=get_rom_files(rom["file_name"], roms_file_path),
)
for rom in fs_roms
]
def get_assets(platform_slug: str):
saves_path = get_fs_structure(platform_slug, folder=config.SAVES_FOLDER_NAME)
saves_file_path = f"{LIBRARY_BASE_PATH}/{saves_path}"
fs_saves: list[str] = []
fs_states: list[str] = []
fs_screenshots: list[str] = []
try:
emulators = list(os.walk(saves_file_path))[0][1]
for emulator in emulators:
fs_saves += [
(emulator, file)
for file in list(os.walk(f"{saves_file_path}/{emulator}"))[0][2]
]
fs_saves += [(None, file) for file in list(os.walk(saves_file_path))[0][2]]
except IndexError:
pass
states_path = get_fs_structure(platform_slug, folder=config.STATES_FOLDER_NAME)
states_file_path = f"{LIBRARY_BASE_PATH}/{states_path}"
try:
emulators = list(os.walk(states_file_path))[0][1]
for emulator in emulators:
fs_states += [
(emulator, file)
for file in list(os.walk(f"{states_file_path}/{emulator}"))[0][2]
]
fs_states += [(None, file) for file in list(os.walk(states_file_path))[0][2]]
except IndexError:
pass
screenshots_path = get_fs_structure(
platform_slug, folder=config.SCREENSHOTS_FOLDER_NAME
)
screenshots_file_path = f"{LIBRARY_BASE_PATH}/{screenshots_path}"
try:
fs_screenshots += [file for file in list(os.walk(screenshots_file_path))[0][2]]
except IndexError:
pass
return {
"saves": fs_saves,
"states": fs_states,
"screenshots": fs_screenshots,
}
def get_screenshots():
screenshots_path = f"{LIBRARY_BASE_PATH}/{config.SCREENSHOTS_FOLDER_NAME}"
fs_screenshots = []
try:
platforms = list(os.walk(screenshots_path))[0][1]
for platform in platforms:
fs_screenshots += [
(platform, file)
for file in list(os.walk(f"{screenshots_path}/{platform}"))[0][2]
]
fs_screenshots += [
(None, file) for file in list(os.walk(screenshots_path))[0][2]
]
except IndexError:
pass
return fs_screenshots
def get_rom_file_size(
roms_path: str, file_name: str, multi: bool, multi_files: list = []
):
files = (
[f"{LIBRARY_BASE_PATH}/{roms_path}/{file_name}"]
if not multi
else [
f"{LIBRARY_BASE_PATH}/{roms_path}/{file_name}/{file}"
for file in multi_files
]
)
total_size: float = 0.0
for file in files:
total_size += os.stat(file).st_size
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if total_size < 1024.0 or unit == "PB":
break
total_size /= 1024.0
return round(total_size, 2), unit
def get_fs_file_size(asset_path: str, file_name: str):
return os.stat(f"{LIBRARY_BASE_PATH}/{asset_path}/{file_name}").st_size
def _file_exists(path: str, file_name: str):
"""Check if file exists in filesystem
Args:
path: path to file
file_name: name of file
Returns
True if file exists in filesystem else False
"""
return bool(os.path.exists(f"{LIBRARY_BASE_PATH}/{path}/{file_name}"))
def rename_file(old_name: str, new_name: str, file_path: str):
if new_name != old_name:
if _file_exists(path=file_path, file_name=new_name):
raise RomAlreadyExistsException(new_name)
os.rename(
f"{LIBRARY_BASE_PATH}/{file_path}/{old_name}",
f"{LIBRARY_BASE_PATH}/{file_path}/{new_name}",
)
def remove_file(file_name: str, file_path: str):
try:
os.remove(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
except IsADirectoryError:
shutil.rmtree(f"{LIBRARY_BASE_PATH}/{file_path}/{file_name}")
def build_upload_file_path(fs_slug: str, folder: str = config.ROMS_FOLDER_NAME):
rom_path = get_fs_structure(fs_slug, folder=folder)
return f"{LIBRARY_BASE_PATH}/{rom_path}"
def build_artwork_path(rom_name: str, fs_slug: str, file_ext: str):
q_rom_name = quote(rom_name)
strtime = str(datetime.datetime.now().timestamp())
path_cover_l = f"{fs_slug}/{q_rom_name}/cover/{CoverSize.BIG.value}.{file_ext}?timestamp={strtime}"
path_cover_s = f"{fs_slug}/{q_rom_name}/cover/{CoverSize.SMALL.value}.{file_ext}?timestamp={strtime}"
artwork_path = f"{RESOURCES_BASE_PATH}/{fs_slug}/{rom_name}/cover"
Path(artwork_path).mkdir(parents=True, exist_ok=True)
return path_cover_l, path_cover_s, artwork_path
# ========= Users utils =========
def build_avatar_path(avatar_path: str, username: str):
avatar_user_path = f"{RESOURCES_BASE_PATH}/users/{username}"
Path(avatar_user_path).mkdir(parents=True, exist_ok=True)
return f"users/{username}/{avatar_path}", avatar_user_path

View File

@@ -1,135 +0,0 @@
from datetime import datetime, timedelta
from typing import Any, Final, Optional
from config import ROMM_AUTH_SECRET_KEY
from fastapi import HTTPException, Security, status
from fastapi.param_functions import Form
from fastapi.security.http import HTTPBasic
from fastapi.security.oauth2 import OAuth2PasswordBearer
from fastapi.types import DecoratedCallable
from jose import JWTError, jwt
from starlette.authentication import requires
ALGORITHM: Final = "HS256"
DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15
DEFAULT_SCOPES_MAP: Final = {
"me.read": "View your profile",
"me.write": "Modify your profile",
"roms.read": "View ROMs",
"platforms.read": "View platforms",
"assets.read": "View assets",
}
WRITE_SCOPES_MAP: Final = {
"roms.write": "Modify ROMs",
"platforms.write": "Modify platforms",
"assets.write": "Modify assets",
}
FULL_SCOPES_MAP: Final = {
"users.read": "View users",
"users.write": "Modify users",
"tasks.run": "Run tasks",
}
DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys())
WRITE_SCOPES: Final = DEFAULT_SCOPES + list(WRITE_SCOPES_MAP.keys())
FULL_SCOPES: Final = WRITE_SCOPES + list(FULL_SCOPES_MAP.keys())
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
async def get_current_active_user_from_bearer_token(token: str):
from handler import dbh
try:
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise credentials_exception
username = payload.get("sub")
if username is None:
raise credentials_exception
user = dbh.get_user_by_username(username)
if user is None:
raise credentials_exception
if not user.enabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
)
return user, payload
class OAuth2RequestForm:
def __init__(
self,
grant_type: str = Form(default="password"),
scope: str = Form(default=""),
username: Optional[str] = Form(default=None),
password: Optional[str] = Form(default=None),
client_id: Optional[str] = Form(default=None),
client_secret: Optional[str] = Form(default=None),
refresh_token: Optional[str] = Form(default=None),
):
self.grant_type = grant_type
self.scopes = scope.split()
self.username = username
self.password = password
self.client_id = client_id
self.client_secret = client_secret
self.refresh_token = refresh_token
oauth2_password_bearer = OAuth2PasswordBearer(
tokenUrl="/token",
auto_error=False,
scopes={
**DEFAULT_SCOPES_MAP,
**WRITE_SCOPES_MAP,
**FULL_SCOPES_MAP,
},
)
def protected_route(
method: Any,
path: str,
scopes: list[str] = [],
**kwargs,
):
def decorator(func: DecoratedCallable):
fn = requires(scopes)(func)
return method(
path,
dependencies=[
Security(
dependency=oauth2_password_bearer,
scopes=scopes,
),
Security(dependency=HTTPBasic(auto_error=False)),
],
**kwargs,
)(fn)
return decorator

View File

@@ -1,17 +0,0 @@
from config import REDIS_HOST, REDIS_PASSWORD, REDIS_PORT
from redis import Redis
from rq import Queue
redis_client = Redis(
host=REDIS_HOST, port=int(REDIS_PORT), password=REDIS_PASSWORD, db=0
)
redis_url = (
f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}"
if REDIS_PASSWORD
else f"redis://{REDIS_HOST}:{REDIS_PORT}"
)
high_prio_queue = Queue(name="high", connection=redis_client)
default_queue = Queue(name="default", connection=redis_client)
low_prio_queue = Queue(name="low", connection=redis_client)

View File

@@ -1,15 +0,0 @@
import socketio # type: ignore
from config import ENABLE_EXPERIMENTAL_REDIS
from utils.redis import redis_url
socket_server = socketio.AsyncServer(
cors_allowed_origins="*",
async_mode="asgi",
logger=False,
engineio_logger=False,
client_manager=socketio.AsyncRedisManager(redis_url)
if ENABLE_EXPERIMENTAL_REDIS
else None,
)
socket_app = socketio.ASGIApp(socket_server)

View File

@@ -1,7 +1,7 @@
import pytest
from unittest.mock import patch
from ..fs import (
from ...handler.fs_handler.roms_handler import (
get_rom_cover,
get_platforms,
get_fs_structure,

View File

@@ -1,7 +1,7 @@
from utils import (
parse_tags,
get_file_name_with_no_tags as gfnwt,
get_file_extension as gfe,
parse_file_extension as gfe,
)

View File

@@ -1,7 +1,5 @@
import os
from datetime import timedelta
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from config import (
ENABLE_RESCAN_ON_FILESYSTEM_CHANGE,
@@ -11,7 +9,9 @@ from config import (
from config.config_loader import config
from endpoints.sockets.scan import scan_platforms
from logger.logger import log
from tasks.utils import tasks_scheduler
from tasks.tasks import tasks_scheduler
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
path = (
config.HIGH_PRIO_STRUCTURE_PATH

View File

@@ -2,7 +2,7 @@ import sys
from config import ENABLE_EXPERIMENTAL_REDIS
from rq import Connection, Queue, Worker
from utils.redis import redis_client
from handler.redis_handler import redis_client
listen = ["high", "default", "low"]

View File

@@ -12,8 +12,7 @@ export type EnhancedRomSchema = {
id: number;
igdb_id: (number | null);
sgdb_id: (number | null);
platform_slug: string;
platform_name: string;
platform_id: number;
file_name: string;
file_name_no_tags: string;
file_extension: string;

View File

@@ -4,6 +4,7 @@
/* eslint-disable */
export type PlatformSchema = {
id: number;
slug: string;
fs_slug: string;
igdb_id?: (number | null);

View File

@@ -11,8 +11,7 @@ export type RomSchema = {
id: number;
igdb_id: (number | null);
sgdb_id: (number | null);
platform_slug: string;
platform_name: string;
platform_id: number;
file_name: string;
file_name_no_tags: string;
file_extension: string;

View File

@@ -13,7 +13,6 @@ export type SaveSchema = {
full_path: string;
download_path: string;
rom_id: number;
platform_slug: string;
emulator: (string | null);
};

View File

@@ -13,6 +13,5 @@ export type ScreenshotSchema = {
full_path: string;
download_path: string;
rom_id: number;
platform_slug: (string | null);
};

View File

@@ -13,7 +13,6 @@ export type StateSchema = {
full_path: string;
download_path: string;
rom_id: number;
platform_slug: string;
emulator: (string | null);
};

View File

@@ -1,10 +1,9 @@
<script setup lang="ts">
import { useDisplay } from "vuetify";
import PlatformIcon from "@/components/Platform/PlatformIcon.vue";
import { regionToEmoji, languageToEmoji } from "@/utils";
import type { Rom } from "@/stores/roms";
import type { RomSchema, PlatformSchema } from "@/__generated__/"
defineProps<{ rom: Rom }>();
defineProps<{ rom: RomSchema, platform: PlatformSchema }>();
</script>
<template>
<v-row class="text-white text-shadow" no-gutters>
@@ -17,11 +16,11 @@ defineProps<{ rom: Rom }>();
<v-col cols="12">
<v-chip
class="font-italic px-3 my-2"
:to="`/platform/${rom.platform_slug}`"
:to="`/platform/${platform.slug}`"
>
{{ rom.platform_name || rom.platform_slug }}
{{ platform.name }}
<v-avatar :rounded="0" size="40" class="ml-2 pa-2">
<platform-icon :platform="rom.platform_slug"></platform-icon>
<platform-icon :platform="platform.slug"></platform-icon>
</v-avatar>
</v-chip>
<v-chip

View File

@@ -1,11 +1,11 @@
<script setup lang="ts">
import { ref, watch } from "vue";
import { useRouter } from "vue-router";
import { regionToEmoji, languageToEmoji } from "@/utils";
import type { PlatformSchema, RomSchema } from "@/__generated__";
import type { Rom } from "@/stores/roms";
import type { RomSchema } from "@/__generated__";
import { languageToEmoji, regionToEmoji } from "@/utils";
import { ref } from "vue";
import { useRouter } from "vue-router";
const props = defineProps<{ rom: Rom }>();
const props = defineProps<{ rom: Rom; platform: PlatformSchema }>();
const router = useRouter();
const version = ref(props.rom.id);
@@ -18,7 +18,8 @@ function formatItem(rom: RomSchema) {
function updateVersion() {
router.push({
path: `/platform/${props.rom.platform_slug}/${version.value}`,
name: "rom",
params: { platform: props.platform.slug, rom: version.value },
});
}
</script>

View File

@@ -18,12 +18,12 @@ const routes = [
component: () => import("@/views/Dashboard/Base.vue"),
},
{
path: "/platform/:platform",
path: "/gallery/:platform",
name: "platform",
component: () => import("@/views/Gallery/Base.vue"),
},
{
path: "/platform/:platform/:rom",
path: "/gallery/:platform/:rom",
name: "rom",
component: () => import("@/views/Details/Base.vue"),
},

View File

@@ -8,7 +8,7 @@ import type { Events } from "@/types/emitter";
import api from "@/services/api";
import storeRoms, { type Rom } from "@/stores/roms";
import BackgroundHeader from "@/components/Details/BackgroundHeader.vue";
import TitleInfo from "@/components/Details/Title.vue";
import Title from "@/components/Details/Title.vue";
import Cover from "@/components/Details/Cover.vue";
import ActionBar from "@/components/Details/ActionBar.vue";
import DetailsInfo from "@/components/Details/DetailsInfo.vue";
@@ -20,11 +20,11 @@ import EditRomDialog from "@/components/Dialog/Rom/EditRom.vue";
import DeleteRomDialog from "@/components/Dialog/Rom/DeleteRom.vue";
import LoadingDialog from "@/components/Dialog/Loading.vue";
import DeleteAssetDialog from "@/components/Details/DeleteAssets.vue";
import type { EnhancedRomSchema } from "@/__generated__";
import type { PlatformSchema, EnhancedRomSchema } from "@/__generated__";
const route = useRoute();
const romsStore = storeRoms();
const rom = ref<EnhancedRomSchema | null>(null);
const rom = ref<EnhancedRomSchema>();
const platform = ref<PlatformSchema>();
const tab = ref<"details" | "saves" | "screenshots">("details");
const { smAndDown, mdAndUp } = useDisplay();
const emitter = inject<Emitter<Events>>("emitter");
@@ -36,7 +36,6 @@ async function fetchRom() {
.fetchRom({ romId: parseInt(route.params.rom as string) })
.then((response) => {
rom.value = response.data;
romsStore.update(response.data);
})
.catch((error) => {
console.log(error);
@@ -108,7 +107,7 @@ watch(
>
<v-row :class="{ 'position-absolute title-lg mr-16': mdAndUp, 'justify-center': smAndDown }" no-gutters>
<v-col cols="12">
<title-info :rom="rom" />
<title :rom="rom" :platform="platform" />
</v-col>
</v-row>
<v-row

View File

@@ -16,11 +16,10 @@ const { mdAndDown } = useDisplay();
const platforms = storePlatforms();
const scanning = storeScanning();
const auth = storeAuth();
const refreshDrawer = ref(false);
// Event listeners bus
const emitter = inject<Emitter<Events>>("emitter");
emitter?.on("refreshDrawer", async () => {
emitter?.on("refreshDrawer", async () => {
const { data: platformData } = await api.fetchPlatforms();
platforms.set(platformData);
});

View File

@@ -169,14 +169,14 @@ onBeforeUnmount(() => {
<platform-icon :platform="platform.slug"></platform-icon>
</v-avatar>
<span class="text-body-2 ml-5"> {{ platform.name }}</span>
<v-list-item v-for="rom in platform.roms" class="text-body-2" disabled>
<!-- <v-list-item v-for="rom in platform.roms" class="text-body-2" disabled>
<span v-if="rom.igdb_id" class="ml-10">
Identified <b>{{ rom.name }} 👾</b>
</span>
<span v-else class="ml-10">
{{ rom.file_name }} not found in IGDB
</span>
</v-list-item>
</v-list-item> -->
</v-col>
</v-row>
</template>

View File

@@ -1,11 +1,11 @@
<script setup lang="ts">
import { ref, inject, onBeforeMount } from "vue";
import { useRouter } from "vue-router";
import type { Emitter } from "mitt";
import type { Events } from "@/types/emitter";
import type { Emitter } from "mitt";
import { inject, onBeforeMount, ref } from "vue";
import { useRouter } from "vue-router";
import storeAuth from "@/stores/auth";
import { api } from "@/services/api";
import storeAuth from "@/stores/auth";
// Props
const auth = storeAuth();
@@ -57,7 +57,7 @@ function login() {
onBeforeMount(async () => {
// Check if authentication is enabled
if (!auth.enabled) {
return router.push("/");
return router.push({"name": "dashboard"});
}
});
</script>