apply_where¶
- array_api_extra.apply_where(cond, args, f1, f2=None, /, *, fill_value=None, kwargs=None, xp=None)¶
- Overloads:
cond (Array), args (Array | tuple[Array, …]), f1 (Callable[…, Array]), f2 (Callable[…, Array]), kwargs (dict[str, Array] | None), xp (ModuleType | None) → Array
cond (Array), args (Array | tuple[Array, …]), f1 (Callable[…, Array]), fill_value (Array | complex), kwargs (dict[str, Array] | None), xp (ModuleType | None) → Array
Run one of two elementwise functions depending on a condition.
Equivalent to
f1(*args) if cond else fill_valueperformed elementwise when fill_value is defined, otherwise tof1(*args) if cond else f2(*args).- Parameters:
cond (
object) – The condition, expressed as a boolean array.args (
object|tuple[object,...]) – Argument(s) to f1 (and f2). Must be broadcastable with cond.f1 (
Callable[...,object]) – Elementwise function of args, returning a single array. Where cond is True, output will bef1(arg0[cond], arg1[cond], ...).f2 (
Callable[...,object] |None) – Elementwise function of args, returning a single array. Where cond is False, output will bef2(arg0[cond], arg1[cond], ...). Mutually exclusive with fill_value.fill_value (
object|complex|None) – If provided, value with which to fill output array where cond is False. It does not need to be scalar; it needs however to be broadcastable with cond and args. Mutually exclusive with f2. You must provide one or the other.kwargs (
dict[str,object] |None) – Keyword argument(s) to f1 (and f2). Values must be broadcastable with cond.xp (
ModuleType|None) – The standard-compatible namespace for cond and args. Default: infer.
- Returns:
An array with elements from the output of f1 where cond is True and either the output of f2 or fill_value where cond is False. The returned array has data type determined by type promotion rules between the output of f1 and either fill_value or the output of f2.
- Return type:
Array
Notes
xp.where(cond, f1(*args), f2(*args))requires explicitly evaluating f1 even when cond is False, and f2 when cond is True. This function evaluates each function only for their matching condition, if the backend allows for it.On Dask, f1 and f2 are applied to the individual chunks and should use functions from the namespace of the chunks.
Examples
>>> import array_api_strict as xp >>> import array_api_extra as xpx >>> a = xp.asarray([5, 4, 3]) >>> b = xp.asarray([0, 2, 2]) >>> def f(a, b): ... return a // b >>> xpx.apply_where(b != 0, (a, b), f, fill_value=xp.nan) array([ nan, 2., 1.])