array_api_extra.testing.patch_lazy_xp_functions

array_api_extra.testing.patch_lazy_xp_functions(request, monkeypatch, *, xp)

Test lazy execution of functions tagged with lazy_xp_function().

If xp==jax.numpy, search for all functions which have been tagged with lazy_xp_function() in the globals of the module that defines the current test and wrap them with jax.jit(). Unwrap them at the end of the test.

If xp==dask.array, wrap the functions with a decorator that disables compute() and persist().

This function should be typically called by your library’s xp fixture that runs tests on multiple backends:

@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
def xp(request, monkeypatch):
    patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
    return request.param

but it can be otherwise be called by the test itself too.

Parameters:
  • request (pytest.FixtureRequest) – Pytest fixture, as acquired by the test itself or by one of its fixtures.

  • monkeypatch (pytest.MonkeyPatch) – Pytest fixture, as acquired by the test itself or by one of its fixtures.

  • xp (module) – Array namespace to be tested.

Return type:

None

See also

lazy_xp_function

Tag a function to be tested on lazy backends.

pytest.FixtureRequest

request test function parameter.