diff --git a/populate_db.py b/populate_db.py new file mode 100755 index 0000000..c53244d --- /dev/null +++ b/populate_db.py @@ -0,0 +1,265 @@ +#!./.venv/bin/python +# +# This script is used to populate the database with some initial data. +# It's mostly useful for testing purposes, in order to have some data to work with. +# You should NOT run this script in production. + +import asyncio +from datetime import date, time +from typing import cast, final + +from beanie import PydanticObjectId +from pydantic import BaseModel +from pydantic_extra_types.color import Color + +from src.api.categories import CategoryCreateData +from src.api.events import EventCreateData +from src.api.invitations import InvitationCreateData +from src.api.users import RegisterData +from src.db.init import initialize_db +from src.db.models.category import Category +from src.db.models.event import Event +from src.db.models.invitation import Invitation +from src.db.models.user import User +from src.utils.db import expr +from src.utils.logging import get_logger + +log = get_logger(__name__) + + +@final +class CategoryConstructData(BaseModel): + """Data necessary to create a new category.""" + + data: CategoryCreateData + owner_username: str + + +@final +class EventConstructData(BaseModel): + """Data necessary to create a new event.""" + + data: EventCreateData + owner_username: str + attendee_usernames: list[str] = [] + category_names: list[str] = [] + + +@final +class InvitationConstructData(BaseModel): + """Data necessary to create a new invitation.""" + + event_name: str + invitee_username: str + invitor_username: str + create_notificaton: bool + + +USERS: list[RegisterData] = [ + RegisterData( + username="user1", + email="user1@example.org", + password="test1", # noqa: S106 + ), + RegisterData( + username="user2", + email="user2@example.org", + password="test2", # noqa: S106 + ), +] + +CATEGORIES: list[CategoryConstructData] = [ + CategoryConstructData( + data=CategoryCreateData(name="Category 1", color=Color("#ff0000")), + owner_username="user1", + ), + CategoryConstructData( + data=CategoryCreateData(name="Category 2", color=Color("blue")), + owner_username="user1", + ), + CategoryConstructData( + data=CategoryCreateData(name="Category A", color=Color("#ff00ff")), + owner_username="user2", + ), + CategoryConstructData( + data=CategoryCreateData(name="Category B", color=Color("magenta")), + owner_username="user2", + ), +] + +EVENTS: list[EventConstructData] = [ + EventConstructData( + data=EventCreateData( + title="Event 1", + description="Description 1", + category_ids=[], + start_date=date(2025, 1, 1), + start_time=time(12, 00), + end_date=date(2025, 1, 1), + end_time=time(12, 30), + color=Color("#ff0000"), + attendee_ids=[], + ), + owner_username="user1", + attendee_usernames=["user2"], + category_names=["Category 1"], + ), + EventConstructData( + data=EventCreateData( + title="Event 2", + description="Description 2", + category_ids=[], + start_date=date(2025, 1, 1), + start_time=time(12, 00), + end_date=date(2025, 1, 1), + end_time=time(12, 30), + color=Color("#ff0000"), + attendee_ids=[], + ), + owner_username="user2", + attendee_usernames=[], + category_names=["Category A", "Category B"], + ), +] + +INVITATIONS: list[InvitationConstructData] = [ + InvitationConstructData( + event_name="Event 2", + invitee_username="user1", + invitor_username="user2", + create_notificaton=True, + ), +] + + +async def make_user(user: RegisterData) -> None: + """Create a new user with configured credentials.""" + db_user = await User.find_one(User.username == user.username) + if db_user is not None: + log.info(f"User {db_user.username!r} was already created. Deleting it.") + _ = await db_user.delete() + + _ = await user.create_user() + log.info(f"User {user.username!r} created; password: {user.password!r}, email: {user.email!r}.") + + +async def make_category(category: CategoryConstructData) -> None: + """Create a new category with configured data.""" + db_user = await User.find_one(User.username == category.owner_username) + if db_user is None: + log.error(f"User {category.owner_username!r} not found, failed to create category.") + return + + # The category name isn't actually unique, so this is not exactly correct + # but for the sake of simplicity, we'll assume it is. + # Note that category names might become unique eventually anyways + db_category = await Category.find(Category.name == category.data.name, Category.user == db_user).first_or_none() + if db_category is not None: + log.info(f"Category {db_category.name!r} for user {category.owner_username!r} already exists. Deleting it.") + _ = await db_category.delete() + + _ = await category.data.create_category(db_user) + log.info(f"Category {category.data.name!r} created for user {category.owner_username!r}.") + + +async def make_event(event: EventConstructData) -> None: + """Create a new event with configured data.""" + db_user = await User.find_one(User.username == event.owner_username) + if db_user is None: + log.error(f"User {event.owner_username!r} not found, failed to create event.") + return + + attendee_ids = [] + for attendee_username in event.attendee_usernames: + db_attendee = await User.find_one(User.username == attendee_username) + if db_attendee is None: + log.error(f"Attendee user {attendee_username!r} not found, failed to create event.") + return + attendee_ids.append(db_attendee.id) + + category_ids = [] + for category_name in event.category_names: + db_category = await Category.find( + Category.name == category_name, + Category.user == User.link_from_id(db_user.id), + ).first_or_none() + if db_category is None: + log.error(f"Category {category_name!r} not found, failed to create event.") + return + category_ids.append(db_category.id) + + # The event title isn't actually unique, so this is not exactly correct + # but for the sake of simplicity, we'll assume it is. + db_event = await Event.find(Event.title == event.data.title, Event.user == db_user).first_or_none() + if db_event is not None: + log.info(f"Event {db_event.title!r} for user {event.owner_username!r} already exists. Deleting it.") + _ = await db_event.delete() + + event.data.category_ids = category_ids + event.data.attendee_ids = attendee_ids + _ = await event.data.create_event(db_user) + log.info(f"Event {event.data.title!r} created for user {event.owner_username!r}.") + + +async def make_invitation(invitation: InvitationConstructData) -> None: + """Create a new invitation with configured data.""" + db_invitor = await User.find_one(User.username == invitation.invitor_username) + if db_invitor is None: + log.error(f"Invitor user {invitation.invitor_username!r} not found, failed to create invitation.") + return + + db_invitee = await User.find_one(User.username == invitation.invitee_username) + if db_invitee is None: + log.error(f"Invitee user {invitation.invitee_username!r} not found, failed to create invitation.") + return + + db_event = await Event.find_one(Event.title == invitation.event_name, expr(Event.user).id == db_invitor.id) + if db_event is None: + log.error(f"Event {invitation.event_name!r} not found, failed to create invitation.") + return + + db_invitation = await Invitation.find_one( + expr(Invitation.event).id == db_event.id, + expr(Invitation.invitee).id == db_invitee.id, + expr(Invitation.invitor).id == db_invitor.id, + ) + if db_invitation is not None: + log.info( + f"Invitation for event {invitation.event_name!r} created by {invitation.invitor_username!r} " + f"for {invitation.invitee_username!r} already exists. Deleting it.", + ) + _ = await db_invitation.delete() + + invitation_data = InvitationCreateData( + event_id=cast(PydanticObjectId, db_event.id), + invitee_id=cast(PydanticObjectId, db_invitee.id), + ) + + _ = await invitation_data.create_invitation(db_invitor, create_notification=invitation.create_notificaton) + log.info( + f"Invitation for event {invitation.event_name!r} created by {invitation.invitor_username!r} " + f"for {invitation.invitee_username!r}.", + ) + + +async def main() -> None: + """Create a new user with configured credentials.""" + db_session, _ = await initialize_db() + + for user in USERS: + await make_user(user) + + for category in CATEGORIES: + await make_category(category) + + for event in EVENTS: + await make_event(event) + + for invitation in INVITATIONS: + await make_invitation(invitation) + + db_session.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/api/invitations.py b/src/api/invitations.py index b05e79f..32ded97 100644 --- a/src/api/invitations.py +++ b/src/api/invitations.py @@ -88,6 +88,13 @@ class InvitationCreateData(_BaseInvitationData): @overload async def create_invitation(self, user: User, create_notification: Literal[False]) -> Invitation: ... + @overload + async def create_invitation( + self, + user: User, + create_notification: bool, + ) -> tuple[Invitation, Notification] | Invitation: ... + async def create_invitation( self, user: User,