diff --git a/src/api/events.py b/src/api/events.py index 662fbdf..4e00dce 100644 --- a/src/api/events.py +++ b/src/api/events.py @@ -1,7 +1,8 @@ from datetime import datetime -from typing import Annotated, cast, final +from typing import Annotated, Literal, cast, final from beanie import Link, PydanticObjectId +from beanie.odm.operators.find.comparison import In from fastapi import APIRouter, Body, HTTPException, Query, status from fastapi.responses import Response from pydantic import BaseModel, StringConstraints, field_serializer, field_validator @@ -232,16 +233,21 @@ async def get_user_invited_events( start_to: Annotated[datetime | None, Query()] = None, end_from: Annotated[datetime | None, Query()] = None, end_to: Annotated[datetime | None, Query()] = None, + invite_status: Annotated[Literal["pending", "accepted"] | None, Query()] = None, ) -> list[EventData]: """Get all events that given user was invited to. - These are the events that the user has already accepted the invite for. - Events pending acceptance are not included. + These are the events that the user has already accepted the invite for + and the events that the user has a pending invite for. Optionally, it's possible to use query params to filter the events based on the date-time of the event. The dates are specified in ISO 8601 format. E.g.: /users/{user_id}/events/invited?start_from=2024-12-29T13:48:53.228234Z + + It's also possible to filter the events based on the invite status. + + E.g.: /users/{user_id}/events/invited?invite_status=accepted """ if user_id != user.id: raise HTTPException( @@ -250,7 +256,17 @@ async def get_user_invited_events( ) # Initial query (all events the user is attending) - query = Event.find(expr(Event.attendees).id == user_id) + query = Event.find() + if invite_status is None or invite_status == "accepted": + query = query.find(expr(Event.attendees).id == user_id) + if invite_status is None or invite_status == "pending": + invitations = await Invitation.find( + expr(Invitation.invitee).id == user_id, + expr(Invitation.status) == "pending", + ).to_list() + event_ids = get_id_list([invite.event for invite in invitations], allow_links=True) + + query = Event.find(In(expr(Event.id), event_ids)) # Filter by date-time if start_from is not None: diff --git a/src/utils/db.py b/src/utils/db.py index 54c3887..2a3828c 100644 --- a/src/utils/db.py +++ b/src/utils/db.py @@ -41,12 +41,19 @@ class IdNotFoundError[T: Document](ValueError): super().__init__(f"Item with id {item_id} of type {typ.__name__} not found") -def get_id_list[T: Document](items: Sequence[T | Link[T]]) -> list[PydanticObjectId]: - """Extracts the ids from a list of db items.""" +def get_id_list[T: Document](items: Sequence[T | Link[T]], *, allow_links: bool = False) -> list[PydanticObjectId]: + """Extracts the ids from a list of db items. + + :param items: List of db items. + :param allow_links: If True, allows returning reference ids of unfetched Link objects. + """ id_list: list[PydanticObjectId] = [] for item in items: if isinstance(item, Link): - raise UnfetchedLinkError(item) + if not allow_links: + raise UnfetchedLinkError(item) + id_list.append(item.ref.id) + continue if item.id is None: raise MissingIdError(item)