Use MissingIdError & UnfetchedLinkError instead of python exceptions
This commit is contained in:
parent
e42cc4949b
commit
ddefbf3abd
|
@ -8,6 +8,7 @@ from pydantic import BaseModel
|
|||
|
||||
from src.db.models.user import User
|
||||
from src.settings import ACCESS_TOKEN_EXPIRATION, REFRESH_TOKEN_EXPIRATION
|
||||
from src.utils.db import MissingIdError, UnfetchedLinkError
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .dependencies import RefreshTokenDep
|
||||
|
@ -52,7 +53,7 @@ async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> R
|
|||
detail="Incorrect username or password",
|
||||
)
|
||||
if user.id is None:
|
||||
raise RuntimeError("Never")
|
||||
raise MissingIdError(user)
|
||||
|
||||
if not check_hashed_password(form_data.password, user.password_hash):
|
||||
log.trace(f"Authentication failed: wrong password for user {form_data.username}")
|
||||
|
@ -91,10 +92,10 @@ async def refresh(cur_token: RefreshTokenDep) -> AccessTokenResponseBase:
|
|||
await invalidate_access_tokens(cur_db_token)
|
||||
|
||||
if isinstance(cur_db_token.user, Link):
|
||||
raise TypeError("The user attribute must be fetched")
|
||||
raise UnfetchedLinkError(cur_db_token.user)
|
||||
|
||||
if cur_db_token.user.id is None:
|
||||
raise ValueError("Got user without id")
|
||||
raise MissingIdError(cur_db_token.user)
|
||||
|
||||
# We can now be certain there aren't any existing valid access tokens made from this refresh token.
|
||||
# Let's create a new access token.
|
||||
|
|
|
@ -6,6 +6,7 @@ from fastapi.security import OAuth2PasswordBearer, SecurityScopes
|
|||
|
||||
from src.db.models.token import Token
|
||||
from src.db.models.user import User
|
||||
from src.utils.db import UnfetchedLinkError
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .jwt import InvalidTokenError, TokenData, resolve_jwt_token
|
||||
|
@ -86,7 +87,7 @@ async def get_current_user(
|
|||
_, db_token = await _get_access_token(token, security_scopes=security_scopes)
|
||||
|
||||
if isinstance(db_token.user, Link):
|
||||
raise TypeError("The user attribute must be fetched")
|
||||
raise UnfetchedLinkError(db_token.user)
|
||||
|
||||
return db_token.user
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@ from datetime import UTC, datetime, timedelta
|
|||
from enum import StrEnum
|
||||
from typing import Literal, NotRequired, TypedDict, cast, final
|
||||
|
||||
from beanie import PydanticObjectId
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.db.models import user
|
||||
from src.db.models.token import Token
|
||||
from src.settings import ACCESS_TOKEN_EXPIRATION, JWT_SECRET_TOKEN, REFRESH_TOKEN_EXPIRATION
|
||||
from src.utils.db import MissingIdError
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
__all__ = ["AuthErrState", "InvalidTokenError", "create_jwt_token", "invalidate_access_tokens", "resolve_jwt_token"]
|
||||
|
@ -75,8 +75,8 @@ async def create_jwt_token(
|
|||
)
|
||||
db_token = await db_token.insert()
|
||||
|
||||
if not isinstance(db_token.id, PydanticObjectId):
|
||||
raise RuntimeError("Never") # noqa: TRY004
|
||||
if db_token.id is None:
|
||||
raise MissingIdError(db_token)
|
||||
|
||||
data: TokenData = {
|
||||
"sub": str(db_token.id),
|
||||
|
|
|
@ -11,7 +11,7 @@ from src.api.auth.dependencies import LoggedInDep
|
|||
from src.db.models.category import Category
|
||||
from src.db.models.event import Event
|
||||
from src.db.models.user import User
|
||||
from src.utils.db import UnfetchedLinkError, expr, from_id_list, get_id_list, update_document
|
||||
from src.utils.db import MissingIdError, UnfetchedLinkError, expr, from_id_list, get_id_list, update_document
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import CurrentUserDep
|
||||
|
@ -59,7 +59,7 @@ class EventData(_BaseEventData):
|
|||
raise UnfetchedLinkError(event.user)
|
||||
|
||||
if event.user.id is None:
|
||||
raise ValueError("Got an event without owner user ID")
|
||||
raise MissingIdError(event.user)
|
||||
|
||||
return cls(
|
||||
owner_user_id=event.user.id,
|
||||
|
|
|
@ -64,7 +64,7 @@ async def get_user_sessions(user_id: PydanticObjectId, user: CurrentUserDep) ->
|
|||
Note that you can only access your own sessions.
|
||||
"""
|
||||
if user.id is None:
|
||||
raise RuntimeError("Never")
|
||||
raise MissingIdError(user)
|
||||
if user.id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can only see your own sessions")
|
||||
|
||||
|
|
Loading…
Reference in a new issue