implement csrf middleware directly in repo

This commit is contained in:
Georges-Antoine Assi
2025-11-17 17:53:17 -05:00
parent 7333e5b5bd
commit 551ff72a8a
5 changed files with 166 additions and 39 deletions

View File

@@ -0,0 +1,159 @@
import functools
import http.cookies
import secrets
from re import Pattern
from typing import Optional, cast
from itsdangerous import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
from starlette.datastructures import URL, MutableHeaders
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
class CSRFMiddleware:
def __init__(
self,
app: ASGIApp,
secret: str,
*,
required_urls: Optional[list[Pattern]] = None,
exempt_urls: Optional[list[Pattern]] = None,
sensitive_cookies: Optional[set[str]] = None,
safe_methods: set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"},
cookie_name: str = "csrftoken",
cookie_path: str = "/",
cookie_domain: Optional[str] = None,
cookie_secure: bool = False,
cookie_httponly: bool = False,
cookie_samesite: str = "lax",
header_name: str = "x-csrftoken",
) -> None:
self.app = app
self.serializer = URLSafeSerializer(secret, "csrftoken")
self.secret = secret
self.required_urls = required_urls
self.exempt_urls = exempt_urls
self.sensitive_cookies = sensitive_cookies
self.safe_methods = safe_methods
self.cookie_name = cookie_name
self.cookie_path = cookie_path
self.cookie_domain = cookie_domain
self.cookie_secure = cookie_secure
self.cookie_httponly = cookie_httponly
self.cookie_samesite = cookie_samesite
self.header_name = header_name
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Skip CSRF check if not an HTTP request, like websockets
if scope["type"] != "http":
await self.app(scope, receive, send)
return None
request = Request(scope, receive)
# Skip CSRF check if Authorization header is present
auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower()
if auth_scheme == "bearer" or auth_scheme == "basic":
await self.app(scope, receive, send)
return None
csrf_cookie = request.cookies.get(self.cookie_name)
if self._url_is_required(request.url) or (
request.method not in self.safe_methods
and not self._url_is_exempt(request.url)
and self._has_sensitive_cookies(request.cookies)
):
submitted_csrf_token = await self._get_submitted_csrf_token(request)
if (
not csrf_cookie
or not submitted_csrf_token
or not self._csrf_tokens_match(
csrf_cookie, submitted_csrf_token, request.user.id
)
):
response = self._get_error_response(request)
await response(scope, receive, send)
return
send = functools.partial(self.send, send=send, scope=scope)
await self.app(scope, receive, send)
async def send(self, message: Message, send: Send, scope: Scope) -> None:
request = Request(scope)
csrf_cookie = request.cookies.get(self.cookie_name)
if csrf_cookie is None:
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie()
cookie_name = self.cookie_name
cookie[cookie_name] = self._generate_csrf_token(request.user.id)
cookie[cookie_name]["path"] = self.cookie_path
cookie[cookie_name]["secure"] = self.cookie_secure
cookie[cookie_name]["httponly"] = self.cookie_httponly
cookie[cookie_name]["samesite"] = self.cookie_samesite
if self.cookie_domain is not None:
cookie[cookie_name]["domain"] = self.cookie_domain # pragma: no cover
headers.append("set-cookie", cookie.output(header="").strip())
await send(message)
def _has_sensitive_cookies(self, cookies: dict[str, str]) -> bool:
if not self.sensitive_cookies:
return True
for sensitive_cookie in self.sensitive_cookies:
if sensitive_cookie in cookies:
return True
return False
def _url_is_required(self, url: URL) -> bool:
if not self.required_urls:
return False
for required_url in self.required_urls:
if required_url.match(url.path):
return True
return False
def _url_is_exempt(self, url: URL) -> bool:
if not self.exempt_urls:
return False
for exempt_url in self.exempt_urls:
if exempt_url.match(url.path):
return True
return False
async def _get_submitted_csrf_token(self, request: Request) -> Optional[str]:
return request.headers.get(self.header_name)
def _generate_csrf_token(self, user_id: int | None = None) -> str:
obj = {"token": secrets.token_urlsafe(128), "user_id": user_id}
return cast(str, self.serializer.dumps(obj))
def _csrf_tokens_match(
self, document_cookie: str, header_cookie: str, user_id: str | None
) -> bool:
try:
decoded_doc_cookie: str = self.serializer.loads(document_cookie)
decoded_header_cookie: str = self.serializer.loads(header_cookie)
# Verify that the tokens match, the user IDs match
# and the user_id matches the authenticated user
return (
secrets.compare_digest(
decoded_doc_cookie["token"], decoded_doc_cookie["token"]
)
and decoded_header_cookie["user_id"] == decoded_header_cookie["user_id"]
and decoded_doc_cookie["user_id"] == user_id
and decoded_header_cookie["user_id"] == user_id
)
except BadSignature:
return False
def _get_error_response(self, request: Request) -> Response:
return PlainTextResponse(
content="CSRF token verification failed", status_code=403
)

