diff --git a/backend/handler/auth/middleware/csrf_middleware.py b/backend/handler/auth/middleware/csrf_middleware.py index d6a04466e..902532572 100644 --- a/backend/handler/auth/middleware/csrf_middleware.py +++ b/backend/handler/auth/middleware/csrf_middleware.py @@ -134,7 +134,7 @@ class CSRFMiddleware: return cast(str, self.serializer.dumps(obj)) def _csrf_tokens_match( - self, document_cookie: str, header_cookie: str, user_id: str | None + self, document_cookie: str, header_cookie: str, user_id: int | None ) -> bool: try: decoded_doc_cookie: str = self.serializer.loads(document_cookie) diff --git a/backend/tests/handler/auth/test_csrf_middleware.py b/backend/tests/handler/auth/test_csrf_middleware.py new file mode 100644 index 000000000..8ec8fadbc --- /dev/null +++ b/backend/tests/handler/auth/test_csrf_middleware.py @@ -0,0 +1,248 @@ +import re + +from itsdangerous import URLSafeSerializer +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.routing import Route +from starlette.testclient import TestClient + +from config import ROMM_AUTH_SECRET_KEY +from handler.auth.constants import ALGORITHM +from handler.auth.hybrid_auth import HybridAuthBackend +from handler.auth.middleware.csrf_middleware import CSRFMiddleware +from handler.auth.middleware.session_middleware import SessionMiddleware + + +# Test app factory # +def create_test_app(**csrf_kwargs) -> Starlette: + """Return a Starlette app wired with CSRFMiddleware.""" + + async def get_handler(request: Request) -> PlainTextResponse: + return PlainTextResponse("OK") + + async def post_handler(request: Request) -> JSONResponse: + return JSONResponse({"status": "success"}) + + async def post_echo(request: Request) -> JSONResponse: + """Return the CSRF token that was sent.""" + token = request.headers.get(csrf_kwargs.get("header_name", "x-csrftoken")) + return JSONResponse({"token": token}) + + routes = [ + Route("/get", get_handler, methods=["GET"]), + Route("/post", post_handler, methods=["POST"]), + Route("/echo", post_echo, methods=["POST"]), + ] + middleware = [ + Middleware(CSRFMiddleware, secret="test-secret", **csrf_kwargs), + Middleware(AuthenticationMiddleware, backend=HybridAuthBackend()), + Middleware( + SessionMiddleware, + secret_key=ROMM_AUTH_SECRET_KEY, + session_cookie="romm_session", + same_site="strict", + https_only=False, + jwt_alg=ALGORITHM, + ), + ] + return Starlette(routes=routes, middleware=middleware) + + +class TestCSRFMiddleware: + def test_csrf_cookie_set_on_first_get(self) -> None: + """A GET request should set the CSRF cookie if none exists.""" + app = create_test_app() + client = TestClient(app) + + response = client.get("/get") + assert response.status_code == 200 + assert "csrftoken" in response.cookies + + def test_post_with_valid_token_succeeds(self) -> None: + """POST with correct CSRF header should pass.""" + app = create_test_app() + client = TestClient(app) + + # Obtain cookie + resp = client.get("/get") + cookie = resp.cookies["csrftoken"] + + # Post with token + resp = client.post("/post", headers={"x-csrftoken": cookie}) + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + def test_post_without_cookie_fails(self) -> None: + """POST without CSRF cookie should fail.""" + app = create_test_app() + client = TestClient(app) + + resp = client.post("/post") + assert resp.status_code == 403 + assert "CSRF token verification failed" in resp.text + + def test_post_without_header_fails(self) -> None: + """POST without CSRF header should fail.""" + app = create_test_app() + client = TestClient(app) + + # Obtain cookie but don't send header + client.get("/get") + resp = client.post("/post") + assert resp.status_code == 403 + + def test_post_with_bad_signature_fails(self) -> None: + """POST with tampered token should fail.""" + app = create_test_app() + client = TestClient(app) + + client.get("/get") + bad_token = "tampered-token" + resp = client.post("/post", headers={"x-csrftoken": bad_token}) + assert resp.status_code == 403 + + def test_safe_methods_bypass_csrf(self) -> None: + """GET/HEAD/OPTIONS/TRACE should never require CSRF.""" + app = create_test_app() + client = TestClient(app) + + for method in ("GET", "HEAD", "OPTIONS", "TRACE"): + resp = client.request(method, "/post") + assert resp.status_code == 200 + + def test_custom_header_name(self) -> None: + """Middleware should read the token from the configured header.""" + header_name = "x-xsrf-token" + app = create_test_app(header_name=header_name) + client = TestClient(app) + + cookie_resp = client.get("/get") + token = cookie_resp.cookies["csrftoken"] + + # Send with custom header + resp = client.post("/echo", headers={header_name: token}) + assert resp.status_code == 200 + assert resp.json()["token"] == token + + def test_custom_cookie_name(self) -> None: + """Middleware should use the configured cookie name.""" + cookie_name = "my_token" + app = create_test_app(cookie_name=cookie_name) + client = TestClient(app) + + resp = client.get("/get") + assert cookie_name in resp.cookies + + def test_cookie_attributes(self) -> None: + """Verify Secure, HttpOnly, SameSite, Path, Domain attributes.""" + app = create_test_app( + cookie_secure=True, + cookie_httponly=True, + cookie_samesite="strict", + cookie_path="/app", + cookie_domain=".example.com", + ) + client = TestClient(app) + + resp = client.get("/get") + set_cookie = resp.headers["set-cookie"] + assert "secure" in set_cookie + assert "httponly" in set_cookie + assert "samesite=strict" in set_cookie + assert "path=/app" in set_cookie + assert "domain=.example.com" in set_cookie + + def test_exempt_urls(self) -> None: + """POST to exempt URLs should not require CSRF.""" + app = create_test_app(exempt_urls=[re.compile(r"^/post$")]) + client = TestClient(app) + + # No cookie/header needed + resp = client.post("/post") + assert resp.status_code == 200 + + def test_required_urls(self) -> None: + """POST to required URLs should always require CSRF even for safe methods.""" + app = create_test_app(required_urls=[re.compile(r"^/get$")], safe_methods=set()) + client = TestClient(app) + + # GET now requires token + resp = client.get("/get") + assert resp.status_code == 403 + + def test_sensitive_cookies(self) -> None: + """If no sensitive cookies exist, CSRF is not enforced.""" + app = create_test_app(sensitive_cookies={"session"}) + client = TestClient(app) + + # No sensitive cookie β†’ POST allowed + resp = client.post("/post") + assert resp.status_code == 200 + + # Add sensitive cookie β†’ POST blocked + client.cookies.set("session", "abc123") + resp = client.post("/post") + assert resp.status_code == 403 + + # Bypass rules # + def test_bearer_auth_bypass(self) -> None: + """Requests with Bearer/Basic Authorization header bypass CSRF.""" + app = create_test_app() + client = TestClient(app) + + resp = client.post("/post", headers={"Authorization": "Bearer token"}) + assert resp.status_code == 200 + + def test_non_http_scope_bypass(self) -> None: + """WebSocket (or other non-HTTP) scopes should pass through.""" + # Manual ASGI call; TestClient doesn't expose WebSocket easily + scope = {"type": "websocket", "path": "/ws", "headers": []} + receive = lambda: {} # noqa: E731 + send = lambda msg: None # noqa: E731 + + async def dummy_app(scope, receive, send): + await send({"type": "websocket.accept"}) + + middleware = CSRFMiddleware(dummy_app, secret="test") + import asyncio + + asyncio.run(middleware(scope, receive, send)) # should not raise + + def test_token_generation_and_validation(self) -> None: + """Ensure tokens are signed and validated correctly.""" + app = create_test_app() + client = TestClient(app) + + # Extract cookie + resp = client.get("/get") + cookie = resp.cookies["csrftoken"] + + # Verify signature + serializer = URLSafeSerializer("test-secret", "csrftoken") + payload = serializer.loads(cookie) + assert "token" in payload + assert "user_id" in payload + + # Send same token back + resp = client.post("/echo", headers={"x-csrftoken": cookie}) + assert resp.status_code == 200 + assert resp.json()["token"] == cookie + + def test_user_id_mismatch_fails(self) -> None: + """Tokens issued for one user must not validate for another.""" + # We simulate two users by calling _generate_csrf_token with different IDs + mw = CSRFMiddleware(app=lambda s, r, se: None, secret="test") + user1_token = mw._generate_csrf_token(user_id=1) + mw._generate_csrf_token(user_id=2) + + # user1_token should not validate for user_id=2 + assert not mw._csrf_tokens_match(user1_token, user1_token, user_id=2) + + def test_bad_signature_returns_false(self) -> None: + """_csrf_tokens_match should return False on BadSignature.""" + mw = CSRFMiddleware(app=lambda s, r, se: None, secret="test") + ok = mw._csrf_tokens_match("bad-token", "bad-token", user_id=None) + assert ok is False diff --git a/backend/tests/handler/auth/test_session_middleware.py b/backend/tests/handler/auth/test_session_middleware.py new file mode 100644 index 000000000..38f99584a --- /dev/null +++ b/backend/tests/handler/auth/test_session_middleware.py @@ -0,0 +1,366 @@ +""" +Test suite for SessionMiddleware using JWT-based session management. +""" + +import time +from typing import Any, Dict + +from joserfc import jwt +from joserfc.jwk import OctKey +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.routing import Route +from starlette.testclient import TestClient + +from handler.auth.middleware.session_middleware import SessionMiddleware + + +def create_test_app(**middleware_kwargs) -> Starlette: + def get_session_data(request: Request) -> Dict[str, Any]: + """Extract session data from request""" + return dict(request.session) + + def set_session_data(request: Request, data: Dict[str, Any]) -> None: + """Set session data""" + request.session.update(data) + + def clear_session(request: Request) -> None: + """Clear session data""" + request.session.clear() + + # Define test app routes + async def homepage(request: Request) -> PlainTextResponse: + """Basic route that sets default session data.""" + request.session.setdefault("visited", 0) + request.session["visited"] += 1 + return PlainTextResponse("OK") + + async def set_session(request: Request) -> JSONResponse: + """Set specific session data.""" + data = await request.json() + request.session.update(data) + return JSONResponse({"session": dict(request.session)}) + + async def get_session(request: Request) -> JSONResponse: + """Get current session data.""" + return JSONResponse({"session": dict(request.session)}) + + async def clear_session_route(request: Request) -> PlainTextResponse: + """Clear the session.""" + request.session.clear() + return PlainTextResponse("Session cleared") + + async def modify_session(request: Request) -> JSONResponse: + """Modify session with provided data.""" + data = await request.json() + for key, value in data.items(): + if value is None: + request.session.pop(key, None) + else: + request.session[key] = value + return JSONResponse({"session": dict(request.session)}) + + """Create a test app with SessionMiddleware.""" + routes = [ + Route("/", homepage), + Route("/set", set_session, methods=["POST"]), + Route("/get", get_session), + Route("/clear", clear_session_route), + Route("/modify", modify_session, methods=["POST"]), + ] + kwargs = {"secret_key": "test-secret-key", **middleware_kwargs} + middleware = [Middleware(SessionMiddleware, **kwargs)] + return Starlette(routes=routes, middleware=middleware) + + +class TestSessionMiddleware: + def test_session_creation(self) -> None: + """Test that a session cookie is set on the first request.""" + app = create_test_app() + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + assert "session" in response.cookies + assert response.text == "OK" + + def test_session_reading(self) -> None: + """Test that session data can be read from the cookie on subsequent requests.""" + app = create_test_app() + client = TestClient(app) + + # First request sets session + response = client.get("/") + assert response.status_code == 200 + + # Second request should read the session + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert "visited" in data["session"] + assert data["session"]["visited"] == 1 + + def test_session_modification(self) -> None: + """Test that session data can be modified and persisted across requests.""" + app = create_test_app() + client = TestClient(app) + + response = client.post("/set", json={"user": "test_user", "role": "admin"}) + assert response.status_code == 200 + assert response.json()["session"]["user"] == "test_user" + + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert data["session"]["user"] == "test_user" + assert data["session"]["role"] == "admin" + + def test_session_clearing(self) -> None: + """Test that clearing the session removes the cookie.""" + app = create_test_app() + client = TestClient(app) + + response = client.post("/set", json={"user": "test_user"}) + assert response.status_code == 200 + + # Clear session + response = client.get("/clear") + assert response.status_code == 200 + + # Verify session is cleared + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert data["session"] == {} + + def test_session_max_age(self) -> None: + """Test that the Max-Age attribute is set correctly.""" + max_age = 3600 + app = create_test_app(max_age=max_age) + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + + set_cookie_header = response.headers.get("set-cookie", "") + assert f"Max-Age={max_age}" in set_cookie_header + + def test_session_https_only(self) -> None: + """Test that the secure flag is set when https_only is True.""" + app = create_test_app(https_only=True) + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + + set_cookie_header = response.headers.get("set-cookie", "") + assert "secure" in set_cookie_header + + def test_session_same_site(self) -> None: + """Test that the samesite attribute is set correctly.""" + app = create_test_app(same_site="strict") + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + + set_cookie_header = response.headers.get("set-cookie", "") + assert "samesite=strict" in set_cookie_header + + def test_session_expiration_past(self) -> None: + """Test that an expired session is not loaded and a new one is created.""" + app = create_test_app() + client = TestClient(app) + + # Build a token that expired one hour ago + expired = int(time.time()) - 3600 + payload = {"user": "test_user", "exp": expired} + key = OctKey.import_key("test-secret-key") + token = jwt.encode({"alg": "HS256"}, payload, key=key) + + response = client.get("/get", cookies={"session": token}) + assert response.status_code == 200 + # middleware must reject the expired token β†’ empty session + assert response.json()["session"] == {} + + def test_session_not_before_future(self) -> None: + """Test that a session with a 'not before' claim in the future is ignored.""" + app = create_test_app() + client = TestClient(app) + + # Build a token that is not valid until tomorrow + nbf = int(time.time()) + 86400 + payload = {"user": "test_user", "nbf": nbf} + key = OctKey.import_key("test-secret-key") + token = jwt.encode({"alg": "HS256"}, payload, key=key) + + response = client.get("/get", cookies={"session": token}) + assert response.status_code == 200 + # middleware must reject the future token β†’ empty session + assert response.json()["session"] == {} + + def test_session_bad_signature(self) -> None: + """Test that a session with a bad signature is ignored and a new session is created.""" + app = create_test_app() + client = TestClient(app) + + # Create a session + response = client.post("/set", json={"user": "test_user"}) + assert response.status_code == 200 + + # Tamper with the cookie (simulate bad signature) + tampered_cookie = response.cookies["session"][:-10] + "tampereddata" + response = client.get("/get", cookies={"session": tampered_cookie}) + assert response.status_code == 200 + data = response.json() + + # Should have a new empty session, not the tampered one + assert "user" not in data["session"] + + def test_session_cookie_name(self) -> None: + """Test that custom session cookie name works.""" + custom_cookie_name = "my_session" + app = create_test_app(session_cookie=custom_cookie_name) + client = TestClient(app) + + response = client.get("/") + assert response.status_code == 200 + assert custom_cookie_name in response.cookies + assert "session" not in response.cookies + + def test_session_jwt_algorithm(self) -> None: + """Test that custom JWT algorithm works.""" + app = create_test_app(jwt_alg="HS256") + client = TestClient(app) + + response = client.post("/set", json={"user": "test_user"}) + assert response.status_code == 200 + + token = response.cookies["session"] + decoded = jwt.decode( + token, + key=OctKey.import_key("test-secret-key"), + algorithms=["HS256"], + ) + assert decoded.header["alg"] == "HS256" + + def test_session_data_modification(self) -> None: + """Test that session data can be modified correctly.""" + app = create_test_app() + client = TestClient(app) + + # Set initial data + response = client.post("/set", json={"counter": 1, "name": "Alice"}) + assert response.status_code == 200 + + # Modify data + response = client.post("/modify", json={"counter": 2, "name": None}) + assert response.status_code == 200 + data = response.json() + assert data["session"]["counter"] == 2 + assert "name" not in data["session"] # Should be removed + + # Verify changes persist + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert data["session"]["counter"] == 2 + assert "name" not in data["session"] + + def test_session_empty_on_first_visit(self) -> None: + """Test that session is empty on first visit.""" + app = create_test_app() + client = TestClient(app) + + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert data["session"] == {} + + def test_session_visits_counter(self) -> None: + """Test that the visits counter increments correctly.""" + app = create_test_app() + client = TestClient(app) + + # First visit + response = client.get("/") + assert response.status_code == 200 + + response = client.get("/get") + data = response.json() + assert data["session"]["visited"] == 1 + + # Second visit + response = client.get("/") + assert response.status_code == 200 + + response = client.get("/get") + data = response.json() + assert data["session"]["visited"] == 2 + + def test_session_with_special_characters(self) -> None: + """Test that session data with special characters works correctly.""" + app = create_test_app() + client = TestClient(app) + + special_data = { + "emoji": "πŸš€", + "unicode": "δ½ ε₯½δΈ–η•Œ", + "special": "test@#$%^&*()", + "nested": {"list": [1, 2, 3], "dict": {"a": 1}}, + } + + response = client.post("/set", json=special_data) + assert response.status_code == 200 + + response = client.get("/get") + assert response.status_code == 200 + + data = response.json() + assert data["session"]["emoji"] == special_data["emoji"] + assert data["session"]["unicode"] == special_data["unicode"] + assert data["session"]["special"] == special_data["special"] + assert data["session"]["nested"] == special_data["nested"] + + +def test_full_session_lifecycle(): + """Test the complete lifecycle of a session.""" + app = create_test_app(max_age=3600) + client = TestClient(app) + + # Start with no session + response = client.get("/get") + assert response.status_code == 200 + assert response.json()["session"] == {} + + # Create a session + response = client.post("/set", json={"user": "test_user", "role": "user"}) + assert response.status_code == 200 + + # Verify session persists + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert data["session"]["user"] == "test_user" + assert data["session"]["role"] == "user" + + # Modify session + response = client.post("/modify", json={"role": "admin"}) + assert response.status_code == 200 + + # Verify modification + response = client.get("/get") + assert response.status_code == 200 + data = response.json() + assert data["session"]["role"] == "admin" + + # Clear session + response = client.get("/clear") + assert response.status_code == 200 + + # Verify session is cleared + response = client.get("/get") + assert response.status_code == 200 + assert response.json()["session"] == {} diff --git a/pyproject.toml b/pyproject.toml index c66d66f1c..81cc5950a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1", "gunicorn ~= 23.0", "httpx ~= 0.27", - "itsdangerous>=2.2.0", + "itsdangerous ~= 2.2", "joserfc ~= 1.3.4", "opentelemetry-distro ~= 0.56", "opentelemetry-exporter-otlp ~= 1.36", diff --git a/uv.lock b/uv.lock index 4083570ab..fb5f93452 100644 --- a/uv.lock +++ b/uv.lock @@ -1982,7 +1982,7 @@ requires-dist = [ { name = "httpx", specifier = "~=0.27" }, { name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" }, - { name = "itsdangerous", specifier = ">=2.2.0" }, + { name = "itsdangerous", specifier = "~=2.2" }, { name = "joserfc", specifier = "~=1.3.4" }, { name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" }, { name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" },