diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 3c773dc6..f5eb0802 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -272,3 +272,20 @@ def test_benchmark_rng_discard(benchmark, kind): benchmark, kind, lambda: _run_benchmark_rng_discard(rng).block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_invert_ab_noraise(u, v, ab): + return jgs.fitswcs._invert_ab_noraise(u, v, ab)[0] + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_invert_ab_noraise(benchmark, kind): + u = jnp.arange(1000).astype(jnp.float64) + v = jnp.arange(1000).astype(jnp.float64) + ab = jnp.array([[[-0.5, 0.3], [-0.1, 2.0]], [[-1.0, 0.3], [-0.1, 4.0]]]) + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_invert_ab_noraise(u, v, ab).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ")