diff --git a/populate_db.py b/populate_db.py index c53244d..2e6dd75 100755 --- a/populate_db.py +++ b/populate_db.py @@ -21,7 +21,7 @@ from src.db.models.category import Category from src.db.models.event import Event from src.db.models.invitation import Invitation from src.db.models.user import User -from src.utils.db import expr +from src.utils.db import expr, from_id_list from src.utils.logging import get_logger log = get_logger(__name__) @@ -98,7 +98,6 @@ EVENTS: list[EventConstructData] = [ end_date=date(2025, 1, 1), end_time=time(12, 30), color=Color("#ff0000"), - attendee_ids=[], ), owner_username="user1", attendee_usernames=["user2"], @@ -114,7 +113,6 @@ EVENTS: list[EventConstructData] = [ end_date=date(2025, 1, 1), end_time=time(12, 30), color=Color("#ff0000"), - attendee_ids=[], ), owner_username="user2", attendee_usernames=[], @@ -196,8 +194,12 @@ async def make_event(event: EventConstructData) -> None: _ = await db_event.delete() event.data.category_ids = category_ids - event.data.attendee_ids = attendee_ids - _ = await event.data.create_event(db_user) + db_event = await event.data.create_event(db_user) + + if len(attendee_ids) > 0: + db_event.attendees = await from_id_list(attendee_ids, User, link_return=True) + db_event = await db_event.replace() + log.info(f"Event {event.data.title!r} created for user {event.owner_username!r}.") diff --git a/src/api/events.py b/src/api/events.py index 056258e..fac955d 100644 --- a/src/api/events.py +++ b/src/api/events.py @@ -36,7 +36,6 @@ class _BaseEventData(BaseModel): end_date: date end_time: time color: Color - attendee_ids: list[PydanticObjectId] @field_validator("color", mode="after") @classmethod @@ -52,6 +51,7 @@ class EventData(_BaseEventData): """Data about an event sent to the user.""" owner_user_id: PydanticObjectId + attendee_ids: list[PydanticObjectId] created_at: datetime @classmethod @@ -99,7 +99,7 @@ class EventCreateData(_BaseEventData): end_date=self.end_date, end_time=self.end_time, color=self.color.as_hex(format="long"), - attendees=await from_id_list(self.attendee_ids, User, link_return=True), + attendees=[], ) return await event.create() @@ -115,7 +115,6 @@ class EventCreateData(_BaseEventData): end_date=self.end_date, end_time=self.end_time, color=self.color.as_hex(format="long"), - attendees=await from_id_list(self.attendee_ids, User), ) @@ -131,7 +130,6 @@ class PartialEventUpdateData(_BaseEventData): end_date: date | None = None # pyright: ignore[reportIncompatibleVariableOverride] end_time: time | None = None # pyright: ignore[reportIncompatibleVariableOverride] color: Color | None = None # pyright: ignore[reportIncompatibleVariableOverride] - attendee_ids: list[PydanticObjectId] | None = None # pyright: ignore[reportIncompatibleVariableOverride] async def update_event(self, event: Event) -> Event: """Update an existing event, overwriting it with data that were specified.""" @@ -154,8 +152,6 @@ class PartialEventUpdateData(_BaseEventData): update_dct["end_time"] = self.end_time if self.color is not None: update_dct["color"] = self.color.as_hex(format="long") - if self.attendee_ids is not None: - update_dct["attendees"] = await from_id_list(self.attendee_ids, User) return await update_document(event, **update_dct)