import operator
import sys
import warnings
from contextlib import contextmanager
from typing import Any, Callable, ClassVar, Dict, Set, Union
import numpy as np
from napari._pydantic_compat import (
BaseModel,
ModelMetaclass,
PrivateAttr,
main,
utils,
)
from napari.utils.events.event import EmitterGroup, Event
from napari.utils.misc import pick_equality_operator
from napari.utils.translations import trans
# encoders for non-napari specific field types. To declare a custom encoder
# for a napari type, add a `_json_encode` method to the class itself.
# it will be added to the model json_encoders in :func:`EventedMetaclass.__new__`
_BASE_JSON_ENCODERS = {np.ndarray: lambda arr: arr.tolist()}
@contextmanager
def no_class_attributes():
"""Context in which pydantic.main.ClassAttribute just passes value 2.
Due to a very annoying decision by PySide2, all class ``__signature__``
attributes may only be assigned **once**. (This seems to be regardless of
whether the class has anything to do with PySide2 or not). Furthermore,
the PySide2 ``__signature__`` attribute seems to break the python
descriptor protocol, which means that class attributes that have a
``__get__`` method will not be able to successfully retrieve their value
(instead, the descriptor object itself will be accessed).
This plays terribly with Pydantic, which assigns a ``ClassAttribute``
object to the value of ``cls.__signature__`` in ``ModelMetaclass.__new__``
in order to avoid masking the call signature of object instances that have
a ``__call__`` method (https://github.com/samuelcolvin/pydantic/pull/1466).
So, because we only get to set the ``__signature__`` once, this context
manager basically "opts-out" of pydantic's ``ClassAttribute`` strategy,
thereby directly setting the ``cls.__signature__`` to an instance of
``inspect.Signature``.
For additional context, see:
- https://github.com/napari/napari/issues/2264
- https://github.com/napari/napari/pull/2265
- https://bugreports.qt.io/browse/PYSIDE-1004
- https://codereview.qt-project.org/c/pyside/pyside-setup/+/261411
"""
if "PySide2" not in sys.modules:
yield
return
# monkey patch the pydantic ClassAttribute object
# the second argument to ClassAttribute is the inspect.Signature object
def _return2(x, y):
return y
main.ClassAttribute = _return2
try:
yield
finally:
# undo our monkey patch
main.ClassAttribute = utils.ClassAttribute
class EventedMetaclass(ModelMetaclass):
"""pydantic ModelMetaclass that preps "equality checking" operations.
A metaclass is the thing that "constructs" a class, and ``ModelMetaclass``
is where pydantic puts a lot of it's type introspection and ``ModelField``
creation logic. Here, we simply tack on one more function, that builds a
``cls.__eq_operators__`` dict which is mapping of field name to a function
that can be called to check equality of the value of that field with some
other object. (used in ``EventedModel.__eq__``)
This happens only once, when an ``EventedModel`` class is created (and not
when each instance of an ``EventedModel`` is instantiated).
"""
def __new__(mcs, name, bases, namespace, **kwargs):
with no_class_attributes():
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
cls.__eq_operators__ = {}
for n, f in cls.__fields__.items():
cls.__eq_operators__[n] = pick_equality_operator(f.type_)
# If a field type has a _json_encode method, add it to the json
# encoders for this model.
# NOTE: a _json_encode field must return an object that can be
# passed to json.dumps ... but it needn't return a string.
if hasattr(f.type_, '_json_encode'):
encoder = f.type_._json_encode
cls.__config__.json_encoders[f.type_] = encoder
# also add it to the base config
# required for pydantic>=1.8.0 due to:
# https://github.com/samuelcolvin/pydantic/pull/2064
EventedModel.__config__.json_encoders[f.type_] = encoder
# check for properties defined on the class, so we can allow them
# in EventedModel.__setattr__ and create events
cls.__properties__ = {}
for name, attr in namespace.items():
if isinstance(attr, property):
cls.__properties__[name] = attr
# determine compare operator
if (
hasattr(attr.fget, "__annotations__")
and "return" in attr.fget.__annotations__
and not isinstance(
attr.fget.__annotations__["return"], str
)
):
cls.__eq_operators__[name] = pick_equality_operator(
attr.fget.__annotations__["return"]
)
cls.__field_dependents__ = _get_field_dependents(cls)
return cls
def _update_dependents_from_property_code(
cls, prop_name, prop, deps, visited=()
):
"""Recursively find all the dependents of a property by inspecting the code object.
Update the given deps dictionary with the new findings.
"""
for name in prop.fget.__code__.co_names:
if name in cls.__fields__:
deps.setdefault(name, set()).add(prop_name)
elif name in cls.__properties__ and name not in visited:
# to avoid infinite recursion, we shouldn't re-check getter we've already seen
visited = visited + (name,)
# sub_prop is the new property, but we leave prop_name the same
sub_prop = cls.__properties__[name]
_update_dependents_from_property_code(
cls, prop_name, sub_prop, deps, visited
)
def _get_field_dependents(cls: 'EventedModel') -> Dict[str, Set[str]]:
"""Return mapping of field name -> dependent set of property names.
Dependencies will be guessed by inspecting the code of each property
in order to emit an event for a computed property when a model field
that it depends on changes (e.g: @property 'c' depends on model fields
'a' and 'b'). Alternatvely, dependencies may be declared excplicitly
in the Model Config.
Note: accessing a field with `getattr()` instead of dot notation won't
be automatically detected.
Examples
--------
class MyModel(EventedModel):
a: int = 1
b: int = 1
@property
def c(self) -> List[int]:
return [self.a, self.b]
@c.setter
def c(self, val: Sequence[int]):
self.a, self.b = val
@property
def d(self) -> int:
return sum(self.c)
@d.setter
def d(self, val: int):
self.c = [val // 2, val // 2]
class Config:
dependencies={
'c': ['a', 'b'],
'd': ['a', 'b']
}
"""
if not cls.__properties__:
return {}
deps: Dict[str, Set[str]] = {}
_deps = getattr(cls.__config__, 'dependencies', None)
if _deps:
for prop_name, fields in _deps.items():
if prop_name not in cls.__properties__:
raise ValueError(
'Fields with dependencies must be properties. '
f'{prop_name!r} is not.'
)
for field in fields:
if field not in cls.__fields__:
warnings.warn(f"Unrecognized field dependency: {field}")
deps.setdefault(field, set()).add(prop_name)
else:
# if dependencies haven't been explicitly defined, we can glean
# them from the property.fget code object:
for prop_name, prop in cls.__properties__.items():
_update_dependents_from_property_code(cls, prop_name, prop, deps)
return deps
[docs]
class EventedModel(BaseModel, metaclass=EventedMetaclass):
"""A Model subclass that emits an event whenever a field value is changed.
Note: As per the standard pydantic behavior, default Field values are
not validated (#4138) and should be correctly typed.
"""
# add private attributes for event emission
_events: EmitterGroup = PrivateAttr(default_factory=EmitterGroup)
# mapping of name -> property obj for methods that are properties
__properties__: ClassVar[Dict[str, property]]
# mapping of field name -> dependent set of property names
# when field is changed, an event for dependent properties will be emitted.
__field_dependents__: ClassVar[Dict[str, Set[str]]]
__eq_operators__: ClassVar[Dict[str, Callable[[Any, Any], bool]]]
__slots__: ClassVar[Set[str]] = {"__weakref__"} # type: ignore
# pydantic BaseModel configuration. see:
# https://pydantic-docs.helpmanual.io/usage/model_config/
class Config:
# whether to allow arbitrary user types for fields (they are validated
# simply by checking if the value is an instance of the type). If
# False, RuntimeError will be raised on model declaration
arbitrary_types_allowed = True
# whether to perform validation on assignment to attributes
validate_assignment = True
# whether to treat any underscore non-class var attrs as private
# https://pydantic-docs.helpmanual.io/usage/models/#private-model-attributes
underscore_attrs_are_private = True
# whether to validate field defaults (default: False)
validate_all = True
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#modeljson
# NOTE: json_encoders are also added EventedMetaclass.__new__ if the
# field declares a _json_encode method.
json_encoders = _BASE_JSON_ENCODERS
# extra = Extra.forbid
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._events.source = self
# add event emitters for each field which is mutable
field_events = [
name
for name, field in self.__fields__.items()
if field.field_info.allow_mutation
]
self._events.add(
**dict.fromkeys(field_events + list(self.__properties__))
)
# while seemingly redundant, this next line is very important to maintain
# correct sources; see https://github.com/napari/napari/pull/4138
# we solve it by re-setting the source after initial validation, which allows
# us to use `validate_all = True`
self._reset_event_source()
def _super_setattr_(self, name: str, value: Any) -> None:
# pydantic will raise a ValueError if extra fields are not allowed
# so we first check to see if this field is a property
# if so, we use it instead.
if name in self.__properties__:
setter = self.__properties__[name].fset
if setter is None:
# raise same error as normal properties
raise AttributeError(f"can't set attribute '{name}'")
setter(self, value)
else:
super().__setattr__(name, value)
def __setattr__(self, name: str, value: Any) -> None:
if name not in getattr(self, 'events', {}):
# fallback to default behavior
self._super_setattr_(name, value)
return
# grab current value
field_dep = self.__field_dependents__.get(name, {})
has_callbacks = {
name: bool(getattr(self.events, name).callbacks)
for name in field_dep
}
emitter = getattr(self.events, name)
# equality comparisons may be expensive, so just avoid them if
# event has no callbacks connected
if not (
emitter.callbacks
or self._events.callbacks
or any(has_callbacks.values())
):
self._super_setattr_(name, value)
return
dep_with_callbacks = [
dep for dep, has_cb in has_callbacks.items() if has_cb
]
before = getattr(self, name, object())
before_deps = {}
with warnings.catch_warnings():
# we still need to check for deprecated properties
warnings.simplefilter("ignore", DeprecationWarning)
for dep in dep_with_callbacks:
before_deps[dep] = getattr(self, dep, object())
# set value using original setter
self._super_setattr_(name, value)
# if different we emit the event with new value
after = getattr(self, name)
after_deps = {}
with warnings.catch_warnings():
# we still need to check for deprecated properties
warnings.simplefilter("ignore", DeprecationWarning)
for dep in dep_with_callbacks:
after_deps[dep] = getattr(self, dep, object())
are_equal = self.__eq_operators__.get(name, operator.eq)
if are_equal(after, before):
# no change
return
emitter(value=after) # emit event
# emit events for any dependent computed properties as well
for dep, value_ in before_deps.items():
if dep in self.__eq_operators__:
are_equal = self.__eq_operators__[dep]
else:
are_equal = pick_equality_operator(after_deps[dep])
if not are_equal(after_deps[dep], value_):
getattr(self.events, dep)(value=after_deps[dep])
# expose the private EmitterGroup publically
@property
def events(self) -> EmitterGroup:
return self._events
def _reset_event_source(self):
"""
set the event sources of self and all the children to the correct values
"""
# events are all messed up due to objects being probably
# recreated arbitrarily during validation
self.events.source = self
for name in self.__fields__:
child = getattr(self, name)
if isinstance(child, EventedModel):
# TODO: this isinstance check should be EventedMutables in the future
child._reset_event_source()
elif name in self.events.emitters:
getattr(self.events, name).source = self
@property
def _defaults(self):
return get_defaults(self)
[docs]
def reset(self):
"""Reset the state of the model to default values."""
for name, value in self._defaults.items():
if isinstance(value, EventedModel):
getattr(self, name).reset()
elif (
self.__config__.allow_mutation
and self.__fields__[name].field_info.allow_mutation
):
setattr(self, name, value)
[docs]
def update(
self, values: Union['EventedModel', dict], recurse: bool = True
) -> None:
"""Update a model in place.
Parameters
----------
values : dict, napari.utils.events.EventedModel
Values to update the model with. If an EventedModel is passed it is
first converted to a dictionary. The keys of this dictionary must
be found as attributes on the current model.
recurse : bool
If True, recursively update fields that are EventedModels.
Otherwise, just update the immediate fields of this EventedModel,
which is useful when the declared field type (e.g. ``Union``) can have
different realized types with different fields.
"""
if isinstance(values, self.__class__):
values = values.dict()
if not isinstance(values, dict):
raise TypeError(
trans._(
"Unsupported update from {values}",
deferred=True,
values=type(values),
)
)
with self.events.blocker() as block:
for key, value in values.items():
field = getattr(self, key)
if isinstance(field, EventedModel) and recurse:
field.update(value, recurse=recurse)
else:
setattr(self, key, value)
if block.count:
self.events(Event(self))
def __eq__(self, other) -> bool:
"""Check equality with another object.
We override the pydantic approach (which just checks
``self.dict() == other.dict()``) to accommodate more complicated types
like arrays, whose truth value is often ambiguous. ``__eq_operators__``
is constructed in ``EqualityMetaclass.__new__``
"""
if self is other:
return True
if not isinstance(other, EventedModel):
return self.dict() == other
if self.__class__ != other.__class__:
return False
for f_name in self.__fields__:
eq = self.__eq_operators__[f_name]
if not eq(getattr(self, f_name), getattr(other, f_name)):
return False
return True
[docs]
@contextmanager
def enums_as_values(self, as_values: bool = True):
"""Temporarily override how enums are retrieved.
Parameters
----------
as_values : bool, optional
Whether enums should be shown as values (or as enum objects),
by default `True`
"""
null = object()
before = getattr(self.Config, 'use_enum_values', null)
self.Config.use_enum_values = as_values
try:
yield
finally:
if before is not null:
self.Config.use_enum_values = before
else:
delattr(self.Config, 'use_enum_values')
def get_defaults(obj: BaseModel):
"""Get possibly nested default values for a Model object."""
dflt = {}
for k, v in obj.__fields__.items():
d = v.get_default()
if d is None and isinstance(v.type_, main.ModelMetaclass):
d = get_defaults(v.type_)
dflt[k] = d
return dflt