diff --git a/src/api/events.py b/src/api/events.py index 4e00dce..2693da2 100644 --- a/src/api/events.py +++ b/src/api/events.py @@ -3,6 +3,7 @@ from typing import Annotated, Literal, cast, final from beanie import Link, PydanticObjectId from beanie.odm.operators.find.comparison import In +from beanie.odm.queries.find import FindMany from fastapi import APIRouter, Body, HTTPException, Query, status from fastapi.responses import Response from pydantic import BaseModel, StringConstraints, field_serializer, field_validator @@ -255,33 +256,45 @@ async def get_user_invited_events( detail="You can only access your own invite categories", ) - # Initial query (all events the user is attending) - query = Event.find() + # Initial queries (all events for the user) + # NOTE: We can't use a single query, as we would need a union query for that + # which as far as I could tell is not supported by Beanie (annoying). + queries: list[FindMany[Event]] = [] if invite_status is None or invite_status == "accepted": - query = query.find(expr(Event.attendees).id == user_id) + query = Event.find(expr(Event.attendees).id == user_id) + queries.append(query) + if invite_status is None or invite_status == "pending": invitations = await Invitation.find( expr(Invitation.invitee).id == user_id, expr(Invitation.status) == "pending", ).to_list() event_ids = get_id_list([invite.event for invite in invitations], allow_links=True) + event_ids = set(event_ids) # Remove duplicates query = Event.find(In(expr(Event.id), event_ids)) + queries.append(query) - # Filter by date-time - if start_from is not None: - query = query.find(expr(Event.start_time) >= start_from) - if start_to is not None: - query = query.find(expr(Event.start_time) <= start_to) - if end_from is not None: - query = query.find(expr(Event.end_time) >= end_from) - if end_to is not None: - query = query.find(expr(Event.end_time) <= end_to) + events: list[Event] = [] + for query in queries: + new_query = query - # We can't set this earlier as the other find calls would reset it - query = query.find(fetch_links=True) + # Filter by date-time + if start_from is not None: + new_query = new_query.find(expr(Event.start_time) >= start_from) + if start_to is not None: + new_query = new_query.find(expr(Event.start_time) <= start_to) + if end_from is not None: + new_query = new_query.find(expr(Event.end_time) >= end_from) + if end_to is not None: + new_query = new_query.find(expr(Event.end_time) <= end_to) + + # We can't set this earlier as the other find calls would reset it + new_query = new_query.find(fetch_links=True) + + cur_events = await new_query.to_list() + events.extend(cur_events) - events = await query.to_list() return [EventData.from_event(event) for event in events]