Source code for napari.utils.events.evented_model

import operator
import sys
import warnings
from contextlib import contextmanager
from typing import Any, Callable, ClassVar, Dict, Set, Union

import numpy as np

from napari._pydantic_compat import (
    BaseModel,
    ModelMetaclass,
    PrivateAttr,
    main,
    utils,
)
from napari.utils.events.event import EmitterGroup, Event
from napari.utils.misc import pick_equality_operator
from napari.utils.translations import trans

# encoders for non-napari specific field types.  To declare a custom encoder
# for a napari type, add a `_json_encode` method to the class itself.
# it will be added to the model json_encoders in :func:`EventedMetaclass.__new__`
_BASE_JSON_ENCODERS = {np.ndarray: lambda arr: arr.tolist()}


@contextmanager
def no_class_attributes():
    """Context in which pydantic.main.ClassAttribute just passes value 2.

    Due to a very annoying decision by PySide2, all class ``__signature__``
    attributes may only be assigned **once**.  (This seems to be regardless of
    whether the class has anything to do with PySide2 or not).  Furthermore,
    the PySide2 ``__signature__`` attribute seems to break the python
    descriptor protocol, which means that class attributes that have a
    ``__get__`` method will not be able to successfully retrieve their value
    (instead, the descriptor object itself will be accessed).

    This plays terribly with Pydantic, which assigns a ``ClassAttribute``
    object to the value of ``cls.__signature__`` in ``ModelMetaclass.__new__``
    in order to avoid masking the call signature of object instances that have
    a ``__call__`` method (https://github.com/samuelcolvin/pydantic/pull/1466).

    So, because we only get to set the ``__signature__`` once, this context
    manager basically "opts-out" of pydantic's ``ClassAttribute`` strategy,
    thereby directly setting the ``cls.__signature__`` to an instance of
    ``inspect.Signature``.

    For additional context, see:
    - https://github.com/napari/napari/issues/2264
    - https://github.com/napari/napari/pull/2265
    - https://bugreports.qt.io/browse/PYSIDE-1004
    - https://codereview.qt-project.org/c/pyside/pyside-setup/+/261411
    """

    if "PySide2" not in sys.modules:
        yield
        return

    # monkey patch the pydantic ClassAttribute object
    # the second argument to ClassAttribute is the inspect.Signature object
    def _return2(x, y):
        return y

    main.ClassAttribute = _return2
    try:
        yield
    finally:
        # undo our monkey patch
        main.ClassAttribute = utils.ClassAttribute


class EventedMetaclass(ModelMetaclass):
    """pydantic ModelMetaclass that preps "equality checking" operations.

    A metaclass is the thing that "constructs" a class, and ``ModelMetaclass``
    is where pydantic puts a lot of it's type introspection and ``ModelField``
    creation logic.  Here, we simply tack on one more function, that builds a
    ``cls.__eq_operators__`` dict which is mapping of field name to a function
    that can be called to check equality of the value of that field with some
    other object.  (used in ``EventedModel.__eq__``)

    This happens only once, when an ``EventedModel`` class is created (and not
    when each instance of an ``EventedModel`` is instantiated).
    """

    def __new__(mcs, name, bases, namespace, **kwargs):
        with no_class_attributes():
            cls = super().__new__(mcs, name, bases, namespace, **kwargs)
        cls.__eq_operators__ = {}
        for n, f in cls.__fields__.items():
            cls.__eq_operators__[n] = pick_equality_operator(f.type_)
            # If a field type has a _json_encode method, add it to the json
            # encoders for this model.
            # NOTE: a _json_encode field must return an object that can be
            # passed to json.dumps ... but it needn't return a string.
            if hasattr(f.type_, '_json_encode'):
                encoder = f.type_._json_encode
                cls.__config__.json_encoders[f.type_] = encoder
                # also add it to the base config
                # required for pydantic>=1.8.0 due to:
                # https://github.com/samuelcolvin/pydantic/pull/2064
                EventedModel.__config__.json_encoders[f.type_] = encoder
        # check for properties defined on the class, so we can allow them
        # in EventedModel.__setattr__ and create events
        cls.__properties__ = {}
        for name, attr in namespace.items():
            if isinstance(attr, property):
                cls.__properties__[name] = attr
                # determine compare operator
                if (
                    hasattr(attr.fget, "__annotations__")
                    and "return" in attr.fget.__annotations__
                    and not isinstance(
                        attr.fget.__annotations__["return"], str
                    )
                ):
                    cls.__eq_operators__[name] = pick_equality_operator(
                        attr.fget.__annotations__["return"]
                    )

        cls.__field_dependents__ = _get_field_dependents(cls)
        return cls


