diff --git a/src/api/invitations.py b/src/api/invitations.py index 6e3bcb2..861f5a4 100644 --- a/src/api/invitations.py +++ b/src/api/invitations.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime -from typing import Literal, cast, final +from typing import Literal, cast, final, overload from beanie import Link, PydanticObjectId from fastapi import APIRouter, HTTPException, Response, status @@ -7,6 +7,7 @@ from pydantic import BaseModel from src.db.models.event import Event from src.db.models.invitation import Invitation +from src.db.models.notification import Notification from src.db.models.user import User from src.utils.db import IdNotFoundError, MissingIdError, UnfetchedLinkError, expr from src.utils.logging import get_logger @@ -77,10 +78,26 @@ class InvitationCreateData(_BaseInvitationData): This structure is intended to be used for POST requests. """ - async def create_invitation(self, user: User) -> Invitation: + @overload + async def create_invitation( + self, + user: User, + create_notification: Literal[True] = True, + ) -> tuple[Invitation, Notification]: ... + + @overload + async def create_invitation(self, user: User, create_notification: Literal[False]) -> Invitation: ... + + async def create_invitation( + self, + user: User, + create_notification: bool = True, + ) -> Invitation | tuple[Invitation, Notification]: """Create a new invitation in the database. If one the event or invitee doesn't exist, the IdNotFoundError will be raised. + If create_notification is True, a notification will be created for the invitee. In most cases, + this is the desired behavior. """ event = await Event.get(self.event_id) if event is None: @@ -96,7 +113,20 @@ class InvitationCreateData(_BaseInvitationData): invitor=user, status="pending", ) - return await invitation.create() + invitation = await invitation.create() + + if create_notification: + notification = Notification( + user=invitee, + event_type="invitation", + message="new-invitation", + data=invitation.id, + read=False, + ) + notification = await notification.create() + return invitation, notification + + return invitation @base_router.get("/users/{user_id}/invitations") @@ -155,7 +185,7 @@ async def get_invitation(invitation_id: PydanticObjectId, user: CurrentUserDep) @invitations_router.post("") async def create_invitation(data: InvitationCreateData, user: CurrentUserDep) -> InvitationData: """Create a new invitation.""" - invitation = await data.create_invitation(user) + invitation, _notification = await data.create_invitation(user) return InvitationData.from_invitation(invitation) @@ -197,6 +227,16 @@ async def accept_invitation(invitation_id: PydanticObjectId, user: CurrentUserDe invitation.responded_at = datetime.now(tz=UTC) _ = await invitation.save() + # Send back a notification to the invitor about the acceptance + notification = Notification( + user=invitation.invitor, + event_type="invitation", + message="invitation-accepted", + data=invitation.id, + read=False, + ) + notification = await notification.create() + return Response(status.HTTP_204_NO_CONTENT) @@ -220,6 +260,16 @@ async def decline_invitation(invitation_id: PydanticObjectId, user: CurrentUserD invitation.responded_at = datetime.now(tz=UTC) _ = await invitation.save() + # Send back a notification to the invitor about the decline + notification = Notification( + user=invitation.invitor, + event_type="invitation", + message="invitation-declined", + data=invitation.id, + read=False, + ) + notification = await notification.create() + return Response(status_code=status.HTTP_204_NO_CONTENT)