From 2b92b17c6554b50307a14ef5a04a737ebf7d7710 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 13:44:02 -0600 Subject: [PATCH 1/2] feat: add benchmark and test of flux frac for interp images --- tests/jax/test_benchmarks.py | 19 ++++++ tests/jax/test_interpolatedimage_utils.py | 74 +++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index d00cad23..6d62b1b5 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -238,3 +238,22 @@ def test_benchmark_gaussian_init(benchmark, kind): benchmark, kind, lambda: _run_gaussian_bench_init().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_interpimage_flux_frac(img): + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + return jgs.interpolatedimage._flux_frac(img.array, x, y, cenx, ceny) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_interpimage_flux_frac(benchmark, kind): + obj = jgs.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.2, method="no_pixel") + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_interpimage_flux_frac(img).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ") diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 7043615f..84bd8f00 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -16,6 +16,7 @@ from jax_galsim.interpolatedimage import ( _draw_with_interpolant_kval, _draw_with_interpolant_xval, + _flux_frac, ) @@ -354,3 +355,76 @@ def test_interpolatedimage_interpolant_sample(interp): fdev = np.abs(h - yv) / np.abs(np.sqrt(yv)) np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") + + +def test_interpolatedimage_flux_frac(): + obj = jax_galsim.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.05, method="no_pixel") + true_val = [ + 0.02186161, + 0.06551123, + 0.10894079, + 0.15200604, + 0.19456641, + 0.23648629, + 0.27763629, + 0.31789470, + 0.35714823, + 0.39529300, + 0.43223542, + 0.46789303, + 0.50219434, + 0.53507960, + 0.56650090, + 0.59642231, + 0.62481892, + 0.65167749, + 0.67699528, + 0.70077991, + 0.72304893, + 0.74382806, + 0.76315117, + 0.78105938, + 0.79759991, + 0.81282544, + 0.82679272, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + ] + + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + val = _flux_frac(img.array, x, y, cenx, ceny) + np.testing.assert_allclose( + val, + true_val, + rtol=0, + atol=1e-8, + ) From 600cad2c16d6c4c0464dadbf973504c2908b6d5d Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 27 Jan 2025 14:32:51 -0600 Subject: [PATCH 2/2] Update tests/jax/test_interpolatedimage_utils.py --- tests/jax/test_interpolatedimage_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 84bd8f00..56f3ced5 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -426,5 +426,5 @@ def test_interpolatedimage_flux_frac(): val, true_val, rtol=0, - atol=1e-8, + atol=1e-6, )