array_api_extra.testing.lazy_xp_function¶
- array_api_extra.testing.lazy_xp_function(func, *, allow_dask_compute=0, jax_jit=True, static_argnums=None, static_argnames=None)¶
Tag a function to be tested on lazy backends.
Tag a function so that when any tests are executed with
xp=jax.numpy
the function is replaced with a jitted version of itself, and when it is executed withxp=dask.array
the function will raise if it attempts to materialize the graph. This will be later expanded to provide test coverage for other lazy backends.In order for the tag to be effective, the test or a fixture must call
patch_lazy_xp_functions()
.- Parameters:
func (callable) – Function to be tested.
allow_dask_compute (int, optional) –
Number of times func is allowed to internally materialize the Dask graph. This is typically triggered by
bool()
,float()
, ornp.asarray()
.Set to 1 if you are aware that func converts the input parameters to NumPy and want to let it do so at least for the time being, knowing that it is going to be extremely detrimental for performance.
If a test needs values higher than 1 to pass, it is a canary that the conversion to NumPy/bool/float is happening multiple times, which translates to multiple computations of the whole graph. Short of making the function fully lazy, you should at least add explicit calls to
np.asarray()
early in the function. Note: the counter of allow_dask_compute resets after each call to func, so a test function that invokes func multiple times should still work with this parameter set to 1.Default: 0, meaning that func must be fully lazy and never materialize the graph.
jax_jit (bool, optional) – Set to True to replace func with
jax.jit(func)
after calling thepatch_lazy_xp_functions()
test helper withxp=jax.numpy
. Set to False if func is only compatible with eager (non-jitted) JAX. Default: True.static_argnums (int | Sequence[int], optional) – Passed to jax.jit. Positional arguments to treat as static (compile-time constant). Default: infer from static_argnames using inspect.signature(func).
static_argnames (str | Iterable[str], optional) – Passed to jax.jit. Named arguments to treat as static (compile-time constant). Default: infer from static_argnums using inspect.signature(func).
- Return type:
See also
patch_lazy_xp_functions
Companion function to call from the test or fixture.
jax.jit
JAX function to compile a function for performance.
Examples
In
test_mymodule.py
:from array_api_extra.testing import lazy_xp_function from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) # When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)` # When xp=dask.array, crash on compute() or persist() b = myfunc(a)
Notes
In order for this tag to be effective, the test function must be imported into the test module globals without its namespace; alternatively its namespace must be declared in a
lazy_xp_modules
list in the test module globals.Example 1:
from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): x = myfunc(xp.asarray([1, 2]))
Example 2:
import mymodule lazy_xp_modules = [mymodule] lazy_xp_function(mymodule.myfunc) def test_myfunc(xp): x = mymodule.myfunc(xp.asarray([1, 2]))
A test function can circumvent this monkey-patching system by using a namespace outside of the two above patterns. You need to sanitize your code to make sure this only happens intentionally.
Example 1:
import mymodule from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array c = mymodule.myfunc(a) # This is not
Example 2:
import mymodule class naked: myfunc = mymodule.myfunc lazy_xp_modules = [mymodule] lazy_xp_function(mymodule.myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array c = naked.myfunc(a) # This is not