From 2d3f40c9080711d2a9b6655f3b0a599642e1bd82 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 11:19:58 -0600 Subject: [PATCH 1/2] test: add benchmark for rng discard --- tests/jax/test_benchmarks.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index d00cad23..d36e5a98 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -238,3 +238,16 @@ def test_benchmark_gaussian_init(benchmark, kind): benchmark, kind, lambda: _run_gaussian_bench_init().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +@jax.jit +def _run_benchmark_rng_discard(rng): + return rng.discard(1000) + + +def test_benchmark_rng_discard(benchmark): + rng = jgs.BaseDeviate(seed=42) + dt = _run_benchmarks( + benchmark, "run", lambda: _run_benchmark_rng_discard(rng).block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ") From 6f15cdadcb4440cc84073d0b047685d0f8fdb679 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 11:23:00 -0600 Subject: [PATCH 2/2] fix: make benchmark work --- tests/jax/test_benchmarks.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index d36e5a98..53e75091 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -242,12 +242,14 @@ def test_benchmark_gaussian_init(benchmark, kind): @jax.jit def _run_benchmark_rng_discard(rng): - return rng.discard(1000) + rng.discard(1000) + return rng._state.key -def test_benchmark_rng_discard(benchmark): +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_rng_discard(benchmark, kind): rng = jgs.BaseDeviate(seed=42) dt = _run_benchmarks( - benchmark, "run", lambda: _run_benchmark_rng_discard(rng).block_until_ready() + benchmark, kind, lambda: _run_benchmark_rng_discard(rng).block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ")