diff --git a/tests/helpers.py b/tests/helpers.py index 15cf800..161b793 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio import unittest.mock -from typing import Any, Generic, TYPE_CHECKING, TypeVar +from typing import Any, ClassVar, Generic, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec @@ -84,7 +84,7 @@ class UnpropagatingMockMixin(Generic[T_Mock]): obj_name = f"{mock_name}.{kwargs['name']}" if "name" in kwargs else f"{mock_name}()" raise AttributeError(f"Can't access {obj_name}, mock is sealed.") - # Propagate any other children as simple `unittest.mock.Mock` instances + # Propagate any other children as the `child_mock_type` instances # rather than `self.__class__` instances return self.child_mock_type(**kwargs) @@ -97,9 +97,16 @@ class CustomMockMixin(UnpropagatingMockMixin[T_Mock], Generic[T_Mock]): * Allows using the ``spec_set`` attribute as class attribute """ - spec_set = None + spec_set: ClassVar[object] = None def __init__(self, **kwargs: object): + # If `spec_set` is explicitly passed, have it take precedence over the class attribute. + # + # Although this is an edge case, and there usually shouldn't be a need for this. + # This is mostly for the sake of completeness, and to allow for more flexibility. if "spec_set" in kwargs: - self.spec_set = kwargs.pop("spec_set") - super().__init__(spec_set=self.spec_set, **kwargs) # pyright: ignore[reportCallIssue] # Mixin class, this __init__ is valid + spec_set = kwargs.pop("spec_set") + else: + spec_set = self.spec_set + + super().__init__(spec_set=spec_set, **kwargs) # pyright: ignore[reportCallIssue] # Mixin class, this __init__ is valid