array_api_extra.lazy_apply¶
- array_api_extra.lazy_apply(func, *args, shape=None, dtype=None, as_numpy=False, xp=None, **kwargs)¶
Lazily apply an eager function.
If the backend of the input arrays is lazy, e.g. Dask or jitted JAX, the execution of the function is delayed until the graph is materialized; if it’s eager, the function is executed immediately.
- Parameters:
func (callable) –
The function to apply.
It must accept one or more array API compliant arrays as positional arguments. If as_numpy=True, inputs are converted to NumPy before they are passed to func. It must return either a single array-like or a sequence of array-likes.
func must be a pure function, i.e. without side effects, as depending on the backend it may be executed more than once or never.
*args (Array | int | float | complex | bool | None) –
One or more Array API compliant arrays, Python scalars, or None’s.
If as_numpy=True, you need to be able to apply
numpy.asarray()
to non-None args to convert them to NumPy; read notes below about specific backends.shape (tuple[int | None, ...] | Sequence[tuple[int | None, ...]], optional) – Output shape or sequence of output shapes, one for each output of func. Default: assume single output and broadcast shapes of the input arrays.
dtype (DType | Sequence[DType], optional) – Output dtype or sequence of output dtypes, one for each output of func. dtype(s) must belong to the same array namespace as the input arrays. Default: infer the result type(s) from the input arrays.
as_numpy (bool, optional) –
- If True, convert the input arrays to NumPy before passing them to func.
This is particularly useful to make NumPy-only functions, e.g. written in Cython
- or Numba, work transparently with array API-compliant arrays.
Default: False.
xp (array_namespace, optional) – The standard-compatible namespace for args. Default: infer.
**kwargs (Any, optional) – Additional keyword arguments to pass verbatim to func. They cannot contain Array objects.
- Returns:
The result(s) of func applied to the input arrays, wrapped in the same array namespace as the inputs. If shape is omitted or a single tuple[int | None, …], return a single array. Otherwise, return a tuple of arrays.
- Return type:
Array | tuple[Array, …]
Notes
- JAX
This allows applying eager functions to jitted JAX arrays, which are lazy. The function won’t be applied until the JAX array is materialized. When running inside
jax.jit
, shape must be fully known, i.e. it cannot contain any None elements.Warning
func must never raise inside
jax.jit
, as the resulting behavior is undefined.Using this with as_numpy=False is particularly useful to apply non-jittable JAX functions to arrays on GPU devices. If
as_numpy=True
, the Transfer guard may prevent arrays on a GPU device from being transferred back to CPU. This is treated as an implicit transfer.- PyTorch, CuPy
If
as_numpy=True
, these backends raise by default if you attempt to convert arrays on a GPU device to NumPy.- Sparse
If
as_numpy=True
, by default sparse prevents implicit densification throughnumpy.asarray()
. This safety mechanism can be disabled.- Dask
This allows applying eager functions to Dask arrays. The Dask graph won’t be computed until the user calls
compute()
orpersist()
down the line.The function name will be prominently visible on the user-facing Dask dashboard and on Prometheus metrics, so it is recommended for it to be meaningful.
lazy_apply doesn’t know if func reduces along any axes; also, shape changes are non-trivial in chunked Dask arrays. For these reasons, all inputs will be rechunked into a single chunk.
Warning
The whole operation needs to fit in memory all at once on a single worker.
The outputs will also be returned as a single chunk and you should consider rechunking them into smaller chunks afterwards.
If you want to distribute the calculation across multiple workers, you should use
dask.array.map_blocks()
,dask.array.map_overlap()
,dask.array.blockwise()
, or a native Dask wrapper instead of lazy_apply.- Dask wrapping around other backends
If
as_numpy=False
, func will receive in input eager arrays of the meta namespace, as defined by the._meta
attribute of the input Dask arrays. The outputs of func will be wrapped by the meta namespace, and then wrapped again by Dask.
- Raises:
ValueError – When
xp=jax.numpy
, the output shape is unknown (it containsNone
on one or more axes) and this function was called insidejax.jit
.RuntimeError – When
xp=sparse
and auto-densification is disabled.Exception (backend-specific) – When the backend disallows implicit device to host transfers and the input arrays are on a non-CPU device, e.g. on GPU.