Source code for napari.utils._dask_utils

"""Dask cache utilities."""

import contextlib
from import Iterator
from typing import Any, Callable, Optional

import dask
import dask.array as da
from dask.cache import Cache

#: dask.cache.Cache, optional : A dask cache for opportunistic caching
#: use :func:`~.resize_dask_cache` to actually register and resize.
#: this is a global cache (all layers will use it), but individual layers
#: can opt out using Layer(..., cache=False)
_DASK_CACHE = Cache(1)

DaskIndexer = Callable[
    [], contextlib.AbstractContextManager[Optional[tuple[dict, Cache]]]

[docs] def resize_dask_cache( nbytes: Optional[int] = None, mem_fraction: Optional[float] = None ) -> Cache: """Create or resize the dask cache used for opportunistic caching. The cache object is an instance of a :class:`Cache`, (which wraps a :class:`cachey.Cache`). See `Dask opportunistic caching <>`_ Parameters ---------- nbytes : int, optional The desired size of the cache, in bytes. If ``None``, the cache size will autodetermined as fraction of the total memory in the system, using ``mem_fraction``. If ``nbytes`` is 0. The cache is turned off. by default, cache size is autodetermined using ``mem_fraction``. mem_fraction : float, optional The fraction (from 0 to 1) of total memory to use for the dask cache. Returns ------- dask_cache : dask.cache.Cache An instance of a Dask Cache Examples -------- >>> from napari.utils import resize_dask_cache >>> cache = resize_dask_cache() # use 25% of total memory by default >>> # dask.Cache wraps cachey.Cache >>> assert isinstance(cache.cache, cachey.Cache) >>> # useful attributes >>> cache.cache.available_bytes # full size of cache >>> cache.cache.total_bytes # currently used bytes """ from psutil import virtual_memory if nbytes is None and mem_fraction is not None: nbytes = int(virtual_memory().total * mem_fraction) avail = _DASK_CACHE.cache.available_bytes # if we don't have a cache already, create one. if avail == 1: # If neither nbytes nor mem_fraction was provided, use default if nbytes is None: nbytes = int(virtual_memory().total * _DEFAULT_MEM_FRACTION) _DASK_CACHE.cache.resize(nbytes) elif nbytes is not None and nbytes != _DASK_CACHE.cache.available_bytes: # if the cache has already been registered, then calling # resize_dask_cache() without supplying either mem_fraction or nbytes # is a no-op: _DASK_CACHE.cache.resize(nbytes) return _DASK_CACHE
def _is_dask_data(data: Any) -> bool: """Return True if data is a dask array or a list/tuple of dask arrays.""" return isinstance(data, da.Array) or ( isinstance(data, and any(isinstance(i, da.Array) for i in data) ) def configure_dask(data: Any, cache: bool = True) -> DaskIndexer: """Spin up cache and return context manager that optimizes Dask indexing. This function determines whether data is a dask array or list of dask arrays and prepares some optimizations if so. When a delayed dask array is given to napari, there are couple things that need to be done to optimize performance. 1. Opportunistic caching needs to be enabled, such that we don't recompute (or "re-read") data that has already been computed or read. 2. Dask task fusion must be turned off to prevent napari from triggering new io on data that has already been read from disk. For example, with a 4D timelapse of 3D stacks, napari may actually *re-read* the entire 3D tiff file every time the Z plane index is changed. Turning of Dask task fusion with `` == False`` prevents this. .. note:: Turning off task fusion requires Dask version 2.15.0 or later. For background and context, see `napari/napari#718 <>`_, `napari/napari#1124 <>`_, and `dask/dask#6084 <>`_. For details on Dask task fusion, see the documentation on `Dask Optimization <>`_. Parameters ---------- data : Any data, as passed to a ``Layer.__init__`` method. Returns ------- ContextManager A context manager that can be used to optimize dask indexing Examples -------- >>> data = dask.array.ones((10,10,10)) >>> optimized_slicing = configure_dask(data) >>> with optimized_slicing(): ... data[0, 2].compute() """ if not _is_dask_data(data): return contextlib.nullcontext _cache = resize_dask_cache() if cache else contextlib.nullcontext() @contextlib.contextmanager def dask_optimized_slicing( memfrac: float = 0.5, ) -> Iterator[tuple[Any, Any]]: opts = {'': False} with dask.config.set(opts) as cfg, _cache as c: yield cfg, c return dask_optimized_slicing