def _update_dependents_from_property_code(
    cls, prop_name, prop, deps, visited=()
):
    """Recursively find all the dependents of a property by inspecting the code object.

    Update the given deps dictionary with the new findings.
    """
    for name in prop.fget.__code__.co_names:
        if name in cls.__fields__:
            deps.setdefault(name, set()).add(prop_name)
        elif name in cls.__properties__ and name not in visited:
            # to avoid infinite recursion, we shouldn't re-check getter we've already seen
            visited = visited + (name,)
            # sub_prop is the new property, but we leave prop_name the same
            sub_prop = cls.__properties__[name]
            _update_dependents_from_property_code(
                cls, prop_name, sub_prop, deps, visited
            )


def _get_field_dependents(cls: 'EventedModel') -> Dict[str, Set[str]]:
    """Return mapping of field name -> dependent set of property names.

    Dependencies will be guessed by inspecting the code of each property
    in order to emit an event for a computed property when a model field
    that it depends on changes (e.g: @property 'c' depends on model fields
    'a' and 'b'). Alternatvely, dependencies may be declared excplicitly
    in the Model Config.

    Note: accessing a field with `getattr()` instead of dot notation won't
    be automatically detected.

    Examples
    --------
        class MyModel(EventedModel):
            a: int = 1
            b: int = 1

            @property
            def c(self) -> List[int]:
                return [self.a, self.b]

            @c.setter
            def c(self, val: Sequence[int]):
                self.a, self.b = val

            @property
            def d(self) -> int:
                return sum(self.c)

            @d.setter
            def d(self, val: int):
                self.c = [val // 2, val // 2]

            class Config:
                dependencies={
                    'c': ['a', 'b'],
                    'd': ['a', 'b']
                }
    """
    if not cls.__properties__:
        return {}

    deps: Dict[str, Set[str]] = {}

    _deps = getattr(cls.__config__, 'dependencies', None)
    if _deps:
        for prop_name, fields in _deps.items():
            if prop_name not in cls.__properties__:
                raise ValueError(
                    'Fields with dependencies must be properties. '
                    f'{prop_name!r} is not.'
                )
            for field in fields:
                if field not in cls.__fields__:
                    warnings.warn(f"Unrecognized field dependency: {field}")
                deps.setdefault(field, set()).add(prop_name)
    else:
        # if dependencies haven't been explicitly defined, we can glean
        # them from the property.fget code object:
        for prop_name, prop in cls.__properties__.items():
            _update_dependents_from_property_code(cls, prop_name, prop, deps)
    return deps


