From 2b92b17c6554b50307a14ef5a04a737ebf7d7710 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 13:44:02 -0600 Subject: [PATCH 1/3] feat: add benchmark and test of flux frac for interp images --- tests/jax/test_benchmarks.py | 19 ++++++ tests/jax/test_interpolatedimage_utils.py | 74 +++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index d00cad23..6d62b1b5 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -238,3 +238,22 @@ def test_benchmark_gaussian_init(benchmark, kind): benchmark, kind, lambda: _run_gaussian_bench_init().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_interpimage_flux_frac(img): + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + return jgs.interpolatedimage._flux_frac(img.array, x, y, cenx, ceny) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_interpimage_flux_frac(benchmark, kind): + obj = jgs.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.2, method="no_pixel") + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_interpimage_flux_frac(img).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ") diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 7043615f..84bd8f00 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -16,6 +16,7 @@ from jax_galsim.interpolatedimage import ( _draw_with_interpolant_kval, _draw_with_interpolant_xval, + _flux_frac, ) @@ -354,3 +355,76 @@ def test_interpolatedimage_interpolant_sample(interp): fdev = np.abs(h - yv) / np.abs(np.sqrt(yv)) np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") + + +def test_interpolatedimage_flux_frac(): + obj = jax_galsim.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.05, method="no_pixel") + true_val = [ + 0.02186161, + 0.06551123, + 0.10894079, + 0.15200604, + 0.19456641, + 0.23648629, + 0.27763629, + 0.31789470, + 0.35714823, + 0.39529300, + 0.43223542, + 0.46789303, + 0.50219434, + 0.53507960, + 0.56650090, + 0.59642231, + 0.62481892, + 0.65167749, + 0.67699528, + 0.70077991, + 0.72304893, + 0.74382806, + 0.76315117, + 0.78105938, + 0.79759991, + 0.81282544, + 0.82679272, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + ] + + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + val = _flux_frac(img.array, x, y, cenx, ceny) + np.testing.assert_allclose( + val, + true_val, + rtol=0, + atol=1e-8, + ) From 5788230a286937c03256b8506b816f107609a48b Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 14:05:50 -0600 Subject: [PATCH 2/3] perf: remove fori_loop for flux frac --- jax_galsim/interpolatedimage.py | 37 +++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3620e30b..cc2ff60a 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1168,26 +1168,23 @@ def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): @jax.jit def _flux_frac(a, x, y, cenx, ceny): - def _body(d, args): - res, a, dx, dy, cenx, ceny = args - msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) - - res = res.at[d].set( - jnp.sum( - jnp.where( - msk, - a, - 0.0, - ) - ) - ) - - return [res, a, dx, dy, cenx, ceny] - - res = jnp.zeros(a.shape[0], dtype=float) - jnp.inf - return jax.lax.fori_loop( - 0, a.shape[0], _body, [res, a, x - cenx, y - ceny, cenx, ceny] - )[0] + a = jnp.reshape(a, (a.shape[0], a.shape[1], 1)) + dx = x - cenx + dx = jnp.reshape(dx, (a.shape[0], a.shape[1], 1)) + dy = y - ceny + dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1)) + d = jnp.arange(a.shape[0]) + d = jnp.reshape(d, (1, 1, -1)) + msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) + res = jnp.sum( + jnp.where( + msk, + a, + 0.0, + ), + axis=(0, 1), + ) + return res @jax.jit From c61506313157cddc8258d86acb1f6ee6837d4c9b Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 14:07:03 -0600 Subject: [PATCH 3/3] fix: use -inf for zeros --- jax_galsim/interpolatedimage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index cc2ff60a..1b54f80b 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1184,6 +1184,7 @@ def _flux_frac(a, x, y, cenx, ceny): ), axis=(0, 1), ) + res = jnp.where(res > 0, res, -jnp.inf) return res