Check for category ownership on event creation

This commit is contained in:
Peter Vacho 2025-01-01 14:48:15 +01:00
parent 5f46c7ee7d
commit d24e31af93
Signed by: school
GPG key ID: 8CFC3837052871B4

View file

@ -12,7 +12,16 @@ from src.db.models.category import Category
from src.db.models.event import Event
from src.db.models.invitation import Invitation
from src.db.models.user import User
from src.utils.db import MissingIdError, UnfetchedLinkError, expr, from_id_list, get_id_list, update_document
from src.db.types import BeanieLink
from src.utils.db import (
IdNotFoundError,
MissingIdError,
UnfetchedLinkError,
expr,
from_id_list,
get_id_list,
update_document,
)
from src.utils.logging import get_logger
from .auth import CurrentUserDep
@ -90,16 +99,37 @@ class EventCreateData(_BaseEventData):
This structure is intended to be used for POST & PUT requests.
"""
async def create_event(self, user: User) -> Event:
async def create_event(self, user: User, category_owner_check: bool = True) -> Event:
"""Create a new event in the database.
If one of the attendees or categories is not found, IdNotFoundError will be raised.
If one of the categories is not found, HTTPException with code 404 will be raised.
If `category_owner_check` is set to True, the function will check if the user
is the owner of all categories. If not, HTTPException with code 403 will be raised.
"""
try:
categories = await from_id_list(self.category_ids, Category)
except IdNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Category with id {exc.item_id} doesn't exist",
)
if category_owner_check:
for category in categories:
await category.fetch_link(Category.user)
if cast(User, category.user).id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only create events with your own categories",
)
event = Event(
user=user,
title=self.title,
description=self.description,
categories=await from_id_list(self.category_ids, Category, link_return=True),
categories=cast(list[BeanieLink[Category]], categories),
start_time=self.start_time,
end_time=self.end_time,
color=self.color.as_hex(format="long"),