array_api_extra.testing.lazy_xp_function¶
- array_api_extra.testing.lazy_xp_function(func, *, allow_dask_compute=0, jax_jit=True, static_argnums=None, static_argnames=None)¶
Tag a function to be tested on lazy backends.
Tag a function, which must be imported in the test module globals, so that when any tests defined in the same module are executed with
xp=jax.numpy
the function is replaced with a jitted version of itself, and when it is executed withxp=dask.array
the function will raise if it attempts to materialize the graph. This will be later expanded to provide test coverage for other lazy backends.In order for the tag to be effective, the test or a fixture must call
patch_lazy_xp_functions()
.- Parameters:
func (callable) – Function to be tested.
allow_dask_compute (int, optional) –
Number of times func is allowed to internally materialize the Dask graph. This is typically triggered by
bool()
,float()
, ornp.asarray()
.Set to 1 if you are aware that func converts the input parameters to numpy and want to let it do so at least for the time being, knowing that it is going to be extremely detrimental for performance.
If a test needs values higher than 1 to pass, it is a canary that the conversion to numpy/bool/float is happening multiple times, which translates to multiple computations of the whole graph. Short of making the function fully lazy, you should at least add explicit calls to
np.asarray()
early in the function. Note: the counter of allow_dask_compute resets after each call to func, so a test function that invokes func multiple times should still work with this parameter set to 1.Default: 0, meaning that func must be fully lazy and never materialize the graph.
jax_jit (bool, optional) – Set to True to replace func with
jax.jit(func)
after calling thepatch_lazy_xp_functions()
test helper withxp=jax.numpy
. Set to False if func is only compatible with eager (non-jitted) JAX. Default: True.static_argnums (int | Sequence[int], optional) – Passed to jax.jit. Positional arguments to treat as static (compile-time constant). Default: infer from static_argnames using inspect.signature(func).
static_argnames (str | Iterable[str], optional) – Passed to jax.jit. Named arguments to treat as static (compile-time constant). Default: infer from static_argnums using inspect.signature(func).
- Return type:
See also
patch_lazy_xp_functions
Companion function to call from the test or fixture.
jax.jit
JAX function to compile a function for performance.
Examples
In
test_mymodule.py
:from array_api_extra.testing import lazy_xp_function from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) # When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)` # When xp=dask.array, crash on compute() or persist() b = myfunc(a)
Notes
A test function can circumvent this monkey-patching system by calling func as an attribute of the original module. You need to sanitize your code to make sure this does not happen.
Example:
import mymodule from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c = mymodule.myfunc(a) # This is not