array_api_extra.testing.patch_lazy_xp_functions¶
- array_api_extra.testing.patch_lazy_xp_functions(request, monkeypatch=None, *, xp)¶
Test lazy execution of functions tagged with
lazy_xp_function()
.If
xp==jax.numpy
, search for all functions which have been tagged withlazy_xp_function()
in the globals of the module that defines the current test, as well as in thelazy_xp_modules
list in the globals of the same module, and wrap them withjax.jit()
. Unwrap them at the end of the test.If
xp==dask.array
, wrap the functions with a decorator that disablescompute()
andpersist()
and ensures that exceptions and warnings are raised eagerly.This function should be typically called by your library’s xp fixture that runs tests on multiple backends:
@pytest.fixture(params=[ numpy, array_api_strict, pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe), pytest.param(dask.array, marks=pytest.mark.thread_unsafe), ]) def xp(request): with patch_lazy_xp_functions(request, xp=request.param): yield request.param
but it can be otherwise be called by the test itself too.
- Parameters:
request (pytest.FixtureRequest) – Pytest fixture, as acquired by the test itself or by one of its fixtures.
monkeypatch (pytest.MonkeyPatch) – Deprecated
xp (array_namespace) – Array namespace to be tested.
- Return type:
See also
lazy_xp_function
Tag a function to be tested on lazy backends.
pytest.FixtureRequest
request test function parameter.
Notes
This context manager monkey-patches modules and as such is thread unsafe on Dask and JAX. If you run your test suite with pytest-run-parallel, you should mark these backends with
@pytest.mark.thread_unsafe
, as shown in the example above.