Add early auth support

This commit is contained in:
Peter Vacho 2024-11-27 23:16:42 +01:00
parent 65b68c04a4
commit 6f03d37ee9
Signed by: school
GPG key ID: 8CFC3837052871B4
12 changed files with 427 additions and 0 deletions

View file

@ -12,6 +12,9 @@ dependencies = [
"motor[srv]>=3.6.0",
"beanie>=1.27.0",
"pydantic[email]>=2.10.2",
"python-jose>=3.3.0",
"python-multipart>=0.0.17",
"bcrypt>=4.2.1",
]
readme = "README.md"
requires-python = ">= 3.12"

View file

@ -15,6 +15,8 @@ annotated-types==0.7.0
anyio==4.6.2.post1
# via starlette
basedpyright==1.22.0
bcrypt==4.2.1
# via event-management
beanie==1.27.0
# via event-management
cfgv==3.4.0
@ -29,6 +31,8 @@ distlib==0.3.9
dnspython==2.7.0
# via email-validator
# via pymongo
ecdsa==0.19.0
# via python-jose
email-validator==2.2.0
# via pydantic
fastapi==0.115.5
@ -60,6 +64,9 @@ platformdirs==4.3.6
poethepoet==0.31.1
# via event-management
pre-commit==4.0.1
pyasn1==0.6.1
# via python-jose
# via rsa
pydantic==2.10.2
# via beanie
# via event-management
@ -71,10 +78,18 @@ pymongo==4.9.2
# via motor
python-decouple==3.8
# via event-management
python-jose==3.3.0
# via event-management
python-multipart==0.0.17
# via event-management
pyyaml==6.0.2
# via poethepoet
# via pre-commit
rsa==4.9
# via python-jose
ruff==0.8.0
six==1.16.0
# via ecdsa
sniffio==1.3.1
# via anyio
starlette==0.41.3

View file

@ -14,6 +14,8 @@ annotated-types==0.7.0
# via pydantic
anyio==4.6.2.post1
# via starlette
bcrypt==4.2.1
# via event-management
beanie==1.27.0
# via event-management
click==8.1.7
@ -24,6 +26,8 @@ coloredlogs==15.0.1
dnspython==2.7.0
# via email-validator
# via pymongo
ecdsa==0.19.0
# via python-jose
email-validator==2.2.0
# via pydantic
fastapi==0.115.5
@ -44,6 +48,9 @@ pastel==0.2.1
# via poethepoet
poethepoet==0.31.1
# via event-management
pyasn1==0.6.1
# via python-jose
# via rsa
pydantic==2.10.2
# via beanie
# via event-management
@ -55,8 +62,16 @@ pymongo==4.9.2
# via motor
python-decouple==3.8
# via event-management
python-jose==3.3.0
# via event-management
python-multipart==0.0.17
# via event-management
pyyaml==6.0.2
# via poethepoet
rsa==4.9
# via python-jose
six==1.16.0
# via ecdsa
sniffio==1.3.1
# via anyio
starlette==0.41.3

3
src/api/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .app import app
__all__ = ["app"]

View file

@ -11,6 +11,8 @@ from fastapi.responses import HTMLResponse, PlainTextResponse
from src.db.init import initialize_db
from src.utils.logging import get_logger
from .auth import router as auth_router
log = get_logger(__name__)
__all__ = ["app"]
@ -48,6 +50,8 @@ app = FastAPI(
docs_url=None,
)
app.include_router(auth_router)
@app.get("/ping")
async def ping(_r: Request) -> PlainTextResponse:

4
src/api/auth/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .api import router
from .dependencies import AccessTokenDep, AnyTokenDep, CurrentUserDep, LoggedInDep, RefreshTokenDep
__all__ = ["AccessTokenDep", "AnyTokenDep", "CurrentUserDep", "LoggedInDep", "RefreshTokenDep", "router"]

120
src/api/auth/api.py Normal file
View file

