Extract the document updating logic into a util
This commit is contained in:
parent
66339fa43b
commit
5ff4da15d8
|
@ -11,7 +11,7 @@ from src.api.auth.dependencies import LoggedInDep
|
|||
from src.db.models.category import Category
|
||||
from src.db.models.event import Event
|
||||
from src.db.models.user import User
|
||||
from src.utils.db import MissingIdError
|
||||
from src.utils.db import MissingIdError, update_document
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import CurrentUserDep
|
||||
|
@ -74,18 +74,11 @@ class CategoryCreateData(_BaseCategoryData):
|
|||
|
||||
async def update_category(self, category: Category) -> Category:
|
||||
"""Update an existing category, overwriting all relevant data."""
|
||||
updated = False
|
||||
|
||||
if category.name != self.name:
|
||||
category.name = self.name
|
||||
updated = True
|
||||
if category.color != self.color:
|
||||
category.color = self.color.as_hex(format="long")
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
category = await category.replace()
|
||||
return category
|
||||
return await update_document(
|
||||
category,
|
||||
name=self.name,
|
||||
color=self.color.as_hex(format="long"),
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
|
@ -100,18 +93,12 @@ class PartialCategoryUpdateData(_BaseCategoryData):
|
|||
|
||||
async def update_category(self, category: Category) -> Category:
|
||||
"""Update an existing category, overwriting it with data that were specified."""
|
||||
updated = False
|
||||
|
||||
if self.name and category.name != self.name:
|
||||
category.name = self.name
|
||||
updated = True
|
||||
if self.color and category.color != self.color:
|
||||
category.color = self.color.as_hex(format="long")
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
category = await category.replace()
|
||||
return category
|
||||
update_dct = {}
|
||||
if self.name is not None:
|
||||
update_dct["name"] = self.name
|
||||
if self.color is not None:
|
||||
update_dct["color"] = self.color.as_hex(format="long")
|
||||
return await update_document(category, **update_dct)
|
||||
|
||||
|
||||
@base_router.get("/users/{user_id}/categories")
|
||||
|
|
|
@ -10,7 +10,7 @@ from src.api.auth.dependencies import LoggedInDep
|
|||
from src.api.auth.passwords import check_hashed_password, create_password_hash
|
||||
from src.db.models.token import Token
|
||||
from src.db.models.user import User
|
||||
from src.utils.db import MissingIdError
|
||||
from src.utils.db import MissingIdError, update_document
|
||||
from src.utils.logging import get_logger
|
||||
|
||||
from .auth import CurrentUserDep
|
||||
|
@ -69,28 +69,22 @@ class UserUpdatable(BaseModel):
|
|||
password: Annotated[str, StringConstraints(min_length=1, max_length=50)] | None = None
|
||||
email: EmailStr | None = None
|
||||
|
||||
async def update_user(self, user: User) -> tuple[bool, User]:
|
||||
"""Update the user with data from this instance.
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update the user with data from this instance."""
|
||||
update_dct: dict[str, object] = {}
|
||||
|
||||
:return: Was the user updated?
|
||||
"""
|
||||
updated = False
|
||||
if self.username and self.username != user.username:
|
||||
user.username = self.username
|
||||
updated = True
|
||||
|
||||
if self.email and self.email != user.email:
|
||||
user.email = self.email
|
||||
updated = True
|
||||
if self.username:
|
||||
update_dct["username"] = self.username
|
||||
if self.email:
|
||||
update_dct["email"] = self.email
|
||||
|
||||
# Password is a special case, as it requires a hash check
|
||||
# (new hash will be different even if the plain-text matches due to salting
|
||||
# so to avoid unnecessary updates we check the hash first)
|
||||
if self.password and check_hashed_password(self.password, user.password_hash):
|
||||
user.password_hash = create_password_hash(self.password)
|
||||
updated = True
|
||||
update_dct["password_hash"] = create_password_hash(self.password)
|
||||
|
||||
if updated:
|
||||
user = await user.replace()
|
||||
|
||||
return updated, user
|
||||
return await update_document(user, **update_dct)
|
||||
|
||||
|
||||
@public_router.post("")
|
||||
|
@ -159,7 +153,7 @@ async def patch_user(
|
|||
detail="You're not authorized to update this user",
|
||||
)
|
||||
|
||||
_, user = await data.update_user(user)
|
||||
user = await data.update_user(user)
|
||||
|
||||
return UserData.from_user(user)
|
||||
|
||||
|
|
|
@ -85,3 +85,19 @@ async def from_id_list[T: Document](
|
|||
continue
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
async def update_document[T: Document](document: T, **kwargs) -> T:
|
||||
"""Update a document with given fields.
|
||||
|
||||
This function will only update the document if at least one of the fields differs.
|
||||
"""
|
||||
updated = False
|
||||
for field_name, new_value in kwargs.items():
|
||||
old_value = getattr(document, field_name)
|
||||
if old_value != new_value:
|
||||
setattr(document, field_name, new_value)
|
||||
updated = True
|
||||
if updated:
|
||||
document = await document.replace()
|
||||
return document
|
||||
|
|
Loading…
Reference in a new issue