Fix getting invited events

This commit is contained in:
Peter Vacho 2025-01-04 19:53:45 +01:00
parent 1a24a779d9
commit 0cec3463fb
Signed by: school
GPG key ID: 8CFC3837052871B4

View file

@ -3,6 +3,7 @@ from typing import Annotated, Literal, cast, final
from beanie import Link, PydanticObjectId
from beanie.odm.operators.find.comparison import In
from beanie.odm.queries.find import FindMany
from fastapi import APIRouter, Body, HTTPException, Query, status
from fastapi.responses import Response
from pydantic import BaseModel, StringConstraints, field_serializer, field_validator
@ -255,33 +256,45 @@ async def get_user_invited_events(
detail="You can only access your own invite categories",
)
# Initial query (all events the user is attending)
query = Event.find()
# Initial queries (all events for the user)
# NOTE: We can't use a single query, as we would need a union query for that
# which as far as I could tell is not supported by Beanie (annoying).
queries: list[FindMany[Event]] = []
if invite_status is None or invite_status == "accepted":
query = query.find(expr(Event.attendees).id == user_id)
query = Event.find(expr(Event.attendees).id == user_id)
queries.append(query)
if invite_status is None or invite_status == "pending":
invitations = await Invitation.find(
expr(Invitation.invitee).id == user_id,
expr(Invitation.status) == "pending",
).to_list()
event_ids = get_id_list([invite.event for invite in invitations], allow_links=True)
event_ids = set(event_ids) # Remove duplicates
query = Event.find(In(expr(Event.id), event_ids))
queries.append(query)
# Filter by date-time
if start_from is not None:
query = query.find(expr(Event.start_time) >= start_from)
if start_to is not None:
query = query.find(expr(Event.start_time) <= start_to)
if end_from is not None:
query = query.find(expr(Event.end_time) >= end_from)
if end_to is not None:
query = query.find(expr(Event.end_time) <= end_to)
events: list[Event] = []
for query in queries:
new_query = query
# We can't set this earlier as the other find calls would reset it
query = query.find(fetch_links=True)
# Filter by date-time
if start_from is not None:
new_query = new_query.find(expr(Event.start_time) >= start_from)
if start_to is not None:
new_query = new_query.find(expr(Event.start_time) <= start_to)
if end_from is not None:
new_query = new_query.find(expr(Event.end_time) >= end_from)
if end_to is not None:
new_query = new_query.find(expr(Event.end_time) <= end_to)
# We can't set this earlier as the other find calls would reset it
new_query = new_query.find(fetch_links=True)
cur_events = await new_query.to_list()
events.extend(cur_events)
events = await query.to_list()
return [EventData.from_event(event) for event in events]