mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 00:27:41 +01:00
implement csrf middleware directly in repo
This commit is contained in:
159
backend/handler/auth/middleware/csrf_middleware.py
Normal file
159
backend/handler/auth/middleware/csrf_middleware.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import functools
|
||||
import http.cookies
|
||||
import secrets
|
||||
from re import Pattern
|
||||
from typing import Optional, cast
|
||||
|
||||
from itsdangerous import BadSignature
|
||||
from itsdangerous.url_safe import URLSafeSerializer
|
||||
from starlette.datastructures import URL, MutableHeaders
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
class CSRFMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
secret: str,
|
||||
*,
|
||||
required_urls: Optional[list[Pattern]] = None,
|
||||
exempt_urls: Optional[list[Pattern]] = None,
|
||||
sensitive_cookies: Optional[set[str]] = None,
|
||||
safe_methods: set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"},
|
||||
cookie_name: str = "csrftoken",
|
||||
cookie_path: str = "/",
|
||||
cookie_domain: Optional[str] = None,
|
||||
cookie_secure: bool = False,
|
||||
cookie_httponly: bool = False,
|
||||
cookie_samesite: str = "lax",
|
||||
header_name: str = "x-csrftoken",
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.serializer = URLSafeSerializer(secret, "csrftoken")
|
||||
self.secret = secret
|
||||
self.required_urls = required_urls
|
||||
self.exempt_urls = exempt_urls
|
||||
self.sensitive_cookies = sensitive_cookies
|
||||
self.safe_methods = safe_methods
|
||||
self.cookie_name = cookie_name
|
||||
self.cookie_path = cookie_path
|
||||
self.cookie_domain = cookie_domain
|
||||
self.cookie_secure = cookie_secure
|
||||
self.cookie_httponly = cookie_httponly
|
||||
self.cookie_samesite = cookie_samesite
|
||||
self.header_name = header_name
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Skip CSRF check if not an HTTP request, like websockets
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return None
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Skip CSRF check if Authorization header is present
|
||||
auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower()
|
||||
if auth_scheme == "bearer" or auth_scheme == "basic":
|
||||
await self.app(scope, receive, send)
|
||||
return None
|
||||
|
||||
csrf_cookie = request.cookies.get(self.cookie_name)
|
||||
|
||||
if self._url_is_required(request.url) or (
|
||||
request.method not in self.safe_methods
|
||||
and not self._url_is_exempt(request.url)
|
||||
and self._has_sensitive_cookies(request.cookies)
|
||||
):
|
||||
submitted_csrf_token = await self._get_submitted_csrf_token(request)
|
||||
if (
|
||||
not csrf_cookie
|
||||
or not submitted_csrf_token
|
||||
or not self._csrf_tokens_match(
|
||||
csrf_cookie, submitted_csrf_token, request.user.id
|
||||
)
|
||||
):
|
||||
response = self._get_error_response(request)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
send = functools.partial(self.send, send=send, scope=scope)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def send(self, message: Message, send: Send, scope: Scope) -> None:
|
||||
request = Request(scope)
|
||||
csrf_cookie = request.cookies.get(self.cookie_name)
|
||||
|
||||
if csrf_cookie is None:
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
|
||||
cookie: http.cookies.BaseCookie = http.cookies.SimpleCookie()
|
||||
cookie_name = self.cookie_name
|
||||
cookie[cookie_name] = self._generate_csrf_token(request.user.id)
|
||||
cookie[cookie_name]["path"] = self.cookie_path
|
||||
cookie[cookie_name]["secure"] = self.cookie_secure
|
||||
cookie[cookie_name]["httponly"] = self.cookie_httponly
|
||||
cookie[cookie_name]["samesite"] = self.cookie_samesite
|
||||
if self.cookie_domain is not None:
|
||||
cookie[cookie_name]["domain"] = self.cookie_domain # pragma: no cover
|
||||
headers.append("set-cookie", cookie.output(header="").strip())
|
||||
|
||||
await send(message)
|
||||
|
||||
def _has_sensitive_cookies(self, cookies: dict[str, str]) -> bool:
|
||||
if not self.sensitive_cookies:
|
||||
return True
|
||||
for sensitive_cookie in self.sensitive_cookies:
|
||||
if sensitive_cookie in cookies:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _url_is_required(self, url: URL) -> bool:
|
||||
if not self.required_urls:
|
||||
return False
|
||||
for required_url in self.required_urls:
|
||||
if required_url.match(url.path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _url_is_exempt(self, url: URL) -> bool:
|
||||
if not self.exempt_urls:
|
||||
return False
|
||||
for exempt_url in self.exempt_urls:
|
||||
if exempt_url.match(url.path):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_submitted_csrf_token(self, request: Request) -> Optional[str]:
|
||||
return request.headers.get(self.header_name)
|
||||
|
||||
def _generate_csrf_token(self, user_id: int | None = None) -> str:
|
||||
obj = {"token": secrets.token_urlsafe(128), "user_id": user_id}
|
||||
return cast(str, self.serializer.dumps(obj))
|
||||
|
||||
def _csrf_tokens_match(
|
||||
self, document_cookie: str, header_cookie: str, user_id: str | None
|
||||
) -> bool:
|
||||
try:
|
||||
decoded_doc_cookie: str = self.serializer.loads(document_cookie)
|
||||
decoded_header_cookie: str = self.serializer.loads(header_cookie)
|
||||
|
||||
# Verify that the tokens match, the user IDs match
|
||||
# and the user_id matches the authenticated user
|
||||
return (
|
||||
secrets.compare_digest(
|
||||
decoded_doc_cookie["token"], decoded_doc_cookie["token"]
|
||||
)
|
||||
and decoded_header_cookie["user_id"] == decoded_header_cookie["user_id"]
|
||||
and decoded_doc_cookie["user_id"] == user_id
|
||||
and decoded_header_cookie["user_id"] == user_id
|
||||
)
|
||||
except BadSignature:
|
||||
return False
|
||||
|
||||
def _get_error_response(self, request: Request) -> Response:
|
||||
return PlainTextResponse(
|
||||
content="CSRF token verification failed", status_code=403
|
||||
)
|
||||
@@ -5,31 +5,11 @@ from joserfc import jwt
|
||||
from joserfc.errors import BadSignatureError
|
||||
from joserfc.jwk import OctKey
|
||||
from starlette.datastructures import MutableHeaders, Secret
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette_csrf.middleware import CSRFMiddleware
|
||||
|
||||
from config import SESSION_MAX_AGE_SECONDS
|
||||
|
||||
|
||||
class CustomCSRFMiddleware(CSRFMiddleware):
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Skip CSRF check if not an HTTP request, like websockets
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return None
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Skip CSRF check if Authorization header is present
|
||||
auth_scheme = request.headers.get("Authorization", "").split(" ", 1)[0].lower()
|
||||
if auth_scheme == "bearer" or auth_scheme == "basic":
|
||||
await self.app(scope, receive, send)
|
||||
return None
|
||||
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
SecretKey = namedtuple("SecretKey", ("encode", "decode"))
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ from endpoints import (
|
||||
)
|
||||
from handler.auth.constants import ALGORITHM
|
||||
from handler.auth.hybrid_auth import HybridAuthBackend
|
||||
from handler.auth.middleware import CustomCSRFMiddleware, SessionMiddleware
|
||||
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
|
||||
from handler.auth.middleware.session_middleware import SessionMiddleware
|
||||
from handler.socket_handler import socket_handler
|
||||
from logger.formatter import LOGGING_CONFIG
|
||||
from utils import get_version
|
||||
@@ -90,7 +91,7 @@ app.add_middleware(
|
||||
if not IS_PYTEST_RUN and not DISABLE_CSRF_PROTECTION:
|
||||
# CSRF protection (except endpoints listed in exempt_urls)
|
||||
app.add_middleware(
|
||||
CustomCSRFMiddleware,
|
||||
CSRFMiddleware,
|
||||
cookie_name="romm_csrftoken",
|
||||
secret=ROMM_AUTH_SECRET_KEY,
|
||||
exempt_urls=[re.compile(r"^/api/token.*"), re.compile(r"^/ws")],
|
||||
|
||||
@@ -24,6 +24,7 @@ dependencies = [
|
||||
"fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1",
|
||||
"gunicorn ~= 23.0",
|
||||
"httpx ~= 0.27",
|
||||
"itsdangerous>=2.2.0",
|
||||
"joserfc ~= 1.3.4",
|
||||
"opentelemetry-distro ~= 0.56",
|
||||
"opentelemetry-exporter-otlp ~= 1.36",
|
||||
@@ -46,7 +47,6 @@ dependencies = [
|
||||
"rq-scheduler @ git+https://github.com/adamantike/rq-scheduler.git@feat/script-options-username-ssl",
|
||||
"sentry-sdk ~= 2.32",
|
||||
"starlette ~= 0.49",
|
||||
"starlette-csrf ~= 3.0",
|
||||
"streaming-form-data ~= 1.19",
|
||||
"strsimpy ~= 0.2",
|
||||
"types-colorama ~= 0.4",
|
||||
|
||||
17
uv.lock
generated
17
uv.lock
generated
@@ -1911,6 +1911,7 @@ dependencies = [
|
||||
{ name = "fastapi-pagination", extra = ["sqlalchemy"] },
|
||||
{ name = "gunicorn" },
|
||||
{ name = "httpx" },
|
||||
{ name = "itsdangerous" },
|
||||
{ name = "joserfc" },
|
||||
{ name = "opentelemetry-distro" },
|
||||
{ name = "opentelemetry-exporter-otlp" },
|
||||
@@ -1933,7 +1934,6 @@ dependencies = [
|
||||
{ name = "sentry-sdk" },
|
||||
{ name = "sqlalchemy", extra = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"] },
|
||||
{ name = "starlette" },
|
||||
{ name = "starlette-csrf" },
|
||||
{ name = "streaming-form-data" },
|
||||
{ name = "strsimpy" },
|
||||
{ name = "types-colorama" },
|
||||
@@ -1982,6 +1982,7 @@ requires-dist = [
|
||||
{ name = "httpx", specifier = "~=0.27" },
|
||||
{ name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" },
|
||||
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" },
|
||||
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||
{ name = "joserfc", specifier = "~=1.3.4" },
|
||||
{ name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" },
|
||||
@@ -2013,7 +2014,6 @@ requires-dist = [
|
||||
{ name = "sentry-sdk", specifier = "~=2.32" },
|
||||
{ name = "sqlalchemy", extras = ["mariadb-connector", "mysql-connector", "postgresql-psycopg"], specifier = "~=2.0" },
|
||||
{ name = "starlette", specifier = "~=0.49" },
|
||||
{ name = "starlette-csrf", specifier = "~=3.0" },
|
||||
{ name = "streaming-form-data", specifier = "~=1.19" },
|
||||
{ name = "strsimpy", specifier = "~=0.2" },
|
||||
{ name = "types-colorama", specifier = "~=0.4" },
|
||||
@@ -2198,19 +2198,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "starlette-csrf"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "itsdangerous" },
|
||||
{ name = "starlette" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0f/7c/53c57b4cd76c9a4493a8525d34a76d7e4bbe0ff957de1c53f30241aa757a/starlette_csrf-3.0.0.tar.gz", hash = "sha256:7afaca8c72cc3c726e5942778af53454607ca3e653fd86cd75ee35d8cd1cfa77", size = 8371, upload-time = "2023-06-27T13:23:24.387Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/83/6641e4fdcf33b1cc614a74ecabe5835236a1b2564bf6735db7e35d788795/starlette_csrf-3.0.0-py3-none-any.whl", hash = "sha256:aac29b366e83621d3fc56be690866e16f3c56df91ab5e184b77950540a4e2761", size = 6170, upload-time = "2023-06-27T13:23:25.563Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "streaming-form-data"
|
||||
version = "1.19.1"
|
||||
|
||||
Reference in New Issue
Block a user