diff --git a/src/api/events.py b/src/api/events.py index 2d8e4dd..f70ab52 100644 --- a/src/api/events.py +++ b/src/api/events.py @@ -12,7 +12,16 @@ from src.db.models.category import Category from src.db.models.event import Event from src.db.models.invitation import Invitation from src.db.models.user import User -from src.utils.db import MissingIdError, UnfetchedLinkError, expr, from_id_list, get_id_list, update_document +from src.db.types import BeanieLink +from src.utils.db import ( + IdNotFoundError, + MissingIdError, + UnfetchedLinkError, + expr, + from_id_list, + get_id_list, + update_document, +) from src.utils.logging import get_logger from .auth import CurrentUserDep @@ -90,16 +99,37 @@ class EventCreateData(_BaseEventData): This structure is intended to be used for POST & PUT requests. """ - async def create_event(self, user: User) -> Event: + async def create_event(self, user: User, category_owner_check: bool = True) -> Event: """Create a new event in the database. - If one of the attendees or categories is not found, IdNotFoundError will be raised. + If one of the categories is not found, HTTPException with code 404 will be raised. + + If `category_owner_check` is set to True, the function will check if the user + is the owner of all categories. If not, HTTPException with code 403 will be raised. """ + try: + categories = await from_id_list(self.category_ids, Category) + except IdNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Category with id {exc.item_id} doesn't exist", + ) + + if category_owner_check: + for category in categories: + await category.fetch_link(Category.user) + + if cast(User, category.user).id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You can only create events with your own categories", + ) + event = Event( user=user, title=self.title, description=self.description, - categories=await from_id_list(self.category_ids, Category, link_return=True), + categories=cast(list[BeanieLink[Category]], categories), start_time=self.start_time, end_time=self.end_time, color=self.color.as_hex(format="long"),