add tests for middlewares

This commit is contained in:
Georges-Antoine Assi
2025-11-17 23:40:00 -05:00
parent 551ff72a8a
commit 6a1a344ba2
5 changed files with 617 additions and 3 deletions

View File

@@ -134,7 +134,7 @@ class CSRFMiddleware:
return cast(str, self.serializer.dumps(obj))
def _csrf_tokens_match(
self, document_cookie: str, header_cookie: str, user_id: str | None
self, document_cookie: str, header_cookie: str, user_id: int | None
) -> bool:
try:
decoded_doc_cookie: str = self.serializer.loads(document_cookie)

View File

@@ -0,0 +1,248 @@
import re
from itsdangerous import URLSafeSerializer
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from config import ROMM_AUTH_SECRET_KEY
from handler.auth.constants import ALGORITHM
from handler.auth.hybrid_auth import HybridAuthBackend
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
from handler.auth.middleware.session_middleware import SessionMiddleware
# Test app factory #
def create_test_app(**csrf_kwargs) -> Starlette:
"""Return a Starlette app wired with CSRFMiddleware."""
async def get_handler(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK")
async def post_handler(request: Request) -> JSONResponse:
return JSONResponse({"status": "success"})
async def post_echo(request: Request) -> JSONResponse:
"""Return the CSRF token that was sent."""
token = request.headers.get(csrf_kwargs.get("header_name", "x-csrftoken"))
return JSONResponse({"token": token})
routes = [
Route("/get", get_handler, methods=["GET"]),
Route("/post", post_handler, methods=["POST"]),
Route("/echo", post_echo, methods=["POST"]),
]
middleware = [
Middleware(CSRFMiddleware, secret="test-secret", **csrf_kwargs),
Middleware(AuthenticationMiddleware, backend=HybridAuthBackend()),
Middleware(
SessionMiddleware,
secret_key=ROMM_AUTH_SECRET_KEY,
session_cookie="romm_session",
same_site="strict",
https_only=False,
jwt_alg=ALGORITHM,
),
]
return Starlette(routes=routes, middleware=middleware)
class TestCSRFMiddleware:
def test_csrf_cookie_set_on_first_get(self) -> None:
"""A GET request should set the CSRF cookie if none exists."""
app = create_test_app()
client = TestClient(app)
response = client.get("/get")
assert response.status_code == 200
assert "csrftoken" in response.cookies
def test_post_with_valid_token_succeeds(self) -> None:
"""POST with correct CSRF header should pass."""
app = create_test_app()
client = TestClient(app)
# Obtain cookie
resp = client.get("/get")
cookie = resp.cookies["csrftoken"]
# Post with token
resp = client.post("/post", headers={"x-csrftoken": cookie})
assert resp.status_code == 200
assert resp.json()["status"] == "success"
def test_post_without_cookie_fails(self) -> None:
"""POST without CSRF cookie should fail."""
app = create_test_app()
client = TestClient(app)
resp = client.post("/post")
assert resp.status_code == 403
assert "CSRF token verification failed" in resp.text
def test_post_without_header_fails(self) -> None:
"""POST without CSRF header should fail."""
app = create_test_app()
client = TestClient(app)
# Obtain cookie but don't send header
client.get("/get")
resp = client.post("/post")
assert resp.status_code == 403
def test_post_with_bad_signature_fails(self) -> None:
"""POST with tampered token should fail."""
app = create_test_app()
client = TestClient(app)
client.get("/get")
bad_token = "tampered-token"
resp = client.post("/post", headers={"x-csrftoken": bad_token})
assert resp.status_code == 403
def test_safe_methods_bypass_csrf(self) -> None:
"""GET/HEAD/OPTIONS/TRACE should never require CSRF."""
app = create_test_app()
client = TestClient(app)
for method in ("GET", "HEAD", "OPTIONS", "TRACE"):
resp = client.request(method, "/post")
assert resp.status_code == 200
def test_custom_header_name(self) -> None:
"""Middleware should read the token from the configured header."""
header_name = "x-xsrf-token"
app = create_test_app(header_name=header_name)
client = TestClient(app)
cookie_resp = client.get("/get")
token = cookie_resp.cookies["csrftoken"]
# Send with custom header
resp = client.post("/echo", headers={header_name: token})
assert resp.status_code == 200
assert resp.json()["token"] == token
def test_custom_cookie_name(self) -> None:
"""Middleware should use the configured cookie name."""
cookie_name = "my_token"
app = create_test_app(cookie_name=cookie_name)
client = TestClient(app)
resp = client.get("/get")
assert cookie_name in resp.cookies
def test_cookie_attributes(self) -> None:
"""Verify Secure, HttpOnly, SameSite, Path, Domain attributes."""
app = create_test_app(
cookie_secure=True,
cookie_httponly=True,
cookie_samesite="strict",
cookie_path="/app",
cookie_domain=".example.com",
)
client = TestClient(app)
resp = client.get("/get")
set_cookie = resp.headers["set-cookie"]
assert "secure" in set_cookie
assert "httponly" in set_cookie
assert "samesite=strict" in set_cookie
assert "path=/app" in set_cookie
assert "domain=.example.com" in set_cookie
def test_exempt_urls(self) -> None:
"""POST to exempt URLs should not require CSRF."""
app = create_test_app(exempt_urls=[re.compile(r"^/post$")])
client = TestClient(app)
# No cookie/header needed
resp = client.post("/post")
assert resp.status_code == 200
def test_required_urls(self) -> None:
"""POST to required URLs should always require CSRF even for safe methods."""
app = create_test_app(required_urls=[re.compile(r"^/get$")], safe_methods=set())
client = TestClient(app)
# GET now requires token
resp = client.get("/get")
assert resp.status_code == 403
def test_sensitive_cookies(self) -> None:
"""If no sensitive cookies exist, CSRF is not enforced."""
app = create_test_app(sensitive_cookies={"session"})
client = TestClient(app)
# No sensitive cookie → POST allowed
resp = client.post("/post")
assert resp.status_code == 200
# Add sensitive cookie → POST blocked
client.cookies.set("session", "abc123")
resp = client.post("/post")
assert resp.status_code == 403
# Bypass rules #
def test_bearer_auth_bypass(self) -> None:
"""Requests with Bearer/Basic Authorization header bypass CSRF."""
app = create_test_app()
client = TestClient(app)
resp = client.post("/post", headers={"Authorization": "Bearer token"})
assert resp.status_code == 200
def test_non_http_scope_bypass(self) -> None:
"""WebSocket (or other non-HTTP) scopes should pass through."""
# Manual ASGI call; TestClient doesn't expose WebSocket easily
scope = {"type": "websocket", "path": "/ws", "headers": []}
receive = lambda: {} # noqa: E731
send = lambda msg: None # noqa: E731
async def dummy_app(scope, receive, send):
await send({"type": "websocket.accept"})
middleware = CSRFMiddleware(dummy_app, secret="test")
import asyncio
asyncio.run(middleware(scope, receive, send)) # should not raise
def test_token_generation_and_validation(self) -> None:
"""Ensure tokens are signed and validated correctly."""
app = create_test_app()
client = TestClient(app)
# Extract cookie
resp = client.get("/get")
cookie = resp.cookies["csrftoken"]
# Verify signature
serializer = URLSafeSerializer("test-secret", "csrftoken")
payload = serializer.loads(cookie)
assert "token" in payload
assert "user_id" in payload
# Send same token back
resp = client.post("/echo", headers={"x-csrftoken": cookie})
assert resp.status_code == 200
assert resp.json()["token"] == cookie
def test_user_id_mismatch_fails(self) -> None:
"""Tokens issued for one user must not validate for another."""
# We simulate two users by calling _generate_csrf_token with different IDs
mw = CSRFMiddleware(app=lambda s, r, se: None, secret="test")
user1_token = mw._generate_csrf_token(user_id=1)
mw._generate_csrf_token(user_id=2)
# user1_token should not validate for user_id=2
assert not mw._csrf_tokens_match(user1_token, user1_token, user_id=2)
def test_bad_signature_returns_false(self) -> None:
"""_csrf_tokens_match should return False on BadSignature."""
mw = CSRFMiddleware(app=lambda s, r, se: None, secret="test")
ok = mw._csrf_tokens_match("bad-token", "bad-token", user_id=None)
assert ok is False

View File

@@ -0,0 +1,366 @@
"""
Test suite for SessionMiddleware using JWT-based session management.
"""
import time
from typing import Any, Dict
from joserfc import jwt
from joserfc.jwk import OctKey
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from handler.auth.middleware.session_middleware import SessionMiddleware
def create_test_app(**middleware_kwargs) -> Starlette:
def get_session_data(request: Request) -> Dict[str, Any]:
"""Extract session data from request"""
return dict(request.session)
def set_session_data(request: Request, data: Dict[str, Any]) -> None:
"""Set session data"""
request.session.update(data)
def clear_session(request: Request) -> None:
"""Clear session data"""
request.session.clear()
# Define test app routes
async def homepage(request: Request) -> PlainTextResponse:
"""Basic route that sets default session data."""
request.session.setdefault("visited", 0)
request.session["visited"] += 1
return PlainTextResponse("OK")
async def set_session(request: Request) -> JSONResponse:
"""Set specific session data."""
data = await request.json()
request.session.update(data)
return JSONResponse({"session": dict(request.session)})
async def get_session(request: Request) -> JSONResponse:
"""Get current session data."""
return JSONResponse({"session": dict(request.session)})
async def clear_session_route(request: Request) -> PlainTextResponse:
"""Clear the session."""
request.session.clear()
return PlainTextResponse("Session cleared")
async def modify_session(request: Request) -> JSONResponse:
"""Modify session with provided data."""
data = await request.json()
for key, value in data.items():
if value is None:
request.session.pop(key, None)
else:
request.session[key] = value
return JSONResponse({"session": dict(request.session)})
"""Create a test app with SessionMiddleware."""
routes = [
Route("/", homepage),
Route("/set", set_session, methods=["POST"]),
Route("/get", get_session),
Route("/clear", clear_session_route),
Route("/modify", modify_session, methods=["POST"]),
]
kwargs = {"secret_key": "test-secret-key", **middleware_kwargs}
middleware = [Middleware(SessionMiddleware, **kwargs)]
return Starlette(routes=routes, middleware=middleware)
class TestSessionMiddleware:
def test_session_creation(self) -> None:
"""Test that a session cookie is set on the first request."""
app = create_test_app()
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert "session" in response.cookies
assert response.text == "OK"
def test_session_reading(self) -> None:
"""Test that session data can be read from the cookie on subsequent requests."""
app = create_test_app()
client = TestClient(app)
# First request sets session
response = client.get("/")
assert response.status_code == 200
# Second request should read the session
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert "visited" in data["session"]
assert data["session"]["visited"] == 1
def test_session_modification(self) -> None:
"""Test that session data can be modified and persisted across requests."""
app = create_test_app()
client = TestClient(app)
response = client.post("/set", json={"user": "test_user", "role": "admin"})
assert response.status_code == 200
assert response.json()["session"]["user"] == "test_user"
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"]["user"] == "test_user"
assert data["session"]["role"] == "admin"
def test_session_clearing(self) -> None:
"""Test that clearing the session removes the cookie."""
app = create_test_app()
client = TestClient(app)
response = client.post("/set", json={"user": "test_user"})
assert response.status_code == 200
# Clear session
response = client.get("/clear")
assert response.status_code == 200
# Verify session is cleared
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"] == {}
def test_session_max_age(self) -> None:
"""Test that the Max-Age attribute is set correctly."""
max_age = 3600
app = create_test_app(max_age=max_age)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
set_cookie_header = response.headers.get("set-cookie", "")
assert f"Max-Age={max_age}" in set_cookie_header
def test_session_https_only(self) -> None:
"""Test that the secure flag is set when https_only is True."""
app = create_test_app(https_only=True)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
set_cookie_header = response.headers.get("set-cookie", "")
assert "secure" in set_cookie_header
def test_session_same_site(self) -> None:
"""Test that the samesite attribute is set correctly."""
app = create_test_app(same_site="strict")
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
set_cookie_header = response.headers.get("set-cookie", "")
assert "samesite=strict" in set_cookie_header
def test_session_expiration_past(self) -> None:
"""Test that an expired session is not loaded and a new one is created."""
app = create_test_app()
client = TestClient(app)
# Build a token that expired one hour ago
expired = int(time.time()) - 3600
payload = {"user": "test_user", "exp": expired}
key = OctKey.import_key("test-secret-key")
token = jwt.encode({"alg": "HS256"}, payload, key=key)
response = client.get("/get", cookies={"session": token})
assert response.status_code == 200
# middleware must reject the expired token → empty session
assert response.json()["session"] == {}
def test_session_not_before_future(self) -> None:
"""Test that a session with a 'not before' claim in the future is ignored."""
app = create_test_app()
client = TestClient(app)
# Build a token that is not valid until tomorrow
nbf = int(time.time()) + 86400
payload = {"user": "test_user", "nbf": nbf}
key = OctKey.import_key("test-secret-key")
token = jwt.encode({"alg": "HS256"}, payload, key=key)
response = client.get("/get", cookies={"session": token})
assert response.status_code == 200
# middleware must reject the future token → empty session
assert response.json()["session"] == {}
def test_session_bad_signature(self) -> None:
"""Test that a session with a bad signature is ignored and a new session is created."""
app = create_test_app()
client = TestClient(app)
# Create a session
response = client.post("/set", json={"user": "test_user"})
assert response.status_code == 200
# Tamper with the cookie (simulate bad signature)
tampered_cookie = response.cookies["session"][:-10] + "tampereddata"
response = client.get("/get", cookies={"session": tampered_cookie})
assert response.status_code == 200
data = response.json()
# Should have a new empty session, not the tampered one
assert "user" not in data["session"]
def test_session_cookie_name(self) -> None:
"""Test that custom session cookie name works."""
custom_cookie_name = "my_session"
app = create_test_app(session_cookie=custom_cookie_name)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert custom_cookie_name in response.cookies
assert "session" not in response.cookies
def test_session_jwt_algorithm(self) -> None:
"""Test that custom JWT algorithm works."""
app = create_test_app(jwt_alg="HS256")
client = TestClient(app)
response = client.post("/set", json={"user": "test_user"})
assert response.status_code == 200
token = response.cookies["session"]
decoded = jwt.decode(
token,
key=OctKey.import_key("test-secret-key"),
algorithms=["HS256"],
)
assert decoded.header["alg"] == "HS256"
def test_session_data_modification(self) -> None:
"""Test that session data can be modified correctly."""
app = create_test_app()
client = TestClient(app)
# Set initial data
response = client.post("/set", json={"counter": 1, "name": "Alice"})
assert response.status_code == 200
# Modify data
response = client.post("/modify", json={"counter": 2, "name": None})
assert response.status_code == 200
data = response.json()
assert data["session"]["counter"] == 2
assert "name" not in data["session"] # Should be removed
# Verify changes persist
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"]["counter"] == 2
assert "name" not in data["session"]
def test_session_empty_on_first_visit(self) -> None:
"""Test that session is empty on first visit."""
app = create_test_app()
client = TestClient(app)
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"] == {}
def test_session_visits_counter(self) -> None:
"""Test that the visits counter increments correctly."""
app = create_test_app()
client = TestClient(app)
# First visit
response = client.get("/")
assert response.status_code == 200
response = client.get("/get")
data = response.json()
assert data["session"]["visited"] == 1
# Second visit
response = client.get("/")
assert response.status_code == 200
response = client.get("/get")
data = response.json()
assert data["session"]["visited"] == 2
def test_session_with_special_characters(self) -> None:
"""Test that session data with special characters works correctly."""
app = create_test_app()
client = TestClient(app)
special_data = {
"emoji": "🚀",
"unicode": "你好世界",
"special": "test@#$%^&*()",
"nested": {"list": [1, 2, 3], "dict": {"a": 1}},
}
response = client.post("/set", json=special_data)
assert response.status_code == 200
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"]["emoji"] == special_data["emoji"]
assert data["session"]["unicode"] == special_data["unicode"]
assert data["session"]["special"] == special_data["special"]
assert data["session"]["nested"] == special_data["nested"]
def test_full_session_lifecycle():
"""Test the complete lifecycle of a session."""
app = create_test_app(max_age=3600)
client = TestClient(app)
# Start with no session
response = client.get("/get")
assert response.status_code == 200
assert response.json()["session"] == {}
# Create a session
response = client.post("/set", json={"user": "test_user", "role": "user"})
assert response.status_code == 200
# Verify session persists
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"]["user"] == "test_user"
assert data["session"]["role"] == "user"
# Modify session
response = client.post("/modify", json={"role": "admin"})
assert response.status_code == 200
# Verify modification
response = client.get("/get")
assert response.status_code == 200
data = response.json()
assert data["session"]["role"] == "admin"
# Clear session
response = client.get("/clear")
assert response.status_code == 200
# Verify session is cleared
response = client.get("/get")
assert response.status_code == 200
assert response.json()["session"] == {}

View File

@@ -24,7 +24,7 @@ dependencies = [
"fastapi[standard-no-fastapi-cloud-cli] ~= 0.121.1",
"gunicorn ~= 23.0",
"httpx ~= 0.27",
"itsdangerous>=2.2.0",
"itsdangerous ~= 2.2",
"joserfc ~= 1.3.4",
"opentelemetry-distro ~= 0.56",
"opentelemetry-exporter-otlp ~= 1.36",

2
uv.lock generated
View File

@@ -1982,7 +1982,7 @@ requires-dist = [
{ name = "httpx", specifier = "~=0.27" },
{ name = "ipdb", marker = "extra == 'dev'", specifier = "~=0.13" },
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "~=6.29" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "itsdangerous", specifier = "~=2.2" },
{ name = "joserfc", specifier = "~=1.3.4" },
{ name = "memray", marker = "extra == 'dev'", specifier = "~=1.15" },
{ name = "mypy", marker = "extra == 'dev'", specifier = "~=1.13" },