[docs] class EventedModel(BaseModel, metaclass=EventedMetaclass): """A Model subclass that emits an event whenever a field value is changed. Note: As per the standard pydantic behavior, default Field values are not validated (#4138) and should be correctly typed. """ # add private attributes for event emission _events: EmitterGroup = PrivateAttr(default_factory=EmitterGroup) # mapping of name -> property obj for methods that are properties __properties__: ClassVar[Dict[str, property]] # mapping of field name -> dependent set of property names # when field is changed, an event for dependent properties will be emitted. __field_dependents__: ClassVar[Dict[str, Set[str]]] __eq_operators__: ClassVar[Dict[str, Callable[[Any, Any], bool]]] __slots__: ClassVar[Set[str]] = {"__weakref__"} # type: ignore # pydantic BaseModel configuration. see: # https://pydantic-docs.helpmanual.io/usage/model_config/ class Config: # whether to allow arbitrary user types for fields (they are validated # simply by checking if the value is an instance of the type). If # False, RuntimeError will be raised on model declaration arbitrary_types_allowed = True # whether to perform validation on assignment to attributes validate_assignment = True # whether to treat any underscore non-class var attrs as private # https://pydantic-docs.helpmanual.io/usage/models/#private-model-attributes underscore_attrs_are_private = True # whether to validate field defaults (default: False) validate_all = True # https://pydantic-docs.helpmanual.io/usage/exporting_models/#modeljson # NOTE: json_encoders are also added EventedMetaclass.__new__ if the # field declares a _json_encode method. json_encoders = _BASE_JSON_ENCODERS # extra = Extra.forbid def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._events.source = self # add event emitters for each field which is mutable field_events = [ name for name, field in self.__fields__.items() if field.field_info.allow_mutation ] self._events.add( **dict.fromkeys(field_events + list(self.__properties__)) ) # while seemingly redundant, this next line is very important to maintain # correct sources; see https://github.com/napari/napari/pull/4138 # we solve it by re-setting the source after initial validation, which allows # us to use `validate_all = True` self._reset_event_source() def _super_setattr_(self, name: str, value: Any) -> None: # pydantic will raise a ValueError if extra fields are not allowed # so we first check to see if this field is a property # if so, we use it instead. if name in self.__properties__: setter = self.__properties__[name].fset if setter is None: # raise same error as normal properties raise AttributeError(f"can't set attribute '{name}'") setter(self, value) else: super().__setattr__(name, value) def __setattr__(self, name: str, value: Any) -> None: if name not in getattr(self, 'events', {}): # fallback to default behavior self._super_setattr_(name, value) return # grab current value field_dep = self.__field_dependents__.get(name, {}) has_callbacks = { name: bool(getattr(self.events, name).callbacks) for name in field_dep } emitter = getattr(self.events, name) # equality comparisons may be expensive, so just avoid them if # event has no callbacks connected if not ( emitter.callbacks or self._events.callbacks or any(has_callbacks.values()) ): self._super_setattr_(name, value) return dep_with_callbacks = [ dep for dep, has_cb in has_callbacks.items() if has_cb ] before = getattr(self, name, object()) before_deps = {} with warnings.catch_warnings(): # we still need to check for deprecated properties warnings.simplefilter("ignore", DeprecationWarning) for dep in dep_with_callbacks: before_deps[dep] = getattr(self, dep, object()) # set value using original setter self._super_setattr_(name, value) # if different we emit the event with new value after = getattr(self, name) after_deps = {} with warnings.catch_warnings(): # we still need to check for deprecated properties warnings.simplefilter("ignore", DeprecationWarning) for dep in dep_with_callbacks: after_deps[dep] = getattr(self, dep, object()) are_equal = self.__eq_operators__.get(name, operator.eq) if are_equal(after, before): # no change return emitter(value=after) # emit event # emit events for any dependent computed properties as well for dep, value_ in before_deps.items(): if dep in self.__eq_operators__: are_equal = self.__eq_operators__[dep] else: are_equal = pick_equality_operator(after_deps[dep]) if not are_equal(after_deps[dep], value_): getattr(self.events, dep)(value=after_deps[dep]) # expose the private EmitterGroup publically @property def events(self) -> EmitterGroup: return self._events def _reset_event_source(self): """ set the event sources of self and all the children to the correct values """ # events are all messed up due to objects being probably # recreated arbitrarily during validation self.events.source = self for name in self.__fields__: child = getattr(self, name) if isinstance(child, EventedModel): # TODO: this isinstance check should be EventedMutables in the future child._reset_event_source() elif name in self.events.emitters: getattr(self.events, name).source = self @property def _defaults(self): return get_defaults(self)
[docs] def reset(self): """Reset the state of the model to default values.""" for name, value in self._defaults.items(): if isinstance(value, EventedModel): getattr(self, name).reset() elif ( self.__config__.allow_mutation and self.__fields__[name].field_info.allow_mutation ): setattr(self, name, value)
[docs] def update( self, values: Union['EventedModel', dict], recurse: bool = True ) -> None: """Update a model in place. Parameters ---------- values : dict, napari.utils.events.EventedModel Values to update the model with. If an EventedModel is passed it is first converted to a dictionary. The keys of this dictionary must be found as attributes on the current model. recurse : bool If True, recursively update fields that are EventedModels. Otherwise, just update the immediate fields of this EventedModel, which is useful when the declared field type (e.g. ``Union``) can have different realized types with different fields. """ if isinstance(values, self.__class__): values = values.dict() if not isinstance(values, dict): raise TypeError( trans._( "Unsupported update from {values}", deferred=True, values=type(values), ) ) with self.events.blocker() as block: for key, value in values.items(): field = getattr(self, key) if isinstance(field, EventedModel) and recurse: field.update(value, recurse=recurse) else: setattr(self, key, value) if block.count: self.events(Event(self))
def __eq__(self, other) -> bool: """Check equality with another object. We override the pydantic approach (which just checks ``self.dict() == other.dict()``) to accommodate more complicated types like arrays, whose truth value is often ambiguous. ``__eq_operators__`` is constructed in ``EqualityMetaclass.__new__`` """ if self is other: return True if not isinstance(other, EventedModel): return self.dict() == other if self.__class__ != other.__class__: return False for f_name in self.__fields__: eq = self.__eq_operators__[f_name] if not eq(getattr(self, f_name), getattr(other, f_name)): return False return True
[docs] @contextmanager def enums_as_values(self, as_values: bool = True): """Temporarily override how enums are retrieved. Parameters ---------- as_values : bool, optional Whether enums should be shown as values (or as enum objects), by default `True` """ null = object() before = getattr(self.Config, 'use_enum_values', null) self.Config.use_enum_values = as_values try: yield finally: if before is not null: self.Config.use_enum_values = before else: delattr(self.Config, 'use_enum_values')
def get_defaults(obj: BaseModel): """Get possibly nested default values for a Model object.""" dflt = {} for k, v in obj.__fields__.items(): d = v.get_default() if d is None and isinstance(v.type_, main.ModelMetaclass): d = get_defaults(v.type_) dflt[k] = d return dflt