array_api_extra.testing.patch_lazy_xp_functions¶
- array_api_extra.testing.patch_lazy_xp_functions(request, monkeypatch, *, xp)¶
Test lazy execution of functions tagged with
lazy_xp_function()
.If
xp==jax.numpy
, search for all functions which have been tagged withlazy_xp_function()
in the globals of the module that defines the current test and wrap them withjax.jit()
. Unwrap them at the end of the test.If
xp==dask.array
, wrap the functions with a decorator that disablescompute()
andpersist()
.This function should be typically called by your library’s xp fixture that runs tests on multiple backends:
@pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array]) def xp(request, monkeypatch): patch_lazy_xp_functions(request, monkeypatch, xp=request.param) return request.param
but it can be otherwise be called by the test itself too.
- Parameters:
request (pytest.FixtureRequest) – Pytest fixture, as acquired by the test itself or by one of its fixtures.
monkeypatch (pytest.MonkeyPatch) – Pytest fixture, as acquired by the test itself or by one of its fixtures.
xp (module) – Array namespace to be tested.
- Return type:
See also
lazy_xp_function
Tag a function to be tested on lazy backends.
pytest.FixtureRequest
request test function parameter.