Extract the document updating logic into a util

This commit is contained in:
Peter Vacho 2024-12-27 17:48:58 +01:00
parent 66339fa43b
commit 5ff4da15d8
Signed by: school
GPG key ID: 8CFC3837052871B4
3 changed files with 42 additions and 45 deletions

View file

@ -11,7 +11,7 @@ from src.api.auth.dependencies import LoggedInDep
from src.db.models.category import Category
from src.db.models.event import Event
from src.db.models.user import User
from src.utils.db import MissingIdError
from src.utils.db import MissingIdError, update_document
from src.utils.logging import get_logger
from .auth import CurrentUserDep
@ -74,18 +74,11 @@ class CategoryCreateData(_BaseCategoryData):
async def update_category(self, category: Category) -> Category:
"""Update an existing category, overwriting all relevant data."""
updated = False
if category.name != self.name:
category.name = self.name
updated = True
if category.color != self.color:
category.color = self.color.as_hex(format="long")
updated = True
if updated:
category = await category.replace()
return category
return await update_document(
category,
name=self.name,
color=self.color.as_hex(format="long"),
)
@final
@ -100,18 +93,12 @@ class PartialCategoryUpdateData(_BaseCategoryData):
async def update_category(self, category: Category) -> Category:
"""Update an existing category, overwriting it with data that were specified."""
updated = False
if self.name and category.name != self.name:
category.name = self.name
updated = True
if self.color and category.color != self.color:
category.color = self.color.as_hex(format="long")
updated = True
if updated:
category = await category.replace()
return category
update_dct = {}
if self.name is not None:
update_dct["name"] = self.name
if self.color is not None:
update_dct["color"] = self.color.as_hex(format="long")
return await update_document(category, **update_dct)
@base_router.get("/users/{user_id}/categories")

View file

@ -10,7 +10,7 @@ from src.api.auth.dependencies import LoggedInDep
from src.api.auth.passwords import check_hashed_password, create_password_hash
from src.db.models.token import Token
from src.db.models.user import User
from src.utils.db import MissingIdError
from src.utils.db import MissingIdError, update_document
from src.utils.logging import get_logger
from .auth import CurrentUserDep
@ -69,28 +69,22 @@ class UserUpdatable(BaseModel):
password: Annotated[str, StringConstraints(min_length=1, max_length=50)] | None = None
email: EmailStr | None = None
async def update_user(self, user: User) -> tuple[bool, User]:
"""Update the user with data from this instance.
async def update_user(self, user: User) -> User:
"""Update the user with data from this instance."""
update_dct: dict[str, object] = {}
:return: Was the user updated?
"""
updated = False
if self.username and self.username != user.username:
user.username = self.username
updated = True
if self.email and self.email != user.email:
user.email = self.email
updated = True
if self.username:
update_dct["username"] = self.username
if self.email:
update_dct["email"] = self.email
# Password is a special case, as it requires a hash check
# (new hash will be different even if the plain-text matches due to salting
# so to avoid unnecessary updates we check the hash first)
if self.password and check_hashed_password(self.password, user.password_hash):
user.password_hash = create_password_hash(self.password)
updated = True
update_dct["password_hash"] = create_password_hash(self.password)
if updated:
user = await user.replace()
return updated, user
return await update_document(user, **update_dct)
@public_router.post("")
@ -159,7 +153,7 @@ async def patch_user(
detail="You're not authorized to update this user",
)
_, user = await data.update_user(user)
user = await data.update_user(user)
return UserData.from_user(user)

View file

@ -85,3 +85,19 @@ async def from_id_list[T: Document](
continue
items.append(item)
return items
async def update_document[T: Document](document: T, **kwargs) -> T:
"""Update a document with given fields.
This function will only update the document if at least one of the fields differs.
"""
updated = False
for field_name, new_value in kwargs.items():
old_value = getattr(document, field_name)
if old_value != new_value:
setattr(document, field_name, new_value)
updated = True
if updated:
document = await document.replace()
return document