Fix db links & some other db issues
This addresses several issues regarding the Beanie library (for db interactions with Mongo). Most notably, this moves away from using types like: `Annotated[T | Link[T], Annotated[Link[T], Indexed]]` to annotate link fields, which not only looks pretty bad in the code, but also doesn't work and instead creates copies, not references, which was causing a bunch of issues in the code. The rest of the issues stem from fixing the above, addressing various small bugs that turned up after some doing debugging, or new type issues that this fix introduced (unfortunately, Beanie still doesn't have the best typing support).
This commit is contained in:
parent
3903546f4c
commit
1eaef54d70
|
@ -120,6 +120,9 @@ async def resolve_jwt_token(token: str) -> tuple[TokenData, Token]:
|
|||
log.trace(f"Got a revoked token: {token!r}")
|
||||
raise InvalidTokenError(AuthErrState.REVOKED, token)
|
||||
|
||||
# We'll definitely end up needing the user, so we might as well fetch it now
|
||||
await db_token.fetch_link(Token.user)
|
||||
|
||||
return decoded, db_token
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import Annotated, final
|
||||
from typing import Annotated, cast, final
|
||||
|
||||
from beanie import PydanticObjectId
|
||||
from beanie.odm.operators.find.comparison import In
|
||||
from beanie import Link, PydanticObjectId
|
||||
from fastapi import APIRouter, Body, HTTPException, Response, status
|
||||
from pydantic import BaseModel, StringConstraints, field_validator
|
||||
from pydantic_extra_types.color import Color
|
||||
|
@ -11,7 +10,7 @@ from src.api.auth.dependencies import LoggedInDep
|
|||
from src.db.models.category import Category
|
||||
from src.db.models.event import Event
|
||||
from src.db.models.user import User
|
||||
from src.utils.db import MissingIdError, update_document
|
||||
from src.utils.db import MissingIdError, expr, update_document
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import CurrentUserDep
|
||||
|
@ -110,17 +109,17 @@ async def get_user_categories(user_id: PydanticObjectId, user: CurrentUserDep) -
|
|||
detail="You can only access your own categories",
|
||||
)
|
||||
|
||||
categories = await Category.find(Category.user == user).to_list()
|
||||
categories = await Category.find(expr(Category.user).id == user.id).to_list()
|
||||
return [CategoryData.from_category(category) for category in categories]
|
||||
|
||||
|
||||
@categories_router.get("{category_id}")
|
||||
async def get_category(category_id: PydanticObjectId, user: CurrentUserDep) -> CategoryData:
|
||||
"""Get a category by ID."""
|
||||
category = await Category.get(category_id)
|
||||
category = await Category.get(category_id, fetch_links=True)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
|
||||
if category.owner != user:
|
||||
if cast(User, category.user).id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only access your own categories",
|
||||
|
@ -138,7 +137,7 @@ async def delete_category(category_id: PydanticObjectId, user: CurrentUserDep) -
|
|||
category = await Category.get(category_id)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
|
||||
if category.owner != user:
|
||||
if cast(Link[User], category.user).ref.id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only access your own categories",
|
||||
|
@ -147,10 +146,9 @@ async def delete_category(category_id: PydanticObjectId, user: CurrentUserDep) -
|
|||
# First remove the association to this category by events
|
||||
# (this may leave the event without a category, however that's fine, as event category is optional
|
||||
# and it's far more sensible solution than deleting all events with this category)
|
||||
events = await Event.find(In(category, Event.categories)).to_list()
|
||||
events = await Event.find(expr(Event.categories).id == category.id, fetch_links=True).to_list()
|
||||
for event in events:
|
||||
new_categories = list(event.categories)
|
||||
new_categories.remove(category)
|
||||
new_categories = [cat for cat in event.categories if cast(Category, cat).id != category.id]
|
||||
event.categories = new_categories
|
||||
await Event.replace_many(events)
|
||||
|
||||
|
@ -177,10 +175,10 @@ async def overwrite_category(
|
|||
category_data: Annotated[CategoryCreateData, Body()],
|
||||
) -> CategoryData:
|
||||
"""Overwrite a specific category."""
|
||||
category = await Category.get(category_id)
|
||||
category = await Category.get(category_id, fetch_links=True)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
|
||||
if category.owner != user:
|
||||
if cast(User, category.user).id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only access your own categories",
|
||||
|
@ -197,10 +195,10 @@ async def update_category(
|
|||
category_data: Annotated[PartialCategoryUpdateData, Body()],
|
||||
) -> CategoryData:
|
||||
"""Overwrite a specific category."""
|
||||
category = await Category.get(category_id)
|
||||
category = await Category.get(category_id, fetch_links=True)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category with given id doesn't exist")
|
||||
if category.owner != user:
|
||||
if cast(User, category.user).id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only access your own categories",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import date, datetime, time
|
||||
from typing import Annotated, final
|
||||
|
||||
from beanie import PydanticObjectId
|
||||
from beanie import DeleteRules, Link, PydanticObjectId
|
||||
from fastapi import APIRouter, Body, HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, StringConstraints, field_validator
|
||||
|
@ -11,7 +11,7 @@ from src.api.auth.dependencies import LoggedInDep
|
|||
from src.db.models.category import Category
|
||||
from src.db.models.event import Event
|
||||
from src.db.models.user import User
|
||||
from src.utils.db import from_id_list, get_id_list, update_document
|
||||
from src.utils.db import UnfetchedLinkError, expr, from_id_list, get_id_list, update_document
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import CurrentUserDep
|
||||
|
@ -55,6 +55,9 @@ class EventData(_BaseEventData):
|
|||
|
||||
@classmethod
|
||||
def from_event(cls, event: Event) -> "EventData":
|
||||
if isinstance(event.user, Link):
|
||||
raise UnfetchedLinkError(event.user)
|
||||
|
||||
if event.user.id is None:
|
||||
raise ValueError("Got an event without owner user ID")
|
||||
|
||||
|
@ -89,13 +92,13 @@ class EventCreateData(_BaseEventData):
|
|||
user=user,
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
categories=await from_id_list(self.category_ids, Category),
|
||||
categories=await from_id_list(self.category_ids, Category, link_return=True),
|
||||
start_date=self.start_date,
|
||||
start_time=self.start_time,
|
||||
end_date=self.end_date,
|
||||
end_time=self.end_time,
|
||||
color=self.color.as_hex(format="long"),
|
||||
attendees=await from_id_list(self.attendee_ids, User),
|
||||
attendees=await from_id_list(self.attendee_ids, User, link_return=True),
|
||||
)
|
||||
return await event.create()
|
||||
|
||||
|
@ -165,7 +168,7 @@ async def get_user_events(user_id: PydanticObjectId, user: CurrentUserDep) -> li
|
|||
detail="You can only access your own categories",
|
||||
)
|
||||
|
||||
events = await (Event.find(Event.user == user, fetch_links=True)).to_list()
|
||||
events = await (Event.find(expr(Event.user).id == user.id, fetch_links=True)).to_list()
|
||||
return [EventData.from_event(event) for event in events]
|
||||
|
||||
|
||||
|
@ -196,9 +199,7 @@ async def delete_event(event_id: PydanticObjectId, user: CurrentUserDep) -> Resp
|
|||
detail="You can only access your own events",
|
||||
)
|
||||
|
||||
# TODO: First remove any associations to this event
|
||||
|
||||
_ = await event.delete()
|
||||
_ = await event.delete(link_rules=DeleteRules.DELETE_LINKS)
|
||||
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
@ -220,7 +221,7 @@ async def overwrite_event(
|
|||
event_data: Annotated[EventCreateData, Body()],
|
||||
) -> EventData:
|
||||
"""Overwrite a specific event."""
|
||||
event = await Event.get(event_id)
|
||||
event = await Event.get(event_id, fetch_links=True)
|
||||
if event is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Event with given id doesn't exist")
|
||||
if event.owner != user:
|
||||
|
@ -241,7 +242,7 @@ async def partial_update_event(
|
|||
) -> EventData:
|
||||
"""Partially update an event."""
|
||||
# 1. Fetch the event
|
||||
event = await Event.get(event_id)
|
||||
event = await Event.get(event_id, fetch_links=True)
|
||||
if event is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||
|
||||
from src.api.auth.jwt import invalidate_access_tokens
|
||||
from src.db.models.token import Token
|
||||
from src.utils.db import MissingIdError, UnfetchedLinkError
|
||||
from src.utils.db import MissingIdError, UnfetchedLinkError, expr
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import AnyTokenDep, CurrentUserDep
|
||||
|
@ -68,7 +68,7 @@ async def get_user_sessions(user_id: PydanticObjectId, user: CurrentUserDep) ->
|
|||
if user.id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You can only see your own sessions")
|
||||
|
||||
tokens = await Token.find(Token.user == user).to_list()
|
||||
tokens = await Token.find(expr(Token.user).id == user.id, fetch_links=True).to_list()
|
||||
return [UserSession.from_token(token) for token in tokens]
|
||||
|
||||
|
||||
|
@ -76,13 +76,18 @@ async def get_user_sessions(user_id: PydanticObjectId, user: CurrentUserDep) ->
|
|||
async def get_current_session(token_data: AnyTokenDep) -> UserSession:
|
||||
"""Get information about the current session."""
|
||||
_, db_token = token_data
|
||||
|
||||
# The token dependency does already fetch the user, however we'll also need
|
||||
# the parent token data, so let's fetch that here
|
||||
await db_token.fetch_link(Token.parent_token)
|
||||
|
||||
return UserSession.from_token(db_token)
|
||||
|
||||
|
||||
@sessions_router.get("/{session_id}")
|
||||
async def get_session(session_id: PydanticObjectId, user: CurrentUserDep) -> UserSession:
|
||||
"""Get details about given session."""
|
||||
token = await Token.get(session_id)
|
||||
token = await Token.get(session_id, fetch_links=True)
|
||||
if token is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No such session")
|
||||
if token.user != user:
|
||||
|
@ -96,7 +101,7 @@ async def delete_session(session_id: PydanticObjectId, user: CurrentUserDep) ->
|
|||
|
||||
If this is the active session, this will log you out.
|
||||
"""
|
||||
token = await Token.get(session_id)
|
||||
token = await Token.get(session_id, fetch_links=True)
|
||||
if token is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No such session")
|
||||
if token.user != user:
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, ClassVar, final
|
||||
|
||||
from beanie import Document, Indexed, Link
|
||||
from beanie import Document, Indexed
|
||||
from pydantic import Field
|
||||
|
||||
from src.db.models.user import User
|
||||
from src.db.types import BeanieLink
|
||||
|
||||
|
||||
@final
|
||||
class Category(Document):
|
||||
"""Category table."""
|
||||
|
||||
user: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
|
||||
user: Annotated[BeanieLink[User], Indexed()]
|
||||
name: str # TODO: Should this be unique?
|
||||
color: str # Stored as a hex string (# + 6 characters)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
|
|
@ -1,28 +1,28 @@
|
|||
from collections.abc import Sequence
|
||||
from datetime import UTC, date, datetime, time
|
||||
from typing import Annotated, ClassVar, final
|
||||
|
||||
from beanie import Document, Indexed, Link
|
||||
from beanie import Document, Indexed
|
||||
from pydantic import Field
|
||||
|
||||
from src.db.models.category import Category
|
||||
from src.db.models.user import User
|
||||
from src.db.types import BeanieLink
|
||||
|
||||
|
||||
@final
|
||||
class Event(Document):
|
||||
"""Event table."""
|
||||
|
||||
user: Annotated[User, Annotated[Link[User], Indexed()]]
|
||||
user: Annotated[BeanieLink[User], Indexed()]
|
||||
title: str
|
||||
description: str
|
||||
categories: Annotated[Sequence[Category | Link[Category]], Annotated[list[Link[Category]], Indexed()]]
|
||||
categories: Annotated[list[BeanieLink[Category]], Indexed()]
|
||||
start_date: Annotated[date, Indexed()]
|
||||
start_time: time
|
||||
end_date: Annotated[date, Indexed()]
|
||||
end_time: time
|
||||
color: str # Stored as a hex string (# + 6 characters)
|
||||
attendees: Annotated[Sequence[User | Link[User]], Annotated[list[Link[User]], Indexed()]]
|
||||
attendees: Annotated[list[BeanieLink[User]], Indexed()]
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@final
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, ClassVar, Literal, final
|
||||
|
||||
from beanie import Document, Indexed, Link
|
||||
from beanie import Document, Indexed
|
||||
from pydantic import Field
|
||||
|
||||
from src.db.models.event import Event
|
||||
from src.db.models.user import User
|
||||
from src.db.types import BeanieLink
|
||||
|
||||
|
||||
@final
|
||||
class Invitation(Document):
|
||||
"""Invitation table."""
|
||||
|
||||
event: Annotated[Event | Link[User], Annotated[Link[Event], Indexed()]]
|
||||
invitee: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
|
||||
event: Annotated[BeanieLink[Event], Indexed()]
|
||||
invitee: Annotated[BeanieLink[User], Indexed()]
|
||||
status: Literal["accepted", "declined", "pending"] = "pending"
|
||||
sent_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
responded_at: datetime | None = None
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any, ClassVar, Literal, final
|
||||
|
||||
from beanie import Document, Indexed, Link
|
||||
from beanie import Document, Indexed
|
||||
from pydantic import Field
|
||||
|
||||
from src.db.models.user import User
|
||||
from src.db.types import BeanieLink
|
||||
|
||||
|
||||
@final
|
||||
class Notification(Document):
|
||||
"""Notification table."""
|
||||
|
||||
user: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
|
||||
user: Annotated[BeanieLink[User], Indexed()]
|
||||
event_type: Annotated[Literal["reminder", "invitation"], Indexed()]
|
||||
message: str
|
||||
data: Any
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Annotated, ClassVar, Literal, final
|
||||
|
||||
from beanie import Document, Indexed, Link
|
||||
from beanie import Document, Indexed
|
||||
from pydantic import Field
|
||||
from pymongo import ASCENDING
|
||||
|
||||
from src.db.models.user import User
|
||||
from src.db.types import BeanieLink
|
||||
|
||||
|
||||
@final
|
||||
class Token(Document):
|
||||
"""Token table."""
|
||||
|
||||
user: Annotated[User | Link[User], Annotated[Link[User], Indexed()]]
|
||||
user: Annotated[BeanieLink[User], Indexed()]
|
||||
tok_type: Literal["access", "refresh"]
|
||||
parent_token: Annotated["Token | Link[Token] | None", Annotated[Link["Token"] | None, Indexed()]] = None
|
||||
parent_token: Annotated[BeanieLink["Token"] | None, Indexed()] = None
|
||||
revoked: bool = False
|
||||
issued_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
expires_at: datetime
|
||||
|
|
18
src/db/types.py
Normal file
18
src/db/types.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from beanie import Link
|
||||
|
||||
__all__ = ["BeanieLink"]
|
||||
|
||||
|
||||
# This is a workaround to get Beanie dynamic type inspection it uses for links to work well
|
||||
# with type checkers. This cannot be worked around with Annotated, as beanie doesn't properly
|
||||
# support this in many cases. This is a known issue with beanie, and has this temporary
|
||||
# workaround. See:
|
||||
#
|
||||
# https://github.com/BeanieODM/beanie/issues/820#issuecomment-2563964071
|
||||
|
||||
if TYPE_CHECKING:
|
||||
type BeanieLink[T] = T | Link[T]
|
||||
else:
|
||||
BeanieLink = Link
|
|
@ -1,7 +1,10 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Literal, final, overload
|
||||
from typing import Any, Literal, final, overload
|
||||
|
||||
from beanie import Document, Link, PydanticObjectId
|
||||
from beanie.odm.fields import ExpressionField
|
||||
|
||||
from src.db.types import BeanieLink
|
||||
|
||||
|
||||
@final
|
||||
|
@ -57,31 +60,67 @@ def get_id_list[T: Document](items: Sequence[T | Link[T]]) -> list[PydanticObjec
|
|||
async def from_id_list[T: Document](
|
||||
id_list: Sequence[PydanticObjectId],
|
||||
typ: type[T],
|
||||
*,
|
||||
raise_on_missing: Literal[True] = True,
|
||||
) -> list[T]: ...
|
||||
link_return: Literal[True],
|
||||
) -> list[BeanieLink[T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def from_id_list[T: Document, U](
|
||||
id_list: Sequence[PydanticObjectId],
|
||||
typ: type[T],
|
||||
*,
|
||||
missing_obj: U = None,
|
||||
raise_on_missing: Literal[False],
|
||||
link_return: Literal[True],
|
||||
) -> list[BeanieLink[T] | U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def from_id_list[T: Document](
|
||||
id_list: Sequence[PydanticObjectId],
|
||||
typ: type[T],
|
||||
*,
|
||||
raise_on_missing: Literal[True] = True,
|
||||
link_return: Literal[False] = False,
|
||||
) -> list[T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def from_id_list[T: Document, U](
|
||||
id_list: Sequence[PydanticObjectId],
|
||||
typ: type[T],
|
||||
*,
|
||||
missing_obj: U = None,
|
||||
raise_on_missing: Literal[False],
|
||||
) -> list[T | None]: ...
|
||||
link_return: Literal[False] = False,
|
||||
) -> list[T | U]: ...
|
||||
|
||||
|
||||
async def from_id_list[T: Document](
|
||||
id_list: Sequence[PydanticObjectId],
|
||||
typ: type[T],
|
||||
*,
|
||||
missing_obj: object = None,
|
||||
raise_on_missing: bool = True,
|
||||
) -> list[T] | list[T | None]:
|
||||
"""Constructs a list of db objects from a list of ids."""
|
||||
items: list[T | None] = []
|
||||
link_return: bool = False, # pyright: ignore[reportUnusedParameter] # typing-only parameter
|
||||
) -> list[Any]:
|
||||
"""Constructs a list of db objects from a list of ids.
|
||||
|
||||
:param id_list: List of ids to fetch.
|
||||
:param typ: Type of the objects to fetch.
|
||||
:param missing_obj: Object to append when the db object is not found.
|
||||
:param raise_on_missing: If True, raises an exception when an object is not found.
|
||||
:param link_return: If True, returns a list of links instead of objects.
|
||||
"""
|
||||
items: list[Any] = []
|
||||
for item_id in id_list:
|
||||
item = await typ.get(item_id)
|
||||
if item is None:
|
||||
if raise_on_missing:
|
||||
raise IdNotFoundError(item_id, typ)
|
||||
items.append(None)
|
||||
items.append(missing_obj)
|
||||
continue
|
||||
items.append(item)
|
||||
return items
|
||||
|
@ -95,9 +134,29 @@ async def update_document[T: Document](document: T, **kwargs) -> T:
|
|||
updated = False
|
||||
for field_name, new_value in kwargs.items():
|
||||
old_value = getattr(document, field_name)
|
||||
if isinstance(old_value, Link):
|
||||
raise UnfetchedLinkError(old_value)
|
||||
if old_value != new_value:
|
||||
setattr(document, field_name, new_value)
|
||||
updated = True
|
||||
if updated:
|
||||
document = await document.replace()
|
||||
return document
|
||||
|
||||
|
||||
def expr(ex: Any) -> ExpressionField:
|
||||
"""Convert type to ExpressionField.
|
||||
|
||||
When accessing Beanie documents, there is a difference between attempting to access
|
||||
an attribute of the document class type itself and an attribute of the document instance.
|
||||
|
||||
Unfortunately, Beanie doesn't provide a way to distinguish between these two cases type-wise
|
||||
and misinterprets the former as an attempt to access an attribute of the document instance,
|
||||
which results in type checker using an incorrect type, often leading to type errors.
|
||||
|
||||
This function is essentially just a cast to ExpressionField, with a runtime isinstance check
|
||||
as a safeguard against accidental misuse.
|
||||
"""
|
||||
if not isinstance(ex, ExpressionField):
|
||||
raise TypeError(f"Expected ExpressionField, got {type(ex).__name__}")
|
||||
return ex
|
||||
|
|
Loading…
Reference in a new issue