[ROMM-2628] Fix desirialize job func_name

This commit is contained in:
Georges-Antoine Assi
2025-11-10 17:57:09 -05:00
parent 6474a031b0
commit b2dea510c4
5 changed files with 32 additions and 10 deletions

View File

@@ -28,7 +28,7 @@ from handler.filesystem import (
)
from handler.filesystem.roms_handler import FSRom
from handler.metadata.ss_handler import get_preferred_media_types
from handler.redis_handler import high_prio_queue, redis_client
from handler.redis_handler import get_job_func_name, high_prio_queue, redis_client
from handler.scan_handler import (
ScanType,
scan_firmware,
@@ -699,7 +699,7 @@ async def stop_scan_handler(_sid: str):
existing_jobs = high_prio_queue.get_jobs()
for job in existing_jobs:
if job.func_name == "scan_platform" and job.is_started:
if get_job_func_name(job) == "scan_platform" and job.is_started:
return await cancel_job(job)
workers = Worker.all(connection=redis_client)
@@ -707,7 +707,8 @@ async def stop_scan_handler(_sid: str):
current_job = worker.get_current_job()
if (
current_job
and current_job.func_name == "endpoints.sockets.scan.scan_platforms"
and get_job_func_name(current_job)
== "endpoints.sockets.scan.scan_platforms"
and current_job.is_started
):
return await cancel_job(current_job)

View File

@@ -27,6 +27,7 @@ from endpoints.responses.tasks import GroupedTasksDict, TaskInfo
from handler.auth.constants import Scope
from handler.redis_handler import (
default_queue,
get_job_func_name,
high_prio_queue,
low_prio_queue,
redis_client,
@@ -117,8 +118,8 @@ def _build_task_status_response(
job: Job,
) -> TaskStatusResponse:
job_meta = job.get_meta()
task_name = job_meta.get("task_name") or job.func_name
task_type = job_meta.get("task_type")
task_name = job_meta.get("task_name") or get_job_func_name(job)
# Convert datetime objects to ISO format strings
created_at = job.created_at.isoformat() if job.created_at else None

View File

@@ -5,6 +5,8 @@ from enum import Enum
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from rq import Queue
from rq.exceptions import DeserializationError
from rq.job import Job
from config import IS_PYTEST_RUN, REDIS_URL
from logger.logger import log
@@ -55,3 +57,20 @@ def __get_async_cache() -> AsyncRedis:
sync_cache = __get_sync_cache()
async_cache = __get_async_cache()
def get_job_func_name(job: Job, fallback: str = "") -> str | None:
"""Safely get the function name from an RQ job, handling DeserializationError.
Args:
job: The RQ Job object to get the function name from
fallback: The value to return if deserialization fails (default: "unknown_task")
Returns:
The function name if available, otherwise the fallback value
"""
try:
return job.func_name
except DeserializationError:
# Job data cannot be deserialized (e.g., function no longer exists)
return fallback

View File

@@ -9,7 +9,7 @@ from rq_scheduler import Scheduler
from config import TASK_TIMEOUT
from exceptions.task_exceptions import SchedulerException
from handler.redis_handler import low_prio_queue
from handler.redis_handler import get_job_func_name, low_prio_queue
from logger.logger import log
from utils.context import ctx_httpx_client
@@ -79,7 +79,7 @@ class PeriodicTask(Task, ABC):
def _get_existing_job(self) -> Job | None:
existing_jobs = tasks_scheduler.get_jobs()
for job in existing_jobs:
if isinstance(job, Job) and job.func_name == self.func:
if isinstance(job, Job) and get_job_func_name(job) == self.func:
return job
return None

View File

@@ -33,7 +33,7 @@ from handler.metadata import (
meta_ss_handler,
meta_tgdb_handler,
)
from handler.redis_handler import low_prio_queue, redis_client
from handler.redis_handler import get_job_func_name, low_prio_queue, redis_client
from handler.scan_handler import MetadataSource, ScanType
from logger.formatter import CYAN
from logger.formatter import highlight as hl
@@ -82,7 +82,7 @@ def get_pending_scan_jobs() -> list[Job]:
for job in scheduled_jobs:
if (
isinstance(job, Job)
and job.func_name == "endpoints.sockets.scan.scan_platforms"
and get_job_func_name(job) == "endpoints.sockets.scan.scan_platforms"
and job.get_status()
in [JobStatus.SCHEDULED, JobStatus.QUEUED, JobStatus.STARTED]
):
@@ -93,7 +93,7 @@ def get_pending_scan_jobs() -> list[Job]:
for job in queue_jobs:
if (
isinstance(job, Job)
and job.func_name == "endpoints.sockets.scan.scan_platforms"
and get_job_func_name(job) == "endpoints.sockets.scan.scan_platforms"
and job.get_status() in [JobStatus.QUEUED, JobStatus.STARTED]
):
pending_jobs.append(job)
@@ -104,7 +104,8 @@ def get_pending_scan_jobs() -> list[Job]:
current_job = worker.get_current_job()
if (
current_job
and current_job.func_name == "endpoints.sockets.scan.scan_platforms"
and get_job_func_name(current_job)
== "endpoints.sockets.scan.scan_platforms"
and current_job.get_status() == JobStatus.STARTED
):
pending_jobs.append(current_job)