@ -0,0 +1,120 @@
from typing import Annotated, Literal
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import Response
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from src.db.models.user import User
from src.settings import ACCESS_TOKEN_EXPIRATION, REFRESH_TOKEN_EXPIRATION
from src.utils.logging import get_logger
from .dependencies import RefreshTokenDep
from .jwt import create_jwt_token, invalidate_access_tokens
from .passwords import check_hashed_password
log = get_logger(__name__)
router = APIRouter(tags=["Authentication"], prefix="/auth")
class AccessTokenResponseBase(BaseModel):
"""Stores the JWT access token data sent in a response."""
access_token: str
expires_in: int
token_type: Literal["bearer"] = "bearer"
user_id: str
class RefreshTokenResponse(AccessTokenResponseBase):
"""Stores the JWT refresh token data sent in a response."""
refresh_token: str
@router.post(
"/login",
responses={
401: {"description": "Incorrect username or password"},
200: {"description": "Login successful"},
},
)
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> RefreshTokenResponse:
"""Endpoint that provides a way for the user to authenticate, obtaining a Bearer token."""
user = await User.find_one(User.username == form_data.username)
if user is None:
log.trace(f"Authentication failed: user {form_data.username} wasn't found")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
if user.id is None:
raise RuntimeError("Never")
if not check_hashed_password(form_data.password, user.password_hash):
log.trace(f"Authentication failed: wrong password for user {form_data.username}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
refresh_token, db_refresh_token = await create_jwt_token(user, "refresh", scopes=form_data.scopes)
access_token, _ = await create_jwt_token(user, "access", parent_token=db_refresh_token, scopes=form_data.scopes)
return RefreshTokenResponse(
refresh_token=refresh_token,
access_token=access_token,
expires_in=REFRESH_TOKEN_EXPIRATION,
user_id=str(user.id),
)
@router.post(
"/refresh",
responses={200: {"description": "Token refreshed"}},
)
async def refresh(cur_token: RefreshTokenDep) -> AccessTokenResponseBase:
"""Endpoint that allows generating a new access token.
This endpoint requires a refresh token as the bearer.
If a valid access token already exists for this refresh token, it will be invalidated
in favor of the new access token returned from this endpoint.
"""
cur_token_data, cur_db_token = cur_token
await invalidate_access_tokens(cur_db_token)
if cur_db_token.user.id is None:
raise RuntimeError("Never")
# We can now be certain there aren't any existing valid access tokens made from this refresh token.
# Let's create a new access token.
scopes = cur_token_data.get("scope")
scopes = scopes.split(" ") if scopes else []
token, _ = await create_jwt_token(cur_db_token.user, "access", parent_token=cur_db_token, scopes=scopes)
return AccessTokenResponseBase(
access_token=token, expires_in=ACCESS_TOKEN_EXPIRATION, user_id=str(cur_db_token.user.id)
)
@router.post(
"/logout",
responses={200: {"description": "Refresh token invalidated"}},
)
async def logout(cur_token: RefreshTokenDep) -> Response:
"""Endpoint that performs a user logout, invalidating the refresh token & all it's access tokens.
This endpoint requires a refresh token as the bearer.
"""
_, cur_db_token = cur_token
await invalidate_access_tokens(cur_db_token)
cur_db_token.revoked = True
_ = await cur_db_token.replace()
return Response(status_code=status.HTTP_200_OK)

View file

@ -0,0 +1,97 @@
from typing import Annotated, Literal, TypeAlias
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from src.db.models.token import Token
from src.db.models.user import User
from src.utils.logging import get_logger
from .jwt import InvalidTokenError, TokenData, resolve_jwt_token
__all__ = ["AccessTokenDep", "AnyTokenDep", "CurrentUserDep", "LoggedInDep", "RefreshTokenDep"]
log = get_logger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
async def _get_token(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: SecurityScopes,
tok_type: Literal["access", "refresh"] | None,
) -> tuple[TokenData, Token]:
"""FastAPI dependency to get the JWT bearer token."""
def make_exc(detail: str) -> HTTPException:
if security_scopes.scopes:
authenticate_header_val = f'Bearer scope="{security_scopes.scope_str}"'
else:
authenticate_header_val = "Bearer"
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
headers={"WWW-Authenticate": authenticate_header_val},
)
try:
token_data, db_token = await resolve_jwt_token(token)
except InvalidTokenError as exc:
raise make_exc(exc.state.value)
if tok_type is not None and token_data["type"] != tok_type:
raise make_exc(f"Invalid token type, expected {tok_type} token, but found {token_data['type']} token")
scopes = token_data.get("scope")
scopes = scopes.split(" ") if scopes else []
for scope in security_scopes.scopes:
if scope not in scopes:
raise make_exc(f"The provided token is missing the {scope!r} scope")
return token_data, db_token
async def _get_refresh_token(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: SecurityScopes,
) -> tuple[TokenData, Token]:
"""FastAPI dependency to get the refresh JWT bearer token."""
return await _get_token(token, security_scopes=security_scopes, tok_type="refresh")
async def _get_access_token(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: SecurityScopes,
) -> tuple[TokenData, Token]:
"""FastAPI dependency to get the access JWT bearer token."""
return await _get_token(token, security_scopes=security_scopes, tok_type="access")
async def _get_any_token(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: SecurityScopes,
) -> tuple[TokenData, Token]:
"""FastAPI dependency to get the access or refresh JWT bearer token."""
return await _get_token(token, security_scopes=security_scopes, tok_type=None)
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
security_scopes: SecurityScopes,
) -> User:
"""FastAPI dependency to get the currently logged in user from the JWT bearer access token."""
_, db_token = await _get_access_token(token, security_scopes=security_scopes)
return db_token.user
# NOTE: When using these aliases, it's not possible to specify the scopes, as they are
# integrated into the type. We don't currently use scopes, so this is fine. If scopes
# will be used in the future, get_current_user will need to be used manually.
LoggedInDep = Security(get_current_user)
CurrentUserDep: TypeAlias = Annotated[User, Security(get_current_user)] # noqa: UP040: https://github.com/fastapi/fastapi/issues/10719
AccessTokenDep: TypeAlias = Annotated[tuple[TokenData, Token], Security(_get_access_token)] # noqa: UP040
RefreshTokenDep: TypeAlias = Annotated[tuple[TokenData, Token], Security(_get_refresh_token)] # noqa: UP040
AnyTokenDep: TypeAlias = Annotated[tuple[TokenData, Token], Security(_get_any_token)] # noqa: UP040

141
src/api/auth/jwt.py Normal file
View file

@ -0,0 +1,141 @@
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from typing import Literal, NotRequired, TypedDict, cast, final
from beanie import PydanticObjectId
from jose import JWTError, jwt
from src.db.models import user
from src.db.models.token import Token
from src.settings import ACCESS_TOKEN_EXPIRATION, JWT_SECRET_TOKEN
from src.utils.logging import get_logger
__all__ = ["AuthErrState", "InvalidTokenError", "create_jwt_token", "invalidate_access_tokens", "resolve_jwt_token"]
log = get_logger(__name__)
@final
class TokenData(TypedDict):
"""Payload data stored in the session token JWT."""
sub: str # token id
iat: int # issued at timestamp
exp: int # expiration timestamp
scope: NotRequired[str]
type: Literal["access", "refresh"]
@final
class AuthErrState(StrEnum):
"""Authentication error states."""
MALFORMED = "The provided token is malformed."
EXPIRED = "The provided token is expired."
UNTRACKED = "The provided token isn't tracked by the server"
REVOKED = "The provided token was revoked"
@final
class InvalidTokenError(Exception):
"""Exception raised when given JWT token was invalid."""
def __init__(self, state: AuthErrState, raw_token: str):
self.state = state
self.raw_token = raw_token
super().__init__(state.value)
async def create_jwt_token(
user: user.User,
tok_type: Literal["access", "refresh"],
*,
parent_token: Token | None = None,
scopes: list[str] | None = None,
) -> tuple[str, Token]:
"""Create a new JWT token for given ``user``.
This will create a signed token and store it in the database.
"""
issued_at = datetime.now(UTC)
expires_at = issued_at + timedelta(seconds=ACCESS_TOKEN_EXPIRATION)
await user.sync()
db_token = Token(
user=user,
tok_type=tok_type,
parent_token=parent_token,
revoked=False,
issued_at=issued_at,
expires_at=expires_at,
)
db_token = await db_token.insert()
if not isinstance(db_token.id, PydanticObjectId):
raise RuntimeError("Never") # noqa: TRY004
data: TokenData = {
"sub": str(db_token.id),
"iat": round(issued_at.timestamp()),
"exp": round(expires_at.timestamp()),
"type": tok_type,
}
if scopes:
for scope in scopes:
if " " in scope:
raise ValueError(f"Scope names can't contain spaces. Scope {scope!r} is invalid")
data["scope"] = " ".join(scopes)
return jwt.encode(dict(data), JWT_SECRET_TOKEN, algorithm="HS256"), db_token
async def resolve_jwt_token(token: str) -> tuple[TokenData, Token]:
"""Resolve the user from given JWT token.
If the JWT token isn't valid, :exc:`InvalidTokenError` will be raised.
This will also perform a database lookup, making sure that this token does in fact exist.
"""
try:
decoded = cast(TokenData, jwt.decode(token, JWT_SECRET_TOKEN, algorithms=["HS256"])) # pyright: ignore[reportInvalidCast]
except JWTError:
log.trace(f"Got a malformed token: {token!r}")
raise InvalidTokenError(AuthErrState.MALFORMED, token)
if decoded["exp"] < datetime.now(UTC).timestamp():
log.trace(f"Got an expired token: {token!r}")
raise InvalidTokenError(AuthErrState.EXPIRED, token)
db_token = await Token.get(decoded["sub"])
if db_token is None:
log.warning(f"Got an untracked signed token: {token!r}")
raise InvalidTokenError(AuthErrState.UNTRACKED, token)
if db_token.revoked:
log.trace(f"Got a revoked token: {token!r}")
raise InvalidTokenError(AuthErrState.REVOKED, token)
return decoded, db_token
async def invalidate_access_tokens(refresh_token: Token) -> None:
"""Invalidate all access tokens associated with given refresh token."""
tokens = await Token.find(Token.tok_type == "access", Token.parent_token == refresh_token).to_list()
now = datetime.now(UTC)
for token in tokens:
# Delete any expired access tokens, we won't need them anymore
if token.expires_at < now:
_ = await token.delete()
continue
# If there is a non-expired but already revoked token, ignore it
# (we can't delete it as we want proper error message if it's still used,
# it will be cleared once it expires)
if token.revoked:
continue
# If there is an active token, revoke it.
token.revoked = True
_ = await token.replace()

