Add early auth support
This commit is contained in:
parent
65b68c04a4
commit
6f03d37ee9
|
@ -12,6 +12,9 @@ dependencies = [
|
|||
"motor[srv]>=3.6.0",
|
||||
"beanie>=1.27.0",
|
||||
"pydantic[email]>=2.10.2",
|
||||
"python-jose>=3.3.0",
|
||||
"python-multipart>=0.0.17",
|
||||
"bcrypt>=4.2.1",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.12"
|
||||
|
|
|
@ -15,6 +15,8 @@ annotated-types==0.7.0
|
|||
anyio==4.6.2.post1
|
||||
# via starlette
|
||||
basedpyright==1.22.0
|
||||
bcrypt==4.2.1
|
||||
# via event-management
|
||||
beanie==1.27.0
|
||||
# via event-management
|
||||
cfgv==3.4.0
|
||||
|
@ -29,6 +31,8 @@ distlib==0.3.9
|
|||
dnspython==2.7.0
|
||||
# via email-validator
|
||||
# via pymongo
|
||||
ecdsa==0.19.0
|
||||
# via python-jose
|
||||
email-validator==2.2.0
|
||||
# via pydantic
|
||||
fastapi==0.115.5
|
||||
|
@ -60,6 +64,9 @@ platformdirs==4.3.6
|
|||
poethepoet==0.31.1
|
||||
# via event-management
|
||||
pre-commit==4.0.1
|
||||
pyasn1==0.6.1
|
||||
# via python-jose
|
||||
# via rsa
|
||||
pydantic==2.10.2
|
||||
# via beanie
|
||||
# via event-management
|
||||
|
@ -71,10 +78,18 @@ pymongo==4.9.2
|
|||
# via motor
|
||||
python-decouple==3.8
|
||||
# via event-management
|
||||
python-jose==3.3.0
|
||||
# via event-management
|
||||
python-multipart==0.0.17
|
||||
# via event-management
|
||||
pyyaml==6.0.2
|
||||
# via poethepoet
|
||||
# via pre-commit
|
||||
rsa==4.9
|
||||
# via python-jose
|
||||
ruff==0.8.0
|
||||
six==1.16.0
|
||||
# via ecdsa
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
starlette==0.41.3
|
||||
|
|
|
@ -14,6 +14,8 @@ annotated-types==0.7.0
|
|||
# via pydantic
|
||||
anyio==4.6.2.post1
|
||||
# via starlette
|
||||
bcrypt==4.2.1
|
||||
# via event-management
|
||||
beanie==1.27.0
|
||||
# via event-management
|
||||
click==8.1.7
|
||||
|
@ -24,6 +26,8 @@ coloredlogs==15.0.1
|
|||
dnspython==2.7.0
|
||||
# via email-validator
|
||||
# via pymongo
|
||||
ecdsa==0.19.0
|
||||
# via python-jose
|
||||
email-validator==2.2.0
|
||||
# via pydantic
|
||||
fastapi==0.115.5
|
||||
|
@ -44,6 +48,9 @@ pastel==0.2.1
|
|||
# via poethepoet
|
||||
poethepoet==0.31.1
|
||||
# via event-management
|
||||
pyasn1==0.6.1
|
||||
# via python-jose
|
||||
# via rsa
|
||||
pydantic==2.10.2
|
||||
# via beanie
|
||||
# via event-management
|
||||
|
@ -55,8 +62,16 @@ pymongo==4.9.2
|
|||
# via motor
|
||||
python-decouple==3.8
|
||||
# via event-management
|
||||
python-jose==3.3.0
|
||||
# via event-management
|
||||
python-multipart==0.0.17
|
||||
# via event-management
|
||||
pyyaml==6.0.2
|
||||
# via poethepoet
|
||||
rsa==4.9
|
||||
# via python-jose
|
||||
six==1.16.0
|
||||
# via ecdsa
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
starlette==0.41.3
|
||||
|
|
3
src/api/__init__.py
Normal file
3
src/api/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .app import app
|
||||
|
||||
__all__ = ["app"]
|
|
@ -11,6 +11,8 @@ from fastapi.responses import HTMLResponse, PlainTextResponse
|
|||
from src.db.init import initialize_db
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import router as auth_router
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
__all__ = ["app"]
|
||||
|
@ -48,6 +50,8 @@ app = FastAPI(
|
|||
docs_url=None,
|
||||
)
|
||||
|
||||
app.include_router(auth_router)
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
async def ping(_r: Request) -> PlainTextResponse:
|
4
src/api/auth/__init__.py
Normal file
4
src/api/auth/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .api import router
|
||||
from .dependencies import AccessTokenDep, AnyTokenDep, CurrentUserDep, LoggedInDep, RefreshTokenDep
|
||||
|
||||
__all__ = ["AccessTokenDep", "AnyTokenDep", "CurrentUserDep", "LoggedInDep", "RefreshTokenDep", "router"]
|
120
src/api/auth/api.py
Normal file
120
src/api/auth/api.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
from typing import Annotated, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.db.models.user import User
|
||||
from src.settings import ACCESS_TOKEN_EXPIRATION, REFRESH_TOKEN_EXPIRATION
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .dependencies import RefreshTokenDep
|
||||
from .jwt import create_jwt_token, invalidate_access_tokens
|
||||
from .passwords import check_hashed_password
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Authentication"], prefix="/auth")
|
||||
|
||||
|
||||
class AccessTokenResponseBase(BaseModel):
|
||||
"""Stores the JWT access token data sent in a response."""
|
||||
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: Literal["bearer"] = "bearer"
|
||||
user_id: str
|
||||
|
||||
|
||||
class RefreshTokenResponse(AccessTokenResponseBase):
|
||||
"""Stores the JWT refresh token data sent in a response."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
responses={
|
||||
401: {"description": "Incorrect username or password"},
|
||||
200: {"description": "Login successful"},
|
||||
},
|
||||
)
|
||||
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> RefreshTokenResponse:
|
||||
"""Endpoint that provides a way for the user to authenticate, obtaining a Bearer token."""
|
||||
user = await User.find_one(User.username == form_data.username)
|
||||
if user is None:
|
||||
log.trace(f"Authentication failed: user {form_data.username} wasn't found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
if user.id is None:
|
||||
raise RuntimeError("Never")
|
||||
|
||||
if not check_hashed_password(form_data.password, user.password_hash):
|
||||
log.trace(f"Authentication failed: wrong password for user {form_data.username}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
refresh_token, db_refresh_token = await create_jwt_token(user, "refresh", scopes=form_data.scopes)
|
||||
|
||||
access_token, _ = await create_jwt_token(user, "access", parent_token=db_refresh_token, scopes=form_data.scopes)
|
||||
|
||||
return RefreshTokenResponse(
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
expires_in=REFRESH_TOKEN_EXPIRATION,
|
||||
user_id=str(user.id),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
responses={200: {"description": "Token refreshed"}},
|
||||
)
|
||||
async def refresh(cur_token: RefreshTokenDep) -> AccessTokenResponseBase:
|
||||
"""Endpoint that allows generating a new access token.
|
||||
|
||||
This endpoint requires a refresh token as the bearer.
|
||||
|
||||
If a valid access token already exists for this refresh token, it will be invalidated
|
||||
in favor of the new access token returned from this endpoint.
|
||||
"""
|
||||
cur_token_data, cur_db_token = cur_token
|
||||
|
||||
await invalidate_access_tokens(cur_db_token)
|
||||
|
||||
if cur_db_token.user.id is None:
|
||||
raise RuntimeError("Never")
|
||||
|
||||
# We can now be certain there aren't any existing valid access tokens made from this refresh token.
|
||||
# Let's create a new access token.
|
||||
scopes = cur_token_data.get("scope")
|
||||
scopes = scopes.split(" ") if scopes else []
|
||||
token, _ = await create_jwt_token(cur_db_token.user, "access", parent_token=cur_db_token, scopes=scopes)
|
||||
|
||||
return AccessTokenResponseBase(
|
||||
access_token=token, expires_in=ACCESS_TOKEN_EXPIRATION, user_id=str(cur_db_token.user.id)
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout",
|
||||
responses={200: {"description": "Refresh token invalidated"}},
|
||||
)
|
||||
async def logout(cur_token: RefreshTokenDep) -> Response:
|
||||
"""Endpoint that performs a user logout, invalidating the refresh token & all it's access tokens.
|
||||
|
||||
This endpoint requires a refresh token as the bearer.
|
||||
"""
|
||||
_, cur_db_token = cur_token
|
||||
|
||||
await invalidate_access_tokens(cur_db_token)
|
||||
|
||||
cur_db_token.revoked = True
|
||||
_ = await cur_db_token.replace()
|
||||
|
||||
return Response(status_code=status.HTTP_200_OK)
|
97
src/api/auth/dependencies.py
Normal file
97
src/api/auth/dependencies.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from fastapi import Depends, HTTPException, Security, status
|
||||
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
|
||||
|
||||
from src.db.models.token import Token
|
||||
from src.db.models.user import User
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .jwt import InvalidTokenError, TokenData, resolve_jwt_token
|
||||
|
||||
__all__ = ["AccessTokenDep", "AnyTokenDep", "CurrentUserDep", "LoggedInDep", "RefreshTokenDep"]
|
||||
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
||||
|
||||
|
||||
async def _get_token(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
security_scopes: SecurityScopes,
|
||||
tok_type: Literal["access", "refresh"] | None,
|
||||
) -> tuple[TokenData, Token]:
|
||||
"""FastAPI dependency to get the JWT bearer token."""
|
||||
|
||||
def make_exc(detail: str) -> HTTPException:
|
||||
if security_scopes.scopes:
|
||||
authenticate_header_val = f'Bearer scope="{security_scopes.scope_str}"'
|
||||
else:
|
||||
authenticate_header_val = "Bearer"
|
||||
|
||||
return HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
headers={"WWW-Authenticate": authenticate_header_val},
|
||||
)
|
||||
|
||||
try:
|
||||
token_data, db_token = await resolve_jwt_token(token)
|
||||
except InvalidTokenError as exc:
|
||||
raise make_exc(exc.state.value)
|
||||
|
||||
if tok_type is not None and token_data["type"] != tok_type:
|
||||
raise make_exc(f"Invalid token type, expected {tok_type} token, but found {token_data['type']} token")
|
||||
|
||||
scopes = token_data.get("scope")
|
||||
scopes = scopes.split(" ") if scopes else []
|
||||
for scope in security_scopes.scopes:
|
||||
if scope not in scopes:
|
||||
raise make_exc(f"The provided token is missing the {scope!r} scope")
|
||||
|
||||
return token_data, db_token
|
||||
|
||||
|
||||
async def _get_refresh_token(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
security_scopes: SecurityScopes,
|
||||
) -> tuple[TokenData, Token]:
|
||||
"""FastAPI dependency to get the refresh JWT bearer token."""
|
||||
return await _get_token(token, security_scopes=security_scopes, tok_type="refresh")
|
||||
|
||||
|
||||
async def _get_access_token(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
security_scopes: SecurityScopes,
|
||||
) -> tuple[TokenData, Token]:
|
||||
"""FastAPI dependency to get the access JWT bearer token."""
|
||||
return await _get_token(token, security_scopes=security_scopes, tok_type="access")
|
||||
|
||||
|
||||
async def _get_any_token(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
security_scopes: SecurityScopes,
|
||||
) -> tuple[TokenData, Token]:
|
||||
"""FastAPI dependency to get the access or refresh JWT bearer token."""
|
||||
return await _get_token(token, security_scopes=security_scopes, tok_type=None)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
security_scopes: SecurityScopes,
|
||||
) -> User:
|
||||
"""FastAPI dependency to get the currently logged in user from the JWT bearer access token."""
|
||||
_, db_token = await _get_access_token(token, security_scopes=security_scopes)
|
||||
return db_token.user
|
||||
|
||||
|
||||
# NOTE: When using these aliases, it's not possible to specify the scopes, as they are
|
||||
# integrated into the type. We don't currently use scopes, so this is fine. If scopes
|
||||
# will be used in the future, get_current_user will need to be used manually.
|
||||
LoggedInDep = Security(get_current_user)
|
||||
CurrentUserDep: TypeAlias = Annotated[User, Security(get_current_user)] # noqa: UP040: https://github.com/fastapi/fastapi/issues/10719
|
||||
|
||||
AccessTokenDep: TypeAlias = Annotated[tuple[TokenData, Token], Security(_get_access_token)] # noqa: UP040
|
||||
RefreshTokenDep: TypeAlias = Annotated[tuple[TokenData, Token], Security(_get_refresh_token)] # noqa: UP040
|
||||
AnyTokenDep: TypeAlias = Annotated[tuple[TokenData, Token], Security(_get_any_token)] # noqa: UP040
|
141
src/api/auth/jwt.py
Normal file
141
src/api/auth/jwt.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Literal, NotRequired, TypedDict, cast, final
|
||||
|
||||
from beanie import PydanticObjectId
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.db.models import user
|
||||
from src.db.models.token import Token
|
||||
from src.settings import ACCESS_TOKEN_EXPIRATION, JWT_SECRET_TOKEN
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
__all__ = ["AuthErrState", "InvalidTokenError", "create_jwt_token", "invalidate_access_tokens", "resolve_jwt_token"]
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class TokenData(TypedDict):
|
||||
"""Payload data stored in the session token JWT."""
|
||||
|
||||
sub: str # token id
|
||||
iat: int # issued at timestamp
|
||||
exp: int # expiration timestamp
|
||||
scope: NotRequired[str]
|
||||
type: Literal["access", "refresh"]
|
||||
|
||||
|
||||
@final
|
||||
class AuthErrState(StrEnum):
|
||||
"""Authentication error states."""
|
||||
|
||||
MALFORMED = "The provided token is malformed."
|
||||
EXPIRED = "The provided token is expired."
|
||||
UNTRACKED = "The provided token isn't tracked by the server"
|
||||
REVOKED = "The provided token was revoked"
|
||||
|
||||
|
||||
@final
|
||||
class InvalidTokenError(Exception):
|
||||
"""Exception raised when given JWT token was invalid."""
|
||||
|
||||
def __init__(self, state: AuthErrState, raw_token: str):
|
||||
self.state = state
|
||||
self.raw_token = raw_token
|
||||
super().__init__(state.value)
|
||||
|
||||
|
||||
async def create_jwt_token(
|
||||
user: user.User,
|
||||
tok_type: Literal["access", "refresh"],
|
||||
*,
|
||||
parent_token: Token | None = None,
|
||||
scopes: list[str] | None = None,
|
||||
) -> tuple[str, Token]:
|
||||
"""Create a new JWT token for given ``user``.
|
||||
|
||||
This will create a signed token and store it in the database.
|
||||
"""
|
||||
issued_at = datetime.now(UTC)
|
||||
expires_at = issued_at + timedelta(seconds=ACCESS_TOKEN_EXPIRATION)
|
||||
|
||||
await user.sync()
|
||||
|
||||
db_token = Token(
|
||||
user=user,
|
||||
tok_type=tok_type,
|
||||
parent_token=parent_token,
|
||||
revoked=False,
|
||||
issued_at=issued_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db_token = await db_token.insert()
|
||||
|
||||
if not isinstance(db_token.id, PydanticObjectId):
|
||||
raise RuntimeError("Never") # noqa: TRY004
|
||||
|
||||
data: TokenData = {
|
||||
"sub": str(db_token.id),
|
||||
"iat": round(issued_at.timestamp()),
|
||||
"exp": round(expires_at.timestamp()),
|
||||
"type": tok_type,
|
||||
}
|
||||
|
||||
if scopes:
|
||||
for scope in scopes:
|
||||
if " " in scope:
|
||||
raise ValueError(f"Scope names can't contain spaces. Scope {scope!r} is invalid")
|
||||
data["scope"] = " ".join(scopes)
|
||||
|
||||
return jwt.encode(dict(data), JWT_SECRET_TOKEN, algorithm="HS256"), db_token
|
||||
|
||||
|
||||
async def resolve_jwt_token(token: str) -> tuple[TokenData, Token]:
|
||||
"""Resolve the user from given JWT token.
|
||||
|
||||
If the JWT token isn't valid, :exc:`InvalidTokenError` will be raised.
|
||||
|
||||
This will also perform a database lookup, making sure that this token does in fact exist.
|
||||
"""
|
||||
try:
|
||||
decoded = cast(TokenData, jwt.decode(token, JWT_SECRET_TOKEN, algorithms=["HS256"])) # pyright: ignore[reportInvalidCast]
|
||||
except JWTError:
|
||||
log.trace(f"Got a malformed token: {token!r}")
|
||||
raise InvalidTokenError(AuthErrState.MALFORMED, token)
|
||||
|
||||
if decoded["exp"] < datetime.now(UTC).timestamp():
|
||||
log.trace(f"Got an expired token: {token!r}")
|
||||
raise InvalidTokenError(AuthErrState.EXPIRED, token)
|
||||
|
||||
db_token = await Token.get(decoded["sub"])
|
||||
if db_token is None:
|
||||
log.warning(f"Got an untracked signed token: {token!r}")
|
||||
raise InvalidTokenError(AuthErrState.UNTRACKED, token)
|
||||
|
||||
if db_token.revoked:
|
||||
log.trace(f"Got a revoked token: {token!r}")
|
||||
raise InvalidTokenError(AuthErrState.REVOKED, token)
|
||||
|
||||
return decoded, db_token
|
||||
|
||||
|
||||
async def invalidate_access_tokens(refresh_token: Token) -> None:
|
||||
"""Invalidate all access tokens associated with given refresh token."""
|
||||
tokens = await Token.find(Token.tok_type == "access", Token.parent_token == refresh_token).to_list()
|
||||
now = datetime.now(UTC)
|
||||
for token in tokens:
|
||||
# Delete any expired access tokens, we won't need them anymore
|
||||
if token.expires_at < now:
|
||||
_ = await token.delete()
|
||||
continue
|
||||
|
||||
# If there is a non-expired but already revoked token, ignore it
|
||||
# (we can't delete it as we want proper error message if it's still used,
|
||||
# it will be cleared once it expires)
|
||||
if token.revoked:
|
||||
continue
|
||||
|
||||
# If there is an active token, revoke it.
|
||||
token.revoked = True
|
||||
_ = await token.replace()
|
18
src/api/auth/passwords.py
Normal file
18
src/api/auth/passwords.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import bcrypt
|
||||
|
||||
|
||||
def create_password_hash(password: str) -> bytes:
|
||||
"""Cryptographically hash given password using a randomly generated salt.
|
||||
|
||||
Note that the generated bcrypt hash actually also contains the salt.
|
||||
"""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password=password.encode(), salt=salt)
|
||||
|
||||
|
||||
def check_hashed_password(password: str, hashed_password: bytes) -> bool:
|
||||
"""Verify that given ``password`` matches the ``hashed_password``.
|
||||
|
||||
Note that the bcrypt ``hashed_password`` actually already contains the salt.
|
||||
"""
|
||||
return bcrypt.checkpw(password=password.encode(), hashed_password=hashed_password)
|
|
@ -7,3 +7,8 @@ LOG_FILE = get_config("LOG_FILE", cast=Path, default=None)
|
|||
TRACE_LEVEL_FILTER = get_config("TRACE_LEVEL_FILTER", default=None)
|
||||
|
||||
MONGODB_URI = get_config("MONGODB_URI")
|
||||
|
||||
REFRESH_TOKEN_EXPIRATION = get_config("REFRESH_TOKEN_EXPIRATION", cast=int, default=3600 * 24 * 30) # 30 days
|
||||
ACCESS_TOKEN_EXPIRATION = get_config("ACCESS_TOKEN_EXPIRATION", cast=int, default=3600) # 1 hour
|
||||
|
||||
JWT_SECRET_TOKEN = get_config("JWT_SECRET_TOKEN") # openssl rand -hex 32
|
||||
|
|
|
@ -180,5 +180,7 @@ def _setup_external_log_levels(root_log: LoggerClass) -> None: # pyright: ignor
|
|||
get_logger("asyncio").setLevel(logging.INFO)
|
||||
get_logger("multipart").setLevel(logging.INFO) # FastAPI HTTP forms parsing lib
|
||||
get_logger("pymongo").setLevel(logging.INFO)
|
||||
get_logger("passlib").setLevel(logging.INFO)
|
||||
get_logger("python_multipart").setLevel(logging.INFO)
|
||||
|
||||
get_logger("parso").setLevel(logging.WARNING) # For usage in IPython
|
||||
|
|
Loading…
Reference in a new issue