Fix db links & some other db issues

This addresses several issues regarding the Beanie library (for db
interactions with Mongo).

Most notably, this moves away from using types like:
`Annotated[T | Link[T], Annotated[Link[T], Indexed]]` to annotate link
fields, which not only looks pretty bad in the code, but also doesn't
work and instead creates copies, not references, which was causing a
bunch of issues in the code.

The rest of the issues stem from fixing the above, addressing various
small bugs that turned up after some doing debugging, or new type issues
that this fix introduced (unfortunately, Beanie still doesn't have the
best typing support).
This commit is contained in:
Peter Vacho 2024-12-27 23:02:43 +01:00
parent 3903546f4c
commit 1eaef54d70
Signed by: school
GPG key ID: 8CFC3837052871B4
11 changed files with 139 additions and 51 deletions

View file

@ -120,6 +120,9 @@ async def resolve_jwt_token(token: str) -> tuple[TokenData, Token]:
log.trace(f"Got a revoked token: {token!r}")
raise InvalidTokenError(AuthErrState.REVOKED, token)
# We'll definitely end up needing the user, so we might as well fetch it now
await db_token.fetch_link(Token.user)
return decoded, db_token

View file

@ -1,8 +1,7 @@
from datetime import datetime
from typing import Annotated, final
from typing import Annotated, cast, final
from beanie import PydanticObjectId
from beanie.odm.operators.find.comparison import In
from beanie import Link, PydanticObjectId
from fastapi import APIRouter, Body, HTTPException, Response, status
from pydantic import BaseModel, StringConstraints, field_validator
from pydantic_extra_types.color import Color
@ -11,7 +10,7 @@ from src.api.auth.dependencies import LoggedInDep
from src.db.models.category import Category
from src.db.models.event import Event
from src.db.models.user import User
from src.utils.db import MissingIdError, update_document
from src.utils.db import MissingIdError, expr, update_document
from src.utils.logging import get_logger
from .auth import CurrentUserDep
@ -110,17 +109,17 @@ async def get_user_categories(user_id: PydanticObjectId, user: CurrentUserDep) -
detail="You can only access your own categories",
)
categories = await Category.find(Category.user == user).to_list()
categories = await Category.find(expr(Category.user).id == user.id).to_list()
return [CategoryData.from_category(category) for category in categories]
@categories_router.get("{category_id}")
async def get_category(category_id: PydanticObjectId, user: CurrentUserDep) -> CategoryData:
"""Get a category by ID."""
category = await Category.get(category_id)
category = await Category.get(category_id, fetch_links=True)
if category is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
if category.owner != user:
if cast(User, category.user).id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only access your own categories",
@ -138,7 +137,7 @@ async def delete_category(category_id: PydanticObjectId, user: CurrentUserDep) -
category = await Category.get(category_id)
if category is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
if category.owner != user:
if cast(Link[User], category.user).ref.id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only access your own categories",
@ -147,10 +146,9 @@ async def delete_category(category_id: PydanticObjectId, user: CurrentUserDep) -
# First remove the association to this category by events
# (this may leave the event without a category, however that's fine, as event category is optional
# and it's far more sensible solution than deleting all events with this category)
events = await Event.find(In(category, Event.categories)).to_list()
events = await Event.find(expr(Event.categories).id == category.id, fetch_links=True).to_list()
for event in events:
new_categories = list(event.categories)
new_categories.remove(category)
new_categories = [cat for cat in event.categories if cast(Category, cat).id != category.id]
event.categories = new_categories
await Event.replace_many(events)
@ -177,10 +175,10 @@ async def overwrite_category(
category_data: Annotated[CategoryCreateData, Body()],
) -> CategoryData:
"""Overwrite a specific category."""
category = await Category.get(category_id)
category = await Category.get(category_id, fetch_links=True)
if category is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
if category.owner != user:
if cast(User, category.user).id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only access your own categories",
@ -197,10 +195,10 @@ async def update_category(
category_data: Annotated[PartialCategoryUpdateData, Body()],
) -> CategoryData:
"""Overwrite a specific category."""
category = await Category.get(category_id)
category = await Category.get(category_id, fetch_links=True)
if category is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
if category.owner != user:
if cast(User, category.user).id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only access your own categories",

