diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index d00cad23..53e75091 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -238,3 +238,18 @@ def test_benchmark_gaussian_init(benchmark, kind): benchmark, kind, lambda: _run_gaussian_bench_init().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +@jax.jit +def _run_benchmark_rng_discard(rng): + rng.discard(1000) + return rng._state.key + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_rng_discard(benchmark, kind): + rng = jgs.BaseDeviate(seed=42) + dt = _run_benchmarks( + benchmark, kind, lambda: _run_benchmark_rng_discard(rng).block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ")