From 46b8971748bba3a08501c3075e5822e05c05ec1d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Jan 2025 12:59:02 -0600 Subject: [PATCH 1/4] test: add benchmark for moffat init --- tests/jax/test_benchmarks.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index f5eb0802..f73d569a 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -289,3 +289,15 @@ def test_benchmark_invert_ab_noraise(benchmark, kind): lambda: _run_benchmark_invert_ab_noraise(u, v, ab).block_until_ready(), ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_moffat_init(): + return jgs.Moffat(beta=2.5, half_light_radius=0.6, trunc=1.2).scale_radius + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_moffat_init(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_benchmark_moffat_init().block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ") From 555ec84dc3c22794604b49398fe40394388c2488 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Jan 2025 13:09:49 -0600 Subject: [PATCH 2/4] fix: missed 1k loop --- jax_galsim/moffat.py | 60 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 07502dce..77c3b1f9 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -37,7 +37,36 @@ def _hankel(k, beta, rmax): @jax.jit -def _MoffatCalculateSRFromHLR(re, rm, beta): +def _bodymi(xcur, rm, re, beta): + x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 + x = jnp.power(x, 1 / (1 - beta)) + x = jnp.sqrt(x - 1) + return re / x + + +@jax.jit +def _bodymi10(xcur, rm, re, beta): + for _ in range(10): + xcur = _bodymi(xcur, rm, re, beta) + return xcur + + +@jax.jit +def _bodymi100(xcur, rm, re, beta): + for _ in range(10): + xcur = _bodymi10(xcur, rm, re, beta) + return xcur + + +@jax.jit +def _bodymi1000(xcur, rm, re, beta): + for _ in range(10): + xcur = _bodymi100(xcur, rm, re, beta) + return xcur + + +@partial(jax.jit, static_argnames=("nitr",)) +def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=1000): """ The basic equation that is relevant here is the flux of a Moffat profile out to some radius. @@ -54,16 +83,29 @@ def _MoffatCalculateSRFromHLR(re, rm, beta): BUT the case rm==0 is already done, so HERE rm != 0 """ - # fix loop iteration is faster and reach eps=1e-6 (single precision) - def body(i, xcur): - x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 - x = jnp.power(x, 1 / (1 - beta)) - x = jnp.sqrt(x - 1) - return re / x + nitr_1000 = nitr // 1000 + nrem_1000 = nitr % 1000 + nitr_100 = nrem_1000 // 100 + nrem_100 = nrem_1000 % 100 + nitr_10 = nrem_100 // 10 + nrem_10 = nrem_100 % 10 + assert nitr_1000 * 1000 + nitr_100 * 100 + nitr_10 * 10 + nrem_10 == nitr + + xcur = re + + for _ in range(nitr_1000): + xcur = _bodymi1000(xcur, rm, re, beta) + + for _ in range(nitr_100): + xcur = _bodymi100(xcur, rm, re, beta) + + for _ in range(nitr_10): + xcur = _bodymi10(xcur, rm, re, beta) - rd = jax.lax.fori_loop(0, 1000, body, re) + for _ in range(nrem_10): + xcur = _bodymi(xcur, rm, re, beta) - return rd + return xcur @implements(_galsim.Moffat) From 70159b73d77351c729fd8a2f152731ee4730ecdc Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Jan 2025 13:14:53 -0600 Subject: [PATCH 3/4] fix: try simpler code --- jax_galsim/moffat.py | 43 +------------------------------------------ 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 77c3b1f9..3795aced 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -44,27 +44,6 @@ def _bodymi(xcur, rm, re, beta): return re / x -@jax.jit -def _bodymi10(xcur, rm, re, beta): - for _ in range(10): - xcur = _bodymi(xcur, rm, re, beta) - return xcur - - -@jax.jit -def _bodymi100(xcur, rm, re, beta): - for _ in range(10): - xcur = _bodymi10(xcur, rm, re, beta) - return xcur - - -@jax.jit -def _bodymi1000(xcur, rm, re, beta): - for _ in range(10): - xcur = _bodymi100(xcur, rm, re, beta) - return xcur - - @partial(jax.jit, static_argnames=("nitr",)) def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=1000): """ @@ -82,29 +61,9 @@ def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=1000): nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf. BUT the case rm==0 is already done, so HERE rm != 0 """ - - nitr_1000 = nitr // 1000 - nrem_1000 = nitr % 1000 - nitr_100 = nrem_1000 // 100 - nrem_100 = nrem_1000 % 100 - nitr_10 = nrem_100 // 10 - nrem_10 = nrem_100 % 10 - assert nitr_1000 * 1000 + nitr_100 * 100 + nitr_10 * 10 + nrem_10 == nitr - xcur = re - - for _ in range(nitr_1000): - xcur = _bodymi1000(xcur, rm, re, beta) - - for _ in range(nitr_100): - xcur = _bodymi100(xcur, rm, re, beta) - - for _ in range(nitr_10): - xcur = _bodymi10(xcur, rm, re, beta) - - for _ in range(nrem_10): + for _ in range(nitr): xcur = _bodymi(xcur, rm, re, beta) - return xcur From c604a081d095c11058bd46129b9e6be8f369852b Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Wed, 29 Jan 2025 22:51:27 -0600 Subject: [PATCH 4/4] Update jax_galsim/moffat.py --- jax_galsim/moffat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 3795aced..0563ba78 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -45,7 +45,7 @@ def _bodymi(xcur, rm, re, beta): @partial(jax.jit, static_argnames=("nitr",)) -def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=1000): +def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100): """ The basic equation that is relevant here is the flux of a Moffat profile out to some radius.