array_api_extra.testing.lazy_xp_function

array_api_extra.testing.lazy_xp_function(func, *, allow_dask_compute=False, jax_jit=True, static_argnums=Deprecated.DEPRECATED, static_argnames=Deprecated.DEPRECATED)

Tag a function to be tested on lazy backends.

Tag a function so that when any tests are executed with xp=jax.numpy the function is replaced with a jitted version of itself, and when it is executed with xp=dask.array the function will raise if it attempts to materialize the graph. This will be later expanded to provide test coverage for other lazy backends.

In order for the tag to be effective, the test or a fixture must call patch_lazy_xp_functions().

Parameters:
  • func (callable) – Function to be tested.

  • allow_dask_compute (bool | int, optional) –

    Whether func is allowed to internally materialize the Dask graph, or maximum number of times it is allowed to do so. This is typically triggered by bool(), float(), or np.asarray().

    Set to 1 if you are aware that func converts the input parameters to NumPy and want to let it do so at least for the time being, knowing that it is going to be extremely detrimental for performance.

    If a test needs values higher than 1 to pass, it is a canary that the conversion to NumPy/bool/float is happening multiple times, which translates to multiple computations of the whole graph. Short of making the function fully lazy, you should at least add explicit calls to np.asarray() early in the function. Note: the counter of allow_dask_compute resets after each call to func, so a test function that invokes func multiple times should still work with this parameter set to 1.

    Set to True to allow func to materialize the graph an unlimited number of times.

    Default: False, meaning that func must be fully lazy and never materialize the graph.

  • jax_jit (bool, optional) –

    Set to True to replace func with a smart variant of jax.jit(func) after calling the patch_lazy_xp_functions() test helper with xp=jax.numpy. This is the default behaviour. Set to False if func is only compatible with eager (non-jitted) JAX.

    Unlike with vanilla jax.jit, all arguments and return types that are not JAX arrays are treated as static; the function can accept and return arbitrary wrappers around JAX arrays. This difference is because, in real life, most users won’t wrap the function directly with jax.jit but rather they will use it within their own code, which is itself then wrapped by jax.jit, and internally consume the function’s outputs.

    In other words, the pattern that is being tested is:

    >>> @jax.jit
    ... def user_func(x):
    ...     y = user_prepares_inputs(x)
    ...     z = func(y, some_static_arg=True)
    ...     return user_consumes(z)
    

    Default: True.

  • static_argnums (Deprecated) – Deprecated; ignored

  • static_argnames (Deprecated) – Deprecated; ignored

Return type:

None

See also

patch_lazy_xp_functions

Companion function to call from the test or fixture.

jax.jit

JAX function to compile a function for performance.

Examples

In test_mymodule.py:

from array_api_extra.testing import lazy_xp_function from mymodule import myfunc

lazy_xp_function(myfunc)

def test_myfunc(xp):
    a = xp.asarray([1, 2])
    # When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)`
    # When xp=dask.array, crash on compute() or persist()
    b = myfunc(a)

Notes

In order for this tag to be effective, the test function must be imported into the test module globals without its namespace; alternatively its namespace must be declared in a lazy_xp_modules list in the test module globals.

Example 1:

from mymodule import myfunc

lazy_xp_function(myfunc)

def test_myfunc(xp):
    x = myfunc(xp.asarray([1, 2]))

Example 2:

import mymodule

lazy_xp_modules = [mymodule]
lazy_xp_function(mymodule.myfunc)

def test_myfunc(xp):
    x = mymodule.myfunc(xp.asarray([1, 2]))

A test function can circumvent this monkey-patching system by using a namespace outside of the two above patterns. You need to sanitize your code to make sure this only happens intentionally.

Example 1:

import mymodule
from mymodule import myfunc

lazy_xp_function(myfunc)

def test_myfunc(xp):
    a = xp.asarray([1, 2])
    b = myfunc(a)  # This is wrapped when xp=jax.numpy or xp=dask.array
    c = mymodule.myfunc(a)  # This is not

Example 2:

import mymodule

class naked:
    myfunc = mymodule.myfunc

lazy_xp_modules = [mymodule]
lazy_xp_function(mymodule.myfunc)

def test_myfunc(xp):
    a = xp.asarray([1, 2])
    b = mymodule.myfunc(a)  # This is wrapped when xp=jax.numpy or xp=dask.array
    c = naked.myfunc(a)  # This is not