mirror of
https://github.com/rommapp/romm.git
synced 2026-02-18 23:42:07 +01:00
133 lines
3.6 KiB
Python
133 lines
3.6 KiB
Python
from datetime import datetime, timedelta
|
|
from typing import Optional, Final, Any
|
|
from jose import JWTError, jwt
|
|
from fastapi import HTTPException, status, Security
|
|
from fastapi.param_functions import Form
|
|
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
|
from fastapi.security.http import HTTPBasic
|
|
from fastapi.types import DecoratedCallable
|
|
from starlette.authentication import requires
|
|
|
|
from config import ROMM_AUTH_SECRET_KEY
|
|
|
|
ALGORITHM: Final = "HS256"
|
|
DEFAULT_OAUTH_TOKEN_EXPIRY: Final = 15
|
|
|
|
DEFAULT_SCOPES_MAP: Final = {
|
|
"me.read": "View your profile",
|
|
"me.write": "Modify your profile",
|
|
"roms.read": "View ROMs",
|
|
"platforms.read": "View platforms",
|
|
}
|
|
|
|
WRITE_SCOPES_MAP: Final = {
|
|
"roms.write": "Modify ROMs",
|
|
"platforms.write": "Modify platforms",
|
|
}
|
|
|
|
FULL_SCOPES_MAP: Final = {
|
|
"users.read": "View users",
|
|
"users.write": "Modify users",
|
|
}
|
|
|
|
DEFAULT_SCOPES: Final = list(DEFAULT_SCOPES_MAP.keys())
|
|
WRITE_SCOPES: Final = DEFAULT_SCOPES + list(WRITE_SCOPES_MAP.keys())
|
|
FULL_SCOPES: Final = WRITE_SCOPES + list(FULL_SCOPES_MAP.keys())
|
|
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
def create_oauth_token(data: dict, expires_delta: timedelta | None = None):
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_OAUTH_TOKEN_EXPIRY)
|
|
|
|
to_encode.update({"exp": expire})
|
|
|
|
return jwt.encode(to_encode, ROMM_AUTH_SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
async def get_current_active_user_from_bearer_token(token: str):
|
|
from handler import dbh
|
|
|
|
try:
|
|
payload = jwt.decode(token, ROMM_AUTH_SECRET_KEY, algorithms=[ALGORITHM])
|
|
except JWTError:
|
|
raise credentials_exception
|
|
|
|
username = payload.get("sub")
|
|
if username is None:
|
|
raise credentials_exception
|
|
|
|
user = dbh.get_user_by_username(username)
|
|
if user is None:
|
|
raise credentials_exception
|
|
|
|
if not user.enabled:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
|
|
)
|
|
|
|
return user, payload
|
|
|
|
|
|
class OAuth2RequestForm:
|
|
def __init__(
|
|
self,
|
|
grant_type: str = Form(default="password"),
|
|
scope: str = Form(default=""),
|
|
username: Optional[str] = Form(default=None),
|
|
password: Optional[str] = Form(default=None),
|
|
client_id: Optional[str] = Form(default=None),
|
|
client_secret: Optional[str] = Form(default=None),
|
|
refresh_token: Optional[str] = Form(default=None),
|
|
):
|
|
self.grant_type = grant_type
|
|
self.scopes = scope.split()
|
|
self.username = username
|
|
self.password = password
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.refresh_token = refresh_token
|
|
|
|
|
|
oauth2_password_bearer = OAuth2PasswordBearer(
|
|
tokenUrl="/token",
|
|
auto_error=False,
|
|
scopes={
|
|
**DEFAULT_SCOPES_MAP,
|
|
**WRITE_SCOPES_MAP,
|
|
**FULL_SCOPES_MAP,
|
|
},
|
|
)
|
|
|
|
|
|
def protected_route(
|
|
method: Any,
|
|
path: str,
|
|
scopes: list[str] = [],
|
|
**kwargs,
|
|
):
|
|
def decorator(func: DecoratedCallable):
|
|
fn = requires(scopes)(func)
|
|
return method(
|
|
path,
|
|
dependencies=[
|
|
Security(
|
|
dependency=oauth2_password_bearer,
|
|
scopes=scopes,
|
|
),
|
|
Security(dependency=HTTPBasic(auto_error=False)),
|
|
],
|
|
**kwargs,
|
|
)(fn)
|
|
|
|
return decorator
|