diff --git a/backend/handler/auth/middleware/csrf_middleware.py b/backend/handler/auth/middleware/csrf_middleware.py new file mode 100644 index 000000000..d6a04466e --- /dev/null +++ b/backend/handler/auth/middleware/csrf_middleware.py @@ -0,0 +1,159 @@ +import functools +import http.cookies +import secrets +from re import Pattern +from typing import Optional, cast + +from itsdangerous import BadSignature +from itsdangerous.url_safe import URLSafeSerializer +from starlette.datastructures import URL, MutableHeaders +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class CSRFMiddleware: + def __init__( + self, + app: ASGIApp, + secret: str, + *, + required_urls: Optional[list[Pattern]] = None, + exempt_urls: Optional[list[Pattern]] = None, + sensitive_cookies: Optional[set[str]] = None, + safe_methods: set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"}, + cookie_name: str = "csrftoken", + cookie_path: str = "/", + cookie_domain: Optional[str] = None, + cookie_secure: bool = False, + cookie_httponly: bool = False, + cookie_samesite: str = "lax", + header_name: str = "x-csrftoken", + ) -> None: + self.app = app + self.serializer = URLSafeSerializer(secret, "csrftoken") + self.secret = secret + self.required_urls = required_urls + self.exempt_urls = exempt_urls + self.sensitive_cookies = sensitive_cookies + self.safe_methods = safe_methods + self.cookie_name = cookie_name + self.cookie_path = cookie_path + self.cookie_domain = cookie_domain + self.cookie_secure = cookie_secure + self.cookie_httponly = cookie_httponly + self.cookie_samesite = cookie_samesite + self.header_name = header_name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # Skip CSRF check if not an HTTP request, like websockets + if scope["type"] != "http": + await self.app(scope, receive, send) + return None + + request = Request(scope, receive) + + # Skip CSRF check if Authorization header is present + auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower() + if auth_scheme == "bearer" or auth_scheme == "basic": + await self.app(scope, receive, send) + return None + + csrf_cookie = request.cookies.get(self.cookie_name) + + if self._url_is_required(request.url) or ( + request.method not in self.safe_methods + and not self._url_is_exempt(request.url) + and self._has_sensitive_cookies(request.cookies) + ): + submitted_csrf_token = await self._get_submitted_csrf_token(request) + if ( + not csrf_cookie + or not submitted_csrf_token + or not self._csrf_tokens_match( + csrf_cookie, submitted_csrf_token, request.user.id + ) + ): + response = self._get_error_response(request) + await response(scope, receive, send) + return + + send = functools.partial(self.send, send=send, scope=scope) + await self.app(scope, receive, send) + + async def send(self, message: Message, send: Send, scope: Scope) -> None: + request = Request(scope) + csrf_cookie = request.cookies.get(self.cookie_name) + + if csrf_cookie is None: + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + + cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie() + cookie_name = self.cookie_name + cookie[cookie_name] = self._generate_csrf_token(request.user.id) + cookie[cookie_name]["path"] = self.cookie_path + cookie[cookie_name]["secure"] = self.cookie_secure + cookie[cookie_name]["httponly"] = self.cookie_httponly + cookie[cookie_name]["samesite"] = self.cookie_samesite + if self.cookie_domain is not None: + cookie[cookie_name]["domain"] = self.cookie_domain # pragma: no cover + headers.append("set-cookie", cookie.output(header="").strip()) + + await send(message) + + def _has_sensitive_cookies(self, cookies: dict[str, str]) -> bool: + if not self.sensitive_cookies: + return True + for sensitive_cookie in self.sensitive_cookies: + if sensitive_cookie in cookies: + return True + return False + + def _url_is_required(self, url: URL) -> bool: + if not self.required_urls: + return False + for required_url in self.required_urls: + if required_url.match(url.path): + return True + return False + + def _url_is_exempt(self, url: URL) -> bool: + if not self.exempt_urls: + return False + for exempt_url in self.exempt_urls: + if exempt_url.match(url.path): + return True + return False + + async def _get_submitted_csrf_token(self, request: Request) -> Optional[str]: + return request.headers.get(self.header_name) + + def _generate_csrf_token(self, user_id: int | None = None) -> str: + obj = {"token": secrets.token_urlsafe(128), "user_id": user_id} + return cast(str, self.serializer.dumps(obj)) + + def _csrf_tokens_match( + self, document_cookie: str, header_cookie: str, user_id: str | None + ) -> bool: + try: + decoded_doc_cookie: str = self.serializer.loads(document_cookie) + decoded_header_cookie: str = self.serializer.loads(header_cookie) + + # Verify that the tokens match, the user IDs match + # and the user_id matches the authenticated user + return ( + secrets.compare_digest( + decoded_doc_cookie["token"], decoded_doc_cookie["token"] + ) + and decoded_header_cookie["user_id"] == decoded_header_cookie["user_id"] + and decoded_doc_cookie["user_id"] == user_id + and decoded_header_cookie["user_id"] == user_id + ) + except BadSignature: + return False + + def _get_error_response(self, request: Request) -> Response: + return PlainTextResponse( + content="CSRF token verification failed", status_code=403 + ) diff --git a/backend/handler/auth/middleware.py b/backend/handler/auth/middleware/session_middleware.py similarity index 86% rename from backend/handler/auth/middleware.py rename to backend/handler/auth/middleware/session_middleware.py index efc624189..032c104d7 100644 --- a/backend/handler/auth/middleware.py +++ b/backend/handler/auth/middleware/session_middleware.py @@ -5,31 +5,11 @@ from joserfc import jwt from joserfc.errors import BadSignatureError from joserfc.jwk import OctKey from starlette.datastructures import MutableHeaders, Secret -from starlette.requests import HTTPConnection, Request +from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Message, Receive, Scope, Send -from starlette_csrf.middleware import CSRFMiddleware from config import SESSION_MAX_AGE_SECONDS - -class CustomCSRFMiddleware(CSRFMiddleware): - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - # Skip CSRF check if not an HTTP request, like websockets - if scope["type"] != "http": - await self.app(scope, receive, send) - return None - - request = Request(scope, receive) - - # Skip CSRF check if Authorization header is present - auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower() - if auth_scheme == "bearer" or auth_scheme == "basic": - await self.app(scope, receive, send) - return None - - await super().__call__(scope, receive, send) - - SecretKey = namedtuple("SecretKey", ("encode", "decode")) diff --git a/backend/main.py b/backend/main.py index b81478723..0331b78ad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -44,7 +44,8 @@ from endpoints import ( ) from handler.auth.constants import ALGORITHM from handler.auth.hybrid_auth import HybridAuthBackend -from handler.auth.middleware import CustomCSRFMiddleware, SessionMiddleware +from handler.auth.middleware.csrf_middleware import CSRFMiddleware +from handler.auth.middleware.session_middleware import SessionMiddleware from handler.socket_handler import socket_handler from logger.formatter import LOGGING_CONFIG from utils import get_version @@ -90,7 +91,7 @@ app.add_middleware( if not IS_PYTEST_RUN and not DISABLE_CSRF_PROTECTION: # CSRF protection (except endpoints listed in exempt_urls) app.add_middleware( - CustomCSRFMiddleware, + CSRFMiddleware, cookie_name="romm_csrftoken", secret=ROMM_AUTH_SECRET_KEY, exempt_urls=[re.compile(r"^/api/token.*"), re.compile(r"^/ws")], diff --git a/pyproject.toml b/pyproject.toml index dd7e30687..c66d66f1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1", "gunicorn ~= 23.0", "httpx ~= 0.27", + "itsdangerous>=2.2.0", "joserfc ~= 1.3.4", "opentelemetry-distro ~= 0.56", "opentelemetry-exporter-otlp ~= 1.36", @@ -46,7 +47,6 @@ dependencies = [ "rq-scheduler @ git+https://github.com/adamantike/rq-scheduler.git@feat/script-options-username-ssl", "sentry-sdk ~= 2.32", "starlette ~= 0.49", - "starlette-csrf ~= 3.0", "streaming-form-data ~= 1.19", "strsimpy ~= 0.2", "types-colorama ~= 0.4", diff --git a/uv.lock b/uv.lock index 3991e9aa1..4083570ab 100644 --- a/uv.lock +++ b/uv.lock @@ -1911,6 +1911,7 @@ dependencies = [ { name = "fastapi-pagination", extra = ["sqlalchemy"] }, { name = "gunicorn" }, { name = "httpx" }, + { name = "itsdangerous" }, { name = "joserfc" }, { name = "opentelemetry-distro" }, { name = "opentelemetry-exporter-otlp" }, @@ -1933,7 +1934,6 @@ dependencies = [ { name = "sentry-sdk" }, { name = "sqlalchemy", extra = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"] }, { name = "starlette" }, - { name = "starlette-csrf" }, { name = "streaming-form-data" }, { name = "strsimpy" }, { name = "types-colorama" }, @@ -1982,6 +1982,7 @@ requires-dist = [ { name = "httpx", specifier = "~=0.27" }, { name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" }, + { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "joserfc", specifier = "~=1.3.4" }, { name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" }, { name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" }, @@ -2013,7 +2014,6 @@ requires-dist = [ { name = "sentry-sdk", specifier = "~=2.32" }, { name = "sqlalchemy", extras = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"], specifier = "~=2.0" }, { name = "starlette", specifier = "~=0.49" }, - { name = "starlette-csrf", specifier = "~=3.0" }, { name = "streaming-form-data", specifier = "~=1.19" }, { name = "strsimpy", specifier = "~=0.2" }, { name = "types-colorama", specifier = "~=0.4" }, @@ -2198,19 +2198,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, ] -[[package]] -name = "starlette-csrf" -version = "3.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "itsdangerous" }, - { name = "starlette" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0f/7c/53c57b4cd76c9a4493a8525d34a76d7e4bbe0ff957de1c53f30241aa757a/starlette_csrf-3.0.0.tar.gz", hash = "sha256:7afaca8c72cc3c726e5942778af53454607ca3e653fd86cd75ee35d8cd1cfa77", size = 8371, upload-time = "2023-06-27T13:23:24.387Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/83/6641e4fdcf33b1cc614a74ecabe5835236a1b2564bf6735db7e35d788795/starlette_csrf-3.0.0-py3-none-any.whl", hash = "sha256:aac29b366e83621d3fc56be690866e16f3c56df91ab5e184b77950540a4e2761", size = 6170, upload-time = "2023-06-27T13:23:25.563Z" }, -] - [[package]] name = "streaming-form-data" version = "1.19.1"