View file

@ -1,7 +1,7 @@
from datetime import date, datetime, time
from typing import Annotated, final
from beanie import PydanticObjectId
from beanie import DeleteRules, Link, PydanticObjectId
from fastapi import APIRouter, Body, HTTPException, status
from fastapi.responses import Response
from pydantic import BaseModel, StringConstraints, field_validator
@ -11,7 +11,7 @@ from src.api.auth.dependencies import LoggedInDep
from src.db.models.category import Category
from src.db.models.event import Event
from src.db.models.user import User
from src.utils.db import from_id_list, get_id_list, update_document
from src.utils.db import UnfetchedLinkError, expr, from_id_list, get_id_list, update_document
from src.utils.logging import get_logger
from .auth import CurrentUserDep
@ -55,6 +55,9 @@ class EventData(_BaseEventData):
@classmethod
def from_event(cls, event: Event) -> "EventData":
if isinstance(event.user, Link):
raise UnfetchedLinkError(event.user)
if event.user.id is None:
raise ValueError("Got an event without owner user ID")
@ -89,13 +92,13 @@ class EventCreateData(_BaseEventData):
user=user,
title=self.title,
description=self.description,
categories=await from_id_list(self.category_ids, Category),
categories=await from_id_list(self.category_ids, Category, link_return=True),
start_date=self.start_date,
start_time=self.start_time,
end_date=self.end_date,
end_time=self.end_time,
color=self.color.as_hex(format="long"),
attendees=await from_id_list(self.attendee_ids, User),
attendees=await from_id_list(self.attendee_ids, User, link_return=True),
)
return await event.create()
@ -165,7 +168,7 @@ async def get_user_events(user_id: PydanticObjectId, user: CurrentUserDep) -> li
detail="You can only access your own categories",
)
events = await (Event.find(Event.user == user, fetch_links=True)).to_list()
events = await (Event.find(expr(Event.user).id == user.id, fetch_links=True)).to_list()
return [EventData.from_event(event) for event in events]
@ -196,9 +199,7 @@ async def delete_event(event_id: PydanticObjectId, user: CurrentUserDep) -> Resp
detail="You can only access your own events",
)
# TODO: First remove any associations to this event
_ = await event.delete()
_ = await event.delete(link_rules=DeleteRules.DELETE_LINKS)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@ -220,7 +221,7 @@ async def overwrite_event(
event_data: Annotated[EventCreateData, Body()],
) -> EventData:
"""Overwrite a specific event."""
event = await Event.get(event_id)
event = await Event.get(event_id, fetch_links=True)
if event is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Event with given id doesn't exist")
if event.owner != user:
@ -241,7 +242,7 @@ async def partial_update_event(
) -> EventData:
"""Partially update an event."""
# 1. Fetch the event
event = await Event.get(event_id)
event = await Event.get(event_id, fetch_links=True)
if event is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel
from src.api.auth.jwt import invalidate_access_tokens
from src.db.models.token import Token
from src.utils.db import MissingIdError, UnfetchedLinkError
from src.utils.db import MissingIdError, UnfetchedLinkError, expr
from src.utils.logging import get_logger
from .auth import AnyTokenDep, CurrentUserDep
@ -68,7 +68,7 @@ async def get_user_sessions(user_id: PydanticObjectId, user: CurrentUserDep) ->
if user.id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can only see your own sessions")
tokens = await Token.find(Token.user == user).to_list()
tokens = await Token.find(expr(Token.user).id == user.id, fetch_links=True).to_list()
return [UserSession.from_token(token) for token in tokens]
@ -76,13 +76,18 @@ async def get_user_sessions(user_id: PydanticObjectId, user: CurrentUserDep) ->
async def get_current_session(token_data: AnyTokenDep) -> UserSession:
"""Get information about the current session."""
_, db_token = token_data
# The token dependency does already fetch the user, however we'll also need
# the parent token data, so let's fetch that here
await db_token.fetch_link(Token.parent_token)
return UserSession.from_token(db_token)
@sessions_router.get("/{session_id}")
async def get_session(session_id: PydanticObjectId, user: CurrentUserDep) -> UserSession:
"""Get details about given session."""
token = await Token.get(session_id)
token = await Token.get(session_id, fetch_links=True)
if token is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No such session")
if token.user != user:
@ -96,7 +101,7 @@ async def delete_session(session_id: PydanticObjectId, user: CurrentUserDep) ->
If this is the active session, this will log you out.
"""
token = await Token.get(session_id)
token = await Token.get(session_id, fetch_links=True)
if token is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No such session")
if token.user != user:

View file

@ -1,17 +1,18 @@
from datetime import UTC, datetime
from typing import Annotated, ClassVar, final
from beanie import Document, Indexed, Link
from beanie import Document, Indexed
from pydantic import Field
from src.db.models.user import User
from src.db.types import BeanieLink
@final
class Category(Document):
"""Category table."""
user: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
user: Annotated[BeanieLink[User], Indexed()]
name: str # TODO: Should this be unique?
color: str # Stored as a hex string (# + 6 characters)
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View file

@ -1,28 +1,28 @@
from collections.abc import Sequence
from datetime import UTC, date, datetime, time
from typing import Annotated, ClassVar, final
from beanie import Document, Indexed, Link
from beanie import Document, Indexed
from pydantic import Field
from src.db.models.category import Category
from src.db.models.user import User
from src.db.types import BeanieLink
@final
class Event(Document):
"""Event table."""
user: Annotated[User, Annotated[Link[User], Indexed()]]
user: Annotated[BeanieLink[User], Indexed()]
title: str
description: str
categories: Annotated[Sequence[Category | Link[Category]], Annotated[list[Link[Category]], Indexed()]]
categories: Annotated[list[BeanieLink[Category]], Indexed()]
start_date: Annotated[date, Indexed()]
start_time: time
end_date: Annotated[date, Indexed()]
end_time: time
color: str # Stored as a hex string (# + 6 characters)
attendees: Annotated[Sequence[User | Link[User]], Annotated[list[Link[User]], Indexed()]]
attendees: Annotated[list[BeanieLink[User]], Indexed()]
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@final

View file

@ -1,19 +1,20 @@
from datetime import UTC, datetime
from typing import Annotated, ClassVar, Literal, final
from beanie import Document, Indexed, Link
from beanie import Document, Indexed
from pydantic import Field
from src.db.models.event import Event
from src.db.models.user import User
from src.db.types import BeanieLink
@final
class Invitation(Document):
"""Invitation table."""
event: Annotated[Event | Link[User], Annotated[Link[Event], Indexed()]]
invitee: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
event: Annotated[BeanieLink[Event], Indexed()]
invitee: Annotated[BeanieLink[User], Indexed()]
status: Literal["accepted", "declined", "pending"] = "pending"
sent_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
responded_at: datetime | None = None

View file

@ -1,17 +1,18 @@
from datetime import UTC, datetime
from typing import Annotated, Any, ClassVar, Literal, final
from beanie import Document, Indexed, Link
from beanie import Document, Indexed
from pydantic import Field
from src.db.models.user import User
from src.db.types import BeanieLink
@final
class Notification(Document):
"""Notification table."""
user: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
user: Annotated[BeanieLink[User], Indexed()]
event_type: Annotated[Literal["reminder", "invitation"], Indexed()]
message: str
data: Any

View file

@ -1,20 +1,21 @@
from datetime import UTC, datetime
from typing import Annotated, ClassVar, Literal, final
from beanie import Document, Indexed, Link
from beanie import Document, Indexed
from pydantic import Field
from pymongo import ASCENDING
from src.db.models.user import User
from src.db.types import BeanieLink
@final
class Token(Document):
"""Token table."""
user: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
user: Annotated[BeanieLink[User], Indexed()]
tok_type: Literal["access", "refresh"]
parent_token: Annotated["Token | Link[Token] | None", Annotated[Link["Token"] | None, Indexed()]] = None
parent_token: Annotated[BeanieLink["Token"] | None, Indexed()] = None
revoked: bool = False
issued_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
expires_at: datetime

18
src/db/types.py Normal file
View file

@ -0,0 +1,18 @@
from typing import TYPE_CHECKING
from beanie import Link
__all__ = ["BeanieLink"]
# This is a workaround to get Beanie dynamic type inspection it uses for links to work well
# with type checkers. This cannot be worked around with Annotated, as beanie doesn't properly
# support this in many cases. This is a known issue with beanie, and has this temporary
# workaround. See:
#
# https://github.com/BeanieODM/beanie/issues/820#issuecomment-2563964071
if TYPE_CHECKING:
type BeanieLink[T] = T | Link[T]
else:
BeanieLink = Link

View file

@ -1,7 +1,10 @@
from collections.abc import Sequence
from typing import Literal, final, overload
from typing import Any, Literal, final, overload
from beanie import Document, Link, PydanticObjectId
from beanie.odm.fields import ExpressionField
from src.db.types import BeanieLink
@final
@ -57,31 +60,67 @@ def get_id_list[T: Document](items: Sequence[T | Link[T]]) -> list[PydanticObjec
async def from_id_list[T: Document](
id_list: Sequence[PydanticObjectId],
typ: type[T],
*,
raise_on_missing: Literal[True] = True,
) -> list[T]: ...
link_return: Literal[True],
) -> list[BeanieLink[T]]: ...
@overload
async def from_id_list[T: Document, U](
id_list: Sequence[PydanticObjectId],
typ: type[T],
*,
missing_obj: U = None,
raise_on_missing: Literal[False],
link_return: Literal[True],
) -> list[BeanieLink[T] | U]: ...
@overload
async def from_id_list[T: Document](
id_list: Sequence[PydanticObjectId],
typ: type[T],
*,
raise_on_missing: Literal[True] = True,
link_return: Literal[False] = False,
) -> list[T]: ...
@overload
async def from_id_list[T: Document, U](
id_list: Sequence[PydanticObjectId],
typ: type[T],
*,
missing_obj: U = None,
raise_on_missing: Literal[False],
) -> list[T | None]: ...
link_return: Literal[False] = False,
) -> list[T | U]: ...
async def from_id_list[T: Document](
id_list: Sequence[PydanticObjectId],
typ: type[T],
*,
missing_obj: object = None,
raise_on_missing: bool = True,
) -> list[T] | list[T | None]:
"""Constructs a list of db objects from a list of ids."""
items: list[T | None] = []
link_return: bool = False, # pyright: ignore[reportUnusedParameter] # typing-only parameter
) -> list[Any]:
"""Constructs a list of db objects from a list of ids.
:param id_list: List of ids to fetch.
:param typ: Type of the objects to fetch.
:param missing_obj: Object to append when the db object is not found.
:param raise_on_missing: If True, raises an exception when an object is not found.
:param link_return: If True, returns a list of links instead of objects.
"""
items: list[Any] = []
for item_id in id_list:
item = await typ.get(item_id)
if item is None:
if raise_on_missing:
raise IdNotFoundError(item_id, typ)
items.append(None)
items.append(missing_obj)
continue
items.append(item)
return items
@ -95,9 +134,29 @@ async def update_document[T: Document](document: T, **kwargs) -> T:
updated = False
for field_name, new_value in kwargs.items():
old_value = getattr(document, field_name)
if isinstance(old_value, Link):
raise UnfetchedLinkError(old_value)
if old_value != new_value:
setattr(document, field_name, new_value)
updated = True
if updated:
document = await document.replace()
return document
def expr(ex: Any) -> ExpressionField:
"""Convert type to ExpressionField.
When accessing Beanie documents, there is a difference between attempting to access
an attribute of the document class type itself and an attribute of the document instance.
Unfortunately, Beanie doesn't provide a way to distinguish between these two cases type-wise
and misinterprets the former as an attempt to access an attribute of the document instance,
which results in type checker using an incorrect type, often leading to type errors.
This function is essentially just a cast to ExpressionField, with a runtime isinstance check
as a safeguard against accidental misuse.
"""
if not isinstance(ex, ExpressionField):
raise TypeError(f"Expected ExpressionField, got {type(ex).__name__}")
return ex