array_api_extra.apply_where

array_api_extra.apply_where(cond, args, f1, f2=None, /, *, fill_value=None, xp=None)

Run one of two elementwise functions depending on a condition.

Equivalent to f1(*args) if cond else fill_value performed elementwise when fill_value is defined, otherwise to f1(*args) if cond else f2(*args).

Parameters:
  • cond (array) – The condition, expressed as a boolean array.

  • args (Array or tuple of Arrays) – Argument(s) to f1 (and f2). Must be broadcastable with cond.

  • f1 (callable) – Elementwise function of args, returning a single array. Where cond is True, output will be f1(arg0[cond], arg1[cond], ...).

  • f2 (callable, optional) – Elementwise function of args, returning a single array. Where cond is False, output will be f2(arg0[cond], arg1[cond], ...). Mutually exclusive with fill_value.

  • fill_value (Array or scalar, optional) – If provided, value with which to fill output array where cond is False. It does not need to be scalar; it needs however to be broadcastable with cond and args. Mutually exclusive with f2. You must provide one or the other.

  • xp (array_namespace, optional) – The standard-compatible namespace for cond and args. Default: infer.

Returns:

An array with elements from the output of f1 where cond is True and either the output of f2 or fill_value where cond is False. The returned array has data type determined by type promotion rules between the output of f1 and either fill_value or the output of f2.

Return type:

Array

Notes

xp.where(cond, f1(*args), f2(*args)) requires explicitly evaluating f1 even when cond is False, and f2 when cond is True. This function evaluates each function only for their matching condition, if the backend allows for it.

On Dask, f1 and f2 are applied to the individual chunks and should use functions from the namespace of the chunks.

Examples

>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> a = xp.asarray([5, 4, 3])
>>> b = xp.asarray([0, 2, 2])
>>> def f(a, b):
...     return a // b
>>> xpx.apply_where(b != 0, (a, b), f, fill_value=xp.nan)
array([ nan,  2., 1.])