diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 53e75091..3c773dc6 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -240,6 +240,25 @@ def test_benchmark_gaussian_init(benchmark, kind): print(f"time: {dt:0.4g} ms", end=" ") +def _run_benchmark_interpimage_flux_frac(img): + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + return jgs.interpolatedimage._flux_frac(img.array, x, y, cenx, ceny) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_interpimage_flux_frac(benchmark, kind): + obj = jgs.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.2, method="no_pixel") + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_interpimage_flux_frac(img).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ") + + @jax.jit def _run_benchmark_rng_discard(rng): rng.discard(1000) diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 7043615f..56f3ced5 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -16,6 +16,7 @@ from jax_galsim.interpolatedimage import ( _draw_with_interpolant_kval, _draw_with_interpolant_xval, + _flux_frac, ) @@ -354,3 +355,76 @@ def test_interpolatedimage_interpolant_sample(interp): fdev = np.abs(h - yv) / np.abs(np.sqrt(yv)) np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") + + +def test_interpolatedimage_flux_frac(): + obj = jax_galsim.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.05, method="no_pixel") + true_val = [ + 0.02186161, + 0.06551123, + 0.10894079, + 0.15200604, + 0.19456641, + 0.23648629, + 0.27763629, + 0.31789470, + 0.35714823, + 0.39529300, + 0.43223542, + 0.46789303, + 0.50219434, + 0.53507960, + 0.56650090, + 0.59642231, + 0.62481892, + 0.65167749, + 0.67699528, + 0.70077991, + 0.72304893, + 0.74382806, + 0.76315117, + 0.78105938, + 0.79759991, + 0.81282544, + 0.82679272, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + ] + + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + val = _flux_frac(img.array, x, y, cenx, ceny) + np.testing.assert_allclose( + val, + true_val, + rtol=0, + atol=1e-6, + )