array_api_extra.testing.patch_lazy_xp_functions

array_api_extra.testing.patch_lazy_xp_functions(request, monkeypatch, *, xp)

Test lazy execution of functions tagged with lazy_xp_function().

If xp==jax.numpy, search for all functions which have been tagged with lazy_xp_function() in the globals of the module that defines the current test, as well as in the lazy_xp_modules list in the globals of the same module, and wrap them with jax.jit(). Unwrap them at the end of the test.

If xp==dask.array, wrap the functions with a decorator that disables compute() and persist() and ensures that exceptions and warnings are raised eagerly.

This function should be typically called by your library’s xp fixture that runs tests on multiple backends:

@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array])
def xp(request, monkeypatch):
    patch_lazy_xp_functions(request, monkeypatch, xp=request.param)
    return request.param

but it can be otherwise be called by the test itself too.

Parameters:
  • request (pytest.FixtureRequest) – Pytest fixture, as acquired by the test itself or by one of its fixtures.

  • monkeypatch (pytest.MonkeyPatch) – Pytest fixture, as acquired by the test itself or by one of its fixtures.

  • xp (array_namespace) – Array namespace to be tested.

Return type:

None

See also

lazy_xp_function

Tag a function to be tested on lazy backends.

pytest.FixtureRequest

request test function parameter.