View File

@@ -5,31 +5,11 @@ from joserfc import jwt
from joserfc.errors import BadSignatureError
from joserfc.jwk import OctKey
from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection, Request
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette_csrf.middleware import CSRFMiddleware
from config import SESSION_MAX_AGE_SECONDS
class CustomCSRFMiddleware(CSRFMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Skip CSRF check if not an HTTP request, like websockets
if scope["type"] != "http":
await self.app(scope, receive, send)
return None
request = Request(scope, receive)
# Skip CSRF check if Authorization header is present
auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower()
if auth_scheme == "bearer" or auth_scheme == "basic":
await self.app(scope, receive, send)
return None
await super().__call__(scope, receive, send)
SecretKey = namedtuple("SecretKey", ("encode", "decode"))

View File

@@ -44,7 +44,8 @@ from endpoints import (
)
from handler.auth.constants import ALGORITHM
from handler.auth.hybrid_auth import HybridAuthBackend
from handler.auth.middleware import CustomCSRFMiddleware, SessionMiddleware
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
from handler.auth.middleware.session_middleware import SessionMiddleware
from handler.socket_handler import socket_handler
from logger.formatter import LOGGING_CONFIG
from utils import get_version
@@ -90,7 +91,7 @@ app.add_middleware(
if not IS_PYTEST_RUN and not DISABLE_CSRF_PROTECTION:
# CSRF protection (except endpoints listed in exempt_urls)
app.add_middleware(
CustomCSRFMiddleware,
CSRFMiddleware,
cookie_name="romm_csrftoken",
secret=ROMM_AUTH_SECRET_KEY,
exempt_urls=[re.compile(r"^/api/token.*"), re.compile(r"^/ws")],

View File

@@ -24,6 +24,7 @@ dependencies = [
"fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1",
"gunicorn ~= 23.0",
"httpx ~= 0.27",
"itsdangerous>=2.2.0",
"joserfc ~= 1.3.4",
"opentelemetry-distro ~= 0.56",
"opentelemetry-exporter-otlp ~= 1.36",
@@ -46,7 +47,6 @@ dependencies = [
"rq-scheduler @ git+https://github.com/adamantike/rq-scheduler.git@feat/script-options-username-ssl",
"sentry-sdk ~= 2.32",
"starlette ~= 0.49",
"starlette-csrf ~= 3.0",
"streaming-form-data ~= 1.19",
"strsimpy ~= 0.2",
"types-colorama ~= 0.4",

17
uv.lock generated
View File

@@ -1911,6 +1911,7 @@ dependencies = [
{ name = "fastapi-pagination", extra = ["sqlalchemy"] },
{ name = "gunicorn" },
{ name = "httpx" },
{ name = "itsdangerous" },
{ name = "joserfc" },
{ name = "opentelemetry-distro" },
{ name = "opentelemetry-exporter-otlp" },
@@ -1933,7 +1934,6 @@ dependencies = [
{ name = "sentry-sdk" },
{ name = "sqlalchemy", extra = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"] },
{ name = "starlette" },
{ name = "starlette-csrf" },
{ name = "streaming-form-data" },
{ name = "strsimpy" },
{ name = "types-colorama" },
@@ -1982,6 +1982,7 @@ requires-dist = [
{ name = "httpx", specifier = "~=0.27" },
{ name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" },
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "joserfc", specifier = "~=1.3.4" },
{ name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" },
{ name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" },
@@ -2013,7 +2014,6 @@ requires-dist = [
{ name = "sentry-sdk", specifier = "~=2.32" },
{ name = "sqlalchemy", extras = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"], specifier = "~=2.0" },
{ name = "starlette", specifier = "~=0.49" },
{ name = "starlette-csrf", specifier = "~=3.0" },
{ name = "streaming-form-data", specifier = "~=1.19" },
{ name = "strsimpy", specifier = "~=0.2" },
{ name = "types-colorama", specifier = "~=0.4" },
@@ -2198,19 +2198,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" },
]
[[package]]
name = "starlette-csrf"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "itsdangerous" },
{ name = "starlette" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0f/7c/53c57b4cd76c9a4493a8525d34a76d7e4bbe0ff957de1c53f30241aa757a/starlette_csrf-3.0.0.tar.gz", hash = "sha256:7afaca8c72cc3c726e5942778af53454607ca3e653fd86cd75ee35d8cd1cfa77", size = 8371, upload-time = "2023-06-27T13:23:24.387Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b9/83/6641e4fdcf33b1cc614a74ecabe5835236a1b2564bf6735db7e35d788795/starlette_csrf-3.0.0-py3-none-any.whl", hash = "sha256:aac29b366e83621d3fc56be690866e16f3c56df91ab5e184b77950540a4e2761", size = 6170, upload-time = "2023-06-27T13:23:25.563Z" },
]
[[package]]
name = "streaming-form-data"
version = "1.19.1"