18
src/api/auth/passwords.py Normal file
View file

@ -0,0 +1,18 @@
import bcrypt
def create_password_hash(password: str) -> bytes:
"""Cryptographically hash given password using a randomly generated salt.
Note that the generated bcrypt hash actually also contains the salt.
"""
salt = bcrypt.gensalt()
return bcrypt.hashpw(password=password.encode(), salt=salt)
def check_hashed_password(password: str, hashed_password: bytes) -> bool:
"""Verify that given ``password`` matches the ``hashed_password``.
Note that the bcrypt ``hashed_password`` actually already contains the salt.
"""
return bcrypt.checkpw(password=password.encode(), hashed_password=hashed_password)

View file

@ -7,3 +7,8 @@ LOG_FILE = get_config("LOG_FILE", cast=Path, default=None)
TRACE_LEVEL_FILTER = get_config("TRACE_LEVEL_FILTER", default=None)
MONGODB_URI = get_config("MONGODB_URI")
REFRESH_TOKEN_EXPIRATION = get_config("REFRESH_TOKEN_EXPIRATION", cast=int, default=3600 * 24 * 30) # 30 days
ACCESS_TOKEN_EXPIRATION = get_config("ACCESS_TOKEN_EXPIRATION", cast=int, default=3600) # 1 hour
JWT_SECRET_TOKEN = get_config("JWT_SECRET_TOKEN") # openssl rand -hex 32

View file

@ -180,5 +180,7 @@ def _setup_external_log_levels(root_log: LoggerClass) -> None: # pyright: ignor
get_logger("asyncio").setLevel(logging.INFO)
get_logger("multipart").setLevel(logging.INFO) # FastAPI HTTP forms parsing lib
get_logger("pymongo").setLevel(logging.INFO)
get_logger("passlib").setLevel(logging.INFO)
get_logger("python_multipart").setLevel(logging.INFO)
get_logger("parso").setLevel(logging.WARNING) # For usage in IPython