mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 23:42:07 +01:00
Add new redis-backed session middleware
This commit is contained in:
@@ -109,7 +109,7 @@ FLASHPOINT_API_ENABLED: Final[bool] = safe_str_to_bool(
|
||||
HLTB_API_ENABLED: Final[bool] = safe_str_to_bool(_get_env("HLTB_API_ENABLED"))
|
||||
|
||||
# AUTH
|
||||
ROMM_AUTH_SECRET_KEY: Final[str | None] = _get_env("ROMM_AUTH_SECRET_KEY")
|
||||
ROMM_AUTH_SECRET_KEY: Final[str] = _get_env("ROMM_AUTH_SECRET_KEY", "")
|
||||
if not ROMM_AUTH_SECRET_KEY:
|
||||
raise ValueError("ROMM_AUTH_SECRET_KEY environment variable is not set!")
|
||||
|
||||
|
||||
@@ -54,7 +54,8 @@ def login(
|
||||
if not user.enabled:
|
||||
raise UserDisabledException
|
||||
|
||||
request.session.update({"iss": "romm:auth", "sub": user.username})
|
||||
request.session["iss"] = "romm:auth"
|
||||
request.session["sub"] = user.username
|
||||
|
||||
# Update last login and active times
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -262,7 +263,8 @@ async def auth_openid(request: Request):
|
||||
if not potential_user.enabled:
|
||||
raise UserDisabledException
|
||||
|
||||
request.session.update({"iss": "romm:auth", "sub": potential_user.username})
|
||||
request.session["iss"] = "romm:auth"
|
||||
request.session["sub"] = potential_user.username
|
||||
|
||||
# Update last login and active times
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -293,7 +295,7 @@ def request_password_reset(username: str = Body(..., embed=True)) -> None:
|
||||
|
||||
|
||||
@router.post("/reset-password", status_code=status.HTTP_200_OK)
|
||||
def reset_password(
|
||||
async def reset_password(
|
||||
token: str = Body(..., embed=True),
|
||||
new_password: str = Body(..., embed=True),
|
||||
) -> None:
|
||||
@@ -308,7 +310,7 @@ def reset_password(
|
||||
"""
|
||||
user = auth_handler.verify_password_reset_token(token)
|
||||
|
||||
auth_handler.set_user_new_password(user, new_password)
|
||||
await auth_handler.set_user_new_password(user, new_password)
|
||||
|
||||
log.info(
|
||||
f"Password was successfully reset for user {hl(user.username, color=CYAN)}."
|
||||
|
||||
@@ -21,11 +21,14 @@ from config import (
|
||||
from decorators.auth import oauth
|
||||
from exceptions.auth_exceptions import OAuthCredentialsException, UserDisabledException
|
||||
from handler.auth.constants import ALGORITHM, DEFAULT_OAUTH_TOKEN_EXPIRY, TokenPurpose
|
||||
from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware
|
||||
from handler.redis_handler import redis_client
|
||||
from logger.formatter import CYAN
|
||||
from logger.formatter import highlight as hl
|
||||
from logger.logger import log
|
||||
|
||||
oct_key = OctKey.import_key(ROMM_AUTH_SECRET_KEY)
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
def __init__(self) -> None:
|
||||
@@ -95,7 +98,7 @@ class AuthHandler:
|
||||
token = jwt.encode(
|
||||
{"alg": ALGORITHM},
|
||||
to_encode,
|
||||
OctKey.import_key(ROMM_AUTH_SECRET_KEY),
|
||||
oct_key,
|
||||
)
|
||||
log.info(
|
||||
f"Reset password link requested for {hl(user.username, color=CYAN)}. Reset link: {hl(f'{ROMM_BASE_URL}/reset-password?token={token}')}"
|
||||
@@ -119,7 +122,7 @@ class AuthHandler:
|
||||
from handler.database import db_user_handler
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
|
||||
except (BadSignatureError, DecodeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid token") from exc
|
||||
|
||||
@@ -146,12 +149,12 @@ class AuthHandler:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
if now > payload.claims.get("exp"):
|
||||
if now > payload.claims.get("exp", 0.0):
|
||||
raise HTTPException(status_code=400, detail="Token has expired")
|
||||
|
||||
return user
|
||||
|
||||
def set_user_new_password(self, user: Any, new_password: str) -> None:
|
||||
async def set_user_new_password(self, user: Any, new_password: str) -> None:
|
||||
"""
|
||||
Set the new password for the user.
|
||||
Args:
|
||||
@@ -163,6 +166,7 @@ class AuthHandler:
|
||||
db_user_handler.update_user(
|
||||
user.id, {"hashed_password": self.get_password_hash(new_password)}
|
||||
)
|
||||
await RedisSessionMiddleware.clear_user_sessions(user.username)
|
||||
|
||||
def generate_invite_link_token(self, user: Any, role: str) -> str:
|
||||
"""
|
||||
@@ -192,7 +196,7 @@ class AuthHandler:
|
||||
token = jwt.encode(
|
||||
{"alg": ALGORITHM},
|
||||
to_encode,
|
||||
OctKey.import_key(ROMM_AUTH_SECRET_KEY),
|
||||
oct_key,
|
||||
)
|
||||
invite_link = f"{ROMM_BASE_URL}/register?token={token}"
|
||||
log.info(
|
||||
@@ -212,9 +216,7 @@ class AuthHandler:
|
||||
str: The JTI (JWT ID) of the token.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, OctKey.import_key(ROMM_AUTH_SECRET_KEY), algorithms=[ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
|
||||
except (BadSignatureError, DecodeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid token") from exc
|
||||
|
||||
@@ -256,16 +258,14 @@ class OAuthHandler:
|
||||
return jwt.encode(
|
||||
{"alg": ALGORITHM},
|
||||
to_encode,
|
||||
OctKey.import_key(ROMM_AUTH_SECRET_KEY),
|
||||
oct_key,
|
||||
)
|
||||
|
||||
async def get_current_active_user_from_bearer_token(self, token: str):
|
||||
from handler.database import db_user_handler
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, OctKey.import_key(ROMM_AUTH_SECRET_KEY), algorithms=[ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, oct_key, algorithms=[ALGORITHM])
|
||||
except (BadSignatureError, DecodeError, ValueError) as exc:
|
||||
raise OAuthCredentialsException from exc
|
||||
|
||||
|
||||
92
backend/handler/auth/middleware/redis_session_middleware.py
Normal file
92
backend/handler/auth/middleware/redis_session_middleware.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from config import SESSION_MAX_AGE_SECONDS
|
||||
from handler.redis_handler import async_cache
|
||||
|
||||
|
||||
class RedisSessionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
session_cookie: str = "session",
|
||||
max_age: int = SESSION_MAX_AGE_SECONDS,
|
||||
same_site: str = "lax",
|
||||
https_only: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.session_cookie = session_cookie
|
||||
self.max_age = max_age
|
||||
self.security_flags = "httponly; samesite=" + same_site
|
||||
if https_only:
|
||||
self.security_flags += "; secure"
|
||||
|
||||
@staticmethod
|
||||
async def clear_user_sessions(user_id: str) -> None:
|
||||
"""
|
||||
Clears all active sessions for a given user.
|
||||
"""
|
||||
session_ids = await async_cache.smembers(f"user_sessions:{user_id}")
|
||||
if session_ids:
|
||||
for session_id in session_ids:
|
||||
await async_cache.delete(f"session:{session_id}")
|
||||
await async_cache.delete(f"user_sessions:{user_id}")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
connection = HTTPConnection(scope)
|
||||
session_id = None # Initialize session_id to None
|
||||
session_cookie_from_request = connection.cookies.get(self.session_cookie)
|
||||
|
||||
if session_cookie_from_request:
|
||||
session_id = session_cookie_from_request
|
||||
session_data = await async_cache.get(f"session:{session_id}")
|
||||
if session_data:
|
||||
scope["session"] = json.loads(session_data)
|
||||
scope["session"]["session_id"] = session_id
|
||||
else:
|
||||
scope["session"] = {}
|
||||
else:
|
||||
scope["session"] = {}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
nonlocal session_id
|
||||
if message["type"] == "http.response.start":
|
||||
headers = MutableHeaders(scope=message)
|
||||
# Check for user_id to track user-specific sessions
|
||||
user_id = scope["session"].get("sub")
|
||||
|
||||
if scope["session"]:
|
||||
session_id = scope["session"].pop("session_id", None) or str(
|
||||
uuid.uuid4()
|
||||
) # Retrieve or create session_id
|
||||
session_data_json = json.dumps(scope["session"])
|
||||
await async_cache.set(
|
||||
f"session:{session_id}", session_data_json, ex=self.max_age
|
||||
)
|
||||
|
||||
# Add session_id to user set of sessions
|
||||
if user_id:
|
||||
await async_cache.sadd(f"user_sessions:{user_id}", session_id)
|
||||
|
||||
header_value = f"{self.session_cookie}={session_id}; path=/; Max-Age={self.max_age}; {self.security_flags}"
|
||||
headers.append("Set-Cookie", header_value)
|
||||
elif session_id:
|
||||
await async_cache.delete(f"session:{session_id}")
|
||||
# Remove session_id from user set of sessions
|
||||
if user_id:
|
||||
await async_cache.srem(f"user_sessions:{user_id}", session_id)
|
||||
|
||||
header_value = f"{self.session_cookie}=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; {self.security_flags}"
|
||||
headers.append("Set-Cookie", header_value)
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
@@ -42,10 +42,9 @@ from endpoints import (
|
||||
tasks,
|
||||
user,
|
||||
)
|
||||
from handler.auth.constants import ALGORITHM
|
||||
from handler.auth.hybrid_auth import HybridAuthBackend
|
||||
from handler.auth.middleware.csrf_middleware import CSRFMiddleware
|
||||
from handler.auth.middleware.session_middleware import SessionMiddleware
|
||||
from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware
|
||||
from handler.socket_handler import socket_handler
|
||||
from logger.formatter import LOGGING_CONFIG
|
||||
from utils import get_version
|
||||
@@ -105,12 +104,10 @@ app.add_middleware(
|
||||
|
||||
# Enables support for sessions on requests
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=ROMM_AUTH_SECRET_KEY,
|
||||
RedisSessionMiddleware,
|
||||
session_cookie="romm_session",
|
||||
same_site="lax" if OIDC_ENABLED else "strict",
|
||||
https_only=False,
|
||||
jwt_alg=ALGORITHM,
|
||||
)
|
||||
|
||||
# Sets context vars in request-response cycle
|
||||
|
||||
@@ -10,9 +10,10 @@ from main import app
|
||||
|
||||
from endpoints.auth import ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
from handler.auth import oauth_handler
|
||||
from handler.auth.middleware.redis_session_middleware import RedisSessionMiddleware
|
||||
from handler.database.users_handler import DBUsersHandler
|
||||
from handler.redis_handler import sync_cache
|
||||
from models.user import Role
|
||||
from models.user import Role, User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -27,7 +28,7 @@ def clear_cache():
|
||||
sync_cache.flushall()
|
||||
|
||||
|
||||
def test_login_logout(client, admin_user):
|
||||
def test_login_logout(client, admin_user: User):
|
||||
response = client.get("/api/login")
|
||||
|
||||
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
|
||||
@@ -45,7 +46,7 @@ def test_login_logout(client, admin_user):
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
def test_get_all_users(client, access_token):
|
||||
def test_get_all_users(client, access_token: str):
|
||||
response = client.get(
|
||||
"/api/users", headers={"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
@@ -56,7 +57,7 @@ def test_get_all_users(client, access_token):
|
||||
assert users[0]["username"] == "test_admin"
|
||||
|
||||
|
||||
def test_get_user(client, access_token, editor_user):
|
||||
def test_get_user(client, access_token: str, editor_user: User):
|
||||
response = client.get(
|
||||
f"/api/users/{editor_user.id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
@@ -68,7 +69,7 @@ def test_get_user(client, access_token, editor_user):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("new_user_role", [Role.VIEWER, Role.EDITOR, Role.ADMIN])
|
||||
def test_add_user_from_admin_user(client, access_token, new_user_role):
|
||||
def test_add_user_from_admin_user(client, access_token: str, new_user_role: Role):
|
||||
response = client.post(
|
||||
"/api/users",
|
||||
json={
|
||||
@@ -98,10 +99,10 @@ def test_add_user_from_admin_user(client, access_token, new_user_role):
|
||||
def test_add_user_from_unauthorized_user(
|
||||
request,
|
||||
client,
|
||||
admin_user,
|
||||
fixture_requesting_user,
|
||||
existing_admin_users,
|
||||
expected_status_code,
|
||||
admin_user: User,
|
||||
fixture_requesting_user: User,
|
||||
existing_admin_users: list[User],
|
||||
expected_status_code: int,
|
||||
):
|
||||
requesting_user = request.getfixturevalue(fixture_requesting_user)
|
||||
|
||||
@@ -133,7 +134,7 @@ def test_add_user_from_unauthorized_user(
|
||||
assert response.status_code == expected_status_code
|
||||
|
||||
|
||||
def test_add_user_with_existing_username(client, access_token, admin_user):
|
||||
def test_add_user_with_existing_username(client, access_token: str, admin_user: User):
|
||||
response = client.post(
|
||||
"/api/users",
|
||||
json={
|
||||
@@ -150,7 +151,7 @@ def test_add_user_with_existing_username(client, access_token, admin_user):
|
||||
assert response["detail"] == f"Username {admin_user.username} already exists"
|
||||
|
||||
|
||||
def test_update_user(client, access_token, editor_user):
|
||||
def test_update_user(client, access_token: str, editor_user: User):
|
||||
assert editor_user.role == Role.EDITOR
|
||||
|
||||
response = client.put(
|
||||
@@ -164,9 +165,85 @@ def test_update_user(client, access_token, editor_user):
|
||||
assert user["role"] == "viewer"
|
||||
|
||||
|
||||
def test_delete_user(client, access_token, editor_user):
|
||||
def test_delete_user(client, access_token: str, editor_user: User):
|
||||
response = client.delete(
|
||||
f"/api/users/{editor_user.id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_change_invalidates_sessions(client, admin_user: User):
|
||||
# Get the user's session cookie
|
||||
basic_auth = base64.b64encode(
|
||||
f"{admin_user.username}:test_admin_password".encode("ascii")
|
||||
).decode("ascii")
|
||||
response = client.post(
|
||||
"/api/login", headers={"Authorization": f"Basic {basic_auth}"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
old_session_cookie = response.cookies.get("romm_session")
|
||||
assert old_session_cookie is not None
|
||||
|
||||
# Verify session works
|
||||
response = client.get("/api/users/me", cookies={"romm_session": old_session_cookie})
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Update the user's password
|
||||
response = client.put(
|
||||
f"/api/users/{admin_user.id}",
|
||||
data={"password": "new_admin_password"},
|
||||
headers={"Authorization": f"Basic {basic_auth}"},
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Attempt to access a protected resource using the old session cookie
|
||||
response = client.get("/api/users/me", cookies={"romm_session": old_session_cookie})
|
||||
assert response.status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]
|
||||
|
||||
# Login with the new credentials
|
||||
basic_auth_new = base64.b64encode(
|
||||
f"{admin_user.username}:new_admin_password".encode("ascii")
|
||||
).decode("ascii")
|
||||
response = client.post(
|
||||
"/api/login", headers={"Authorization": f"Basic {basic_auth_new}"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
new_session_cookie = response.cookies.get("romm_session")
|
||||
assert new_session_cookie is not None
|
||||
assert new_session_cookie != old_session_cookie
|
||||
|
||||
# Attempt to access a protected resource using the new session cookie
|
||||
response = client.get("/api/users/me", cookies={"romm_session": new_session_cookie})
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
await RedisSessionMiddleware.clear_user_sessions(admin_user.username)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logout_invalidates_session(client, admin_user: User):
|
||||
# Get the user's session cookie
|
||||
basic_auth = base64.b64encode(
|
||||
f"{admin_user.username}:test_admin_password".encode("ascii")
|
||||
).decode("ascii")
|
||||
response = client.post(
|
||||
"/api/login", headers={"Authorization": f"Basic {basic_auth}"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
session_cookie = response.cookies.get("romm_session")
|
||||
assert session_cookie is not None
|
||||
|
||||
# Verify session works
|
||||
response = client.get("/api/users/me", cookies={"romm_session": session_cookie})
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Log out the user
|
||||
response = client.post("/api/logout", cookies={"romm_session": session_cookie})
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Attempt to access a protected resource using the old session cookie
|
||||
response = client.get("/api/users/me", cookies={"romm_session": session_cookie})
|
||||
assert response.status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]
|
||||
|
||||
await RedisSessionMiddleware.clear_user_sessions(admin_user.username)
|
||||
|
||||
Reference in New Issue
Block a user