diff --git a/backend/config/__init__.py b/backend/config/__init__.py index 2eca61775..7dcacf82e 100644 --- a/backend/config/__init__.py +++ b/backend/config/__init__.py @@ -109,7 +109,7 @@ FLASHPOINT_API_ENABLED: Final[bool] = safe_str_to_bool( HLTB_API_ENABLED: Final[bool] = safe_str_to_bool(_get_env("HLTB_API_ENABLED")) # AUTH -ROMM_AUTH_SECRET_KEY: Final[str | None] = _get_env("ROMM_AUTH_SECRET_KEY") +ROMM_AUTH_SECRET_KEY: Final[str] = _get_env("ROMM_AUTH_SECRET_KEY", "") if not ROMM_AUTH_SECRET_KEY: raise ValueError("ROMM_AUTH_SECRET_KEY environment variable is not set!") diff --git a/backend/endpoints/auth.py b/backend/endpoints/auth.py index eb3562b85..030f2e8a3 100644 --- a/backend/endpoints/auth.py +++ b/backend/endpoints/auth.py @@ -54,7 +54,8 @@ def login( if not user.enabled: raise UserDisabledException - request.session.update({"iss": "romm:auth", "sub": user.username}) + request.session["iss"] = "romm:auth" + request.session["sub"] = user.username # Update last login and active times now = datetime.now(timezone.utc) @@ -262,7 +263,8 @@ async def auth_openid(request: Request): if not potential_user.enabled: raise UserDisabledException - request.session.update({"iss": "romm:auth", "sub": potential_user.username}) + request.session["iss"] = "romm:auth" + request.session["sub"] = potential_user.username # Update last login and active times now = datetime.now(timezone.utc) @@ -293,7 +295,7 @@ def request_password_reset(username: str = Body(..., embed=True)) -> None: @router.post("/reset-password", status_code=status.HTTP_200_OK) -def reset_password( +async def reset_password( token: str = Body(..., embed=True), new_password: str = Body(..., embed=True), ) -> None: @@ -308,7 +310,7 @@ def reset_password( """ user = auth_handler.verify_password_reset_token(token) - auth_handler.set_user_new_password(user, new_password) + await auth_handler.set_user_new_password(user, new_password) log.info( f"Password was successfully reset for user {hl(user.username, color=CYAN)}." diff --git a/backend/handler/auth/base_handler.py b/backend/handler/auth/base_handler.py index 3bba2456b..d9a9d5afd 100644 --- a/backend/handler/auth/base_handler.py +++ b/backend/handler/auth/base_handler.py @@ -21,11 +21,14 @@ from config import ( from decorators.auth import oauth from exceptions.auth_exceptions import OAuthCredentialsException, UserDisabledException from handler.auth.constants import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY, TokenPurpose +from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware from handler.redis_handler import redis_client from logger.formatter import CYAN from logger.formatter import highlight as hl from logger.logger import log +oct_key = OctKey.import_key(ROMM_AUTH_SECRET_KEY) + class AuthHandler: def __init__(self) -> None: @@ -95,7 +98,7 @@ class AuthHandler: token = jwt.encode( {"alg": ALGORITHM}, to_encode, - OctKey.import_key(ROMM_AUTH_SECRET_KEY), + oct_key, ) log.info( f"Reset password link requested for {hl(user.username, color=CYAN)}. Reset link: {hl(f'{ROMM_BASE_URL}/reset-password?token={token}')}" @@ -119,7 +122,7 @@ class AuthHandler: from handler.database import db_user_handler try: - payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM]) + payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM]) except (BadSignatureError, DecodeError, ValueError) as exc: raise HTTPException(status_code=400, detail="Invalid token") from exc @@ -146,12 +149,12 @@ class AuthHandler: raise HTTPException(status_code=404, detail="User not found") now = datetime.now(timezone.utc).timestamp() - if now > payload.claims.get("exp"): + if now > payload.claims.get("exp", 0.0): raise HTTPException(status_code=400, detail="Token has expired") return user - def set_user_new_password(self, user: Any, new_password: str) -> None: + async def set_user_new_password(self, user: Any, new_password: str) -> None: """ Set the new password for the user. Args: @@ -163,6 +166,7 @@ class AuthHandler: db_user_handler.update_user( user.id, {"hashed_password": self.get_password_hash(new_password)} ) + await RedisSessionMiddleware.clear_user_sessions(user.username) def generate_invite_link_token(self, user: Any, role: str) -> str: """ @@ -192,7 +196,7 @@ class AuthHandler: token = jwt.encode( {"alg": ALGORITHM}, to_encode, - OctKey.import_key(ROMM_AUTH_SECRET_KEY), + oct_key, ) invite_link = f"{ROMM_BASE_URL}/register?token={token}" log.info( @@ -212,9 +216,7 @@ class AuthHandler: str: The JTI (JWT ID) of the token. """ try: - payload = jwt.decode( - token, OctKey.import_key(ROMM_AUTH_SECRET_KEY), algorithms=[ALGORITHM] - ) + payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM]) except (BadSignatureError, DecodeError, ValueError) as exc: raise HTTPException(status_code=400, detail="Invalid token") from exc @@ -256,16 +258,14 @@ class OAuthHandler: return jwt.encode( {"alg": ALGORITHM}, to_encode, - OctKey.import_key(ROMM_AUTH_SECRET_KEY), + oct_key, ) async def get_current_active_user_from_bearer_token(self, token: str): from handler.database import db_user_handler try: - payload = jwt.decode( - token, OctKey.import_key(ROMM_AUTH_SECRET_KEY), algorithms=[ALGORITHM] - ) + payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM]) except (BadSignatureError, DecodeError, ValueError) as exc: raise OAuthCredentialsException from exc diff --git a/backend/handler/auth/middleware/redis_session_middleware.py b/backend/handler/auth/middleware/redis_session_middleware.py new file mode 100644 index 000000000..90793ee17 --- /dev/null +++ b/backend/handler/auth/middleware/redis_session_middleware.py @@ -0,0 +1,92 @@ +import json +import uuid + +from starlette.datastructures import MutableHeaders +from starlette.requests import HTTPConnection +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from config import SESSION_MAX_AGE_SECONDS +from handler.redis_handler import async_cache + + +class RedisSessionMiddleware: + def __init__( + self, + app: ASGIApp, + session_cookie: str = "session", + max_age: int = SESSION_MAX_AGE_SECONDS, + same_site: str = "lax", + https_only: bool = False, + ) -> None: + self.app = app + self.session_cookie = session_cookie + self.max_age = max_age + self.security_flags = "httponly; samesite=" + same_site + if https_only: + self.security_flags += "; secure" + + @staticmethod + async def clear_user_sessions(user_id: str) -> None: + """ + Clears all active sessions for a given user. + """ + session_ids = await async_cache.smembers(f"user_sessions:{user_id}") + if session_ids: + for session_id in session_ids: + await async_cache.delete(f"session:{session_id}") + await async_cache.delete(f"user_sessions:{user_id}") + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + connection = HTTPConnection(scope) + session_id = None # Initialize session_id to None + session_cookie_from_request = connection.cookies.get(self.session_cookie) + + if session_cookie_from_request: + session_id = session_cookie_from_request + session_data = await async_cache.get(f"session:{session_id}") + if session_data: + scope["session"] = json.loads(session_data) + scope["session"]["session_id"] = session_id + else: + scope["session"] = {} + else: + scope["session"] = {} + + async def send_wrapper(message: Message) -> None: + nonlocal session_id + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + # Check for user_id to track user-specific sessions + user_id = scope["session"].get("sub") + + if scope["session"]: + session_id = scope["session"].pop("session_id", None) or str( + uuid.uuid4() + ) # Retrieve or create session_id + session_data_json = json.dumps(scope["session"]) + await async_cache.set( + f"session:{session_id}", session_data_json, ex=self.max_age + ) + + # Add session_id to user set of sessions + if user_id: + await async_cache.sadd(f"user_sessions:{user_id}", session_id) + + header_value = f"{self.session_cookie}={session_id}; path=/; Max-Age={self.max_age}; {self.security_flags}" + headers.append("Set-Cookie", header_value) + elif session_id: + await async_cache.delete(f"session:{session_id}") + # Remove session_id from user set of sessions + if user_id: + await async_cache.srem(f"user_sessions:{user_id}", session_id) + + header_value = f"{self.session_cookie}=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; {self.security_flags}" + headers.append("Set-Cookie", header_value) + + await send(message) + + await self.app(scope, receive, send_wrapper) diff --git a/backend/main.py b/backend/main.py index 0331b78ad..04e541d9d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -42,10 +42,9 @@ from endpoints import ( tasks, user, ) -from handler.auth.constants import ALGORITHM from handler.auth.hybrid_auth import HybridAuthBackend from handler.auth.middleware.csrf_middleware import CSRFMiddleware -from handler.auth.middleware.session_middleware import SessionMiddleware +from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware from handler.socket_handler import socket_handler from logger.formatter import LOGGING_CONFIG from utils import get_version @@ -105,12 +104,10 @@ app.add_middleware( # Enables support for sessions on requests app.add_middleware( - SessionMiddleware, - secret_key=ROMM_AUTH_SECRET_KEY, + RedisSessionMiddleware, session_cookie="romm_session", same_site="lax" if OIDC_ENABLED else "strict", https_only=False, - jwt_alg=ALGORITHM, ) # Sets context vars in request-response cycle diff --git a/backend/tests/endpoints/test_identity.py b/backend/tests/endpoints/test_identity.py index 4f6dd7f20..39afbc424 100644 --- a/backend/tests/endpoints/test_identity.py +++ b/backend/tests/endpoints/test_identity.py @@ -10,9 +10,10 @@ from main import app from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES from handler.auth import oauth_handler +from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware from handler.database.users_handler import DBUsersHandler from handler.redis_handler import sync_cache -from models.user import Role +from models.user import Role, User @pytest.fixture @@ -27,7 +28,7 @@ def clear_cache(): sync_cache.flushall() -def test_login_logout(client, admin_user): +def test_login_logout(client, admin_user: User): response = client.get("/api/login") assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED @@ -45,7 +46,7 @@ def test_login_logout(client, admin_user): assert response.status_code == status.HTTP_200_OK -def test_get_all_users(client, access_token): +def test_get_all_users(client, access_token: str): response = client.get( "/api/users", headers={"Authorization": f"Bearer {access_token}"} ) @@ -56,7 +57,7 @@ def test_get_all_users(client, access_token): assert users[0]["username"] == "test_admin" -def test_get_user(client, access_token, editor_user): +def test_get_user(client, access_token: str, editor_user: User): response = client.get( f"/api/users/{editor_user.id}", headers={"Authorization": f"Bearer {access_token}"}, @@ -68,7 +69,7 @@ def test_get_user(client, access_token, editor_user): @pytest.mark.parametrize("new_user_role", [Role.VIEWER, Role.EDITOR, Role.ADMIN]) -def test_add_user_from_admin_user(client, access_token, new_user_role): +def test_add_user_from_admin_user(client, access_token: str, new_user_role: Role): response = client.post( "/api/users", json={ @@ -98,10 +99,10 @@ def test_add_user_from_admin_user(client, access_token, new_user_role): def test_add_user_from_unauthorized_user( request, client, - admin_user, - fixture_requesting_user, - existing_admin_users, - expected_status_code, + admin_user: User, + fixture_requesting_user: User, + existing_admin_users: list[User], + expected_status_code: int, ): requesting_user = request.getfixturevalue(fixture_requesting_user) @@ -133,7 +134,7 @@ def test_add_user_from_unauthorized_user( assert response.status_code == expected_status_code -def test_add_user_with_existing_username(client, access_token, admin_user): +def test_add_user_with_existing_username(client, access_token: str, admin_user: User): response = client.post( "/api/users", json={ @@ -150,7 +151,7 @@ def test_add_user_with_existing_username(client, access_token, admin_user): assert response["detail"] == f"Username {admin_user.username} already exists" -def test_update_user(client, access_token, editor_user): +def test_update_user(client, access_token: str, editor_user: User): assert editor_user.role == Role.EDITOR response = client.put( @@ -164,9 +165,85 @@ def test_update_user(client, access_token, editor_user): assert user["role"] == "viewer" -def test_delete_user(client, access_token, editor_user): +def test_delete_user(client, access_token: str, editor_user: User): response = client.delete( f"/api/users/{editor_user.id}", headers={"Authorization": f"Bearer {access_token}"}, ) - assert response.status_code == status.HTTP_200_OK + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_password_change_invalidates_sessions(client, admin_user: User): + # Get the user's session cookie + basic_auth = base64.b64encode( + f"{admin_user.username}:test_admin_password".encode("ascii") + ).decode("ascii") + response = client.post( + "/api/login", headers={"Authorization": f"Basic {basic_auth}"} + ) + assert response.status_code == HTTPStatus.OK + old_session_cookie = response.cookies.get("romm_session") + assert old_session_cookie is not None + + # Verify session works + response = client.get("/api/users/me", cookies={"romm_session": old_session_cookie}) + assert response.status_code == HTTPStatus.OK + + # Update the user's password + response = client.put( + f"/api/users/{admin_user.id}", + data={"password": "new_admin_password"}, + headers={"Authorization": f"Basic {basic_auth}"}, + ) + assert response.status_code == HTTPStatus.OK + + # Attempt to access a protected resource using the old session cookie + response = client.get("/api/users/me", cookies={"romm_session": old_session_cookie}) + assert response.status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN] + + # Login with the new credentials + basic_auth_new = base64.b64encode( + f"{admin_user.username}:new_admin_password".encode("ascii") + ).decode("ascii") + response = client.post( + "/api/login", headers={"Authorization": f"Basic {basic_auth_new}"} + ) + assert response.status_code == HTTPStatus.OK + new_session_cookie = response.cookies.get("romm_session") + assert new_session_cookie is not None + assert new_session_cookie != old_session_cookie + + # Attempt to access a protected resource using the new session cookie + response = client.get("/api/users/me", cookies={"romm_session": new_session_cookie}) + assert response.status_code == HTTPStatus.OK + + await RedisSessionMiddleware.clear_user_sessions(admin_user.username) + + +@pytest.mark.asyncio +async def test_logout_invalidates_session(client, admin_user: User): + # Get the user's session cookie + basic_auth = base64.b64encode( + f"{admin_user.username}:test_admin_password".encode("ascii") + ).decode("ascii") + response = client.post( + "/api/login", headers={"Authorization": f"Basic {basic_auth}"} + ) + assert response.status_code == HTTPStatus.OK + session_cookie = response.cookies.get("romm_session") + assert session_cookie is not None + + # Verify session works + response = client.get("/api/users/me", cookies={"romm_session": session_cookie}) + assert response.status_code == HTTPStatus.OK + + # Log out the user + response = client.post("/api/logout", cookies={"romm_session": session_cookie}) + assert response.status_code == HTTPStatus.OK + + # Attempt to access a protected resource using the old session cookie + response = client.get("/api/users/me", cookies={"romm_session": session_cookie}) + assert response.status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN] + + await RedisSessionMiddleware.clear_user_sessions(admin_user.username)