From 2b92b17c6554b50307a14ef5a04a737ebf7d7710 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 13:44:02 -0600 Subject: [PATCH 1/9] feat: add benchmark and test of flux frac for interp images --- tests/jax/test_benchmarks.py | 19 ++++++ tests/jax/test_interpolatedimage_utils.py | 74 +++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index d00cad23..6d62b1b5 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -238,3 +238,22 @@ def test_benchmark_gaussian_init(benchmark, kind): benchmark, kind, lambda: _run_gaussian_bench_init().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_interpimage_flux_frac(img): + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + return jgs.interpolatedimage._flux_frac(img.array, x, y, cenx, ceny) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_interpimage_flux_frac(benchmark, kind): + obj = jgs.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.2, method="no_pixel") + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_interpimage_flux_frac(img).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ") diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 7043615f..84bd8f00 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -16,6 +16,7 @@ from jax_galsim.interpolatedimage import ( _draw_with_interpolant_kval, _draw_with_interpolant_xval, + _flux_frac, ) @@ -354,3 +355,76 @@ def test_interpolatedimage_interpolant_sample(interp): fdev = np.abs(h - yv) / np.abs(np.sqrt(yv)) np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") + + +def test_interpolatedimage_flux_frac(): + obj = jax_galsim.Gaussian(half_light_radius=0.9).shear(g1=0.1, g2=0.2) + img = obj.drawImage(nx=55, ny=55, scale=0.05, method="no_pixel") + true_val = [ + 0.02186161, + 0.06551123, + 0.10894079, + 0.15200604, + 0.19456641, + 0.23648629, + 0.27763629, + 0.31789470, + 0.35714823, + 0.39529300, + 0.43223542, + 0.46789303, + 0.50219434, + 0.53507960, + 0.56650090, + 0.59642231, + 0.62481892, + 0.65167749, + 0.67699528, + 0.70077991, + 0.72304893, + 0.74382806, + 0.76315117, + 0.78105938, + 0.79759991, + 0.81282544, + 0.82679272, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + 0.83956224, + ] + + x, y = img.get_pixel_centers() + cenx = img.center.x + ceny = img.center.y + val = _flux_frac(img.array, x, y, cenx, ceny) + np.testing.assert_allclose( + val, + true_val, + rtol=0, + atol=1e-8, + ) From 16494d8dbb2c075a0441bbf211d70731c6ac14a6 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 27 Jan 2025 14:30:59 -0600 Subject: [PATCH 2/9] perf: use built-in 1d polynomial eval --- jax_galsim/utilities.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index b03a2835..76808918 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -186,7 +186,6 @@ def unweighted_shape(arg): def horner(x, coef, dtype=None): x = jnp.array(x) coef = jnp.atleast_1d(coef) - res = jnp.zeros_like(x, dtype=dtype) if len(coef.shape) != 1: raise GalSimValueError("coef must be 1-dimensional", coef) @@ -198,16 +197,10 @@ def horner(x, coef, dtype=None): return jax.lax.cond( coef.shape[0] == 0, - lambda x, coef, res: res, - lambda x, coef, res: jax.lax.fori_loop( - 0, - coef.shape[0], - lambda i, args: (args[0] * x + args[1][i], args[1]), - (res, coef[::-1]), - )[0], + lambda x, coef: jnp.zeros_like(x, dtype=dtype), + lambda x, coef: jnp.polyval(jnp.flip(coef), x), x, coef, - res, ) From 9adb991b43fa3c6e2b4328432aa2ca9ec79d9526 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Jan 2025 12:19:15 -0600 Subject: [PATCH 3/9] Update jax_galsim/utilities.py --- jax_galsim/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 76808918..aaf90907 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -198,7 +198,7 @@ def horner(x, coef, dtype=None): return jax.lax.cond( coef.shape[0] == 0, lambda x, coef: jnp.zeros_like(x, dtype=dtype), - lambda x, coef: jnp.polyval(jnp.flip(coef), x), + lambda x, coef: jnp.array(jnp.polyval(jnp.flip(coef), x), dtype=dtype), x, coef, ) From 6c9e2cef4f4d410f4496c56aabce874ef92ba9d2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Jan 2025 13:19:31 -0600 Subject: [PATCH 4/9] fix: get result dtype properly --- jax_galsim/utilities.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index aaf90907..a533d658 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -186,6 +186,7 @@ def unweighted_shape(arg): def horner(x, coef, dtype=None): x = jnp.array(x) coef = jnp.atleast_1d(coef) + res_dtype = jnp.result_type(x, coef, dtype) if len(coef.shape) != 1: raise GalSimValueError("coef must be 1-dimensional", coef) @@ -197,8 +198,8 @@ def horner(x, coef, dtype=None): return jax.lax.cond( coef.shape[0] == 0, - lambda x, coef: jnp.zeros_like(x, dtype=dtype), - lambda x, coef: jnp.array(jnp.polyval(jnp.flip(coef), x), dtype=dtype), + lambda x, coef: jnp.zeros_like(x, dtype=res_dtype), + lambda x, coef: jnp.array(jnp.polyval(jnp.flip(coef), x), dtype=res_dtype), x, coef, ) From 056fcdd72e5f5838756bf65df21dbfadadd6e1b3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Jan 2025 13:42:11 -0600 Subject: [PATCH 5/9] test: add extra test --- tests/jax/test_jax_utilities.py | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/jax/test_jax_utilities.py diff --git a/tests/jax/test_jax_utilities.py b/tests/jax/test_jax_utilities.py new file mode 100644 index 00000000..4f5c093e --- /dev/null +++ b/tests/jax/test_jax_utilities.py @@ -0,0 +1,41 @@ +import jax.numpy as jnp +import numpy as np + +import jax_galsim + + +def test_jax_utilities_horner_dtype(): + rng = np.random.RandomState(1234) + x = rng.uniform(size=20) + coef = [1.2332, 3.43242, 4.1231, -0.2342, 0.4242] + truth = coef[0] + coef[1] * x + coef[2] * x**2 + coef[3] * x**3 + coef[4] * x**4 + + result = jax_galsim.utilities.horner(x, coef, dtype=int) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=float) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=complex) + np.testing.assert_almost_equal(result.real, truth) + np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int32) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int64) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float32) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float64) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex64) + np.testing.assert_almost_equal(result.real, truth) + np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex128) + np.testing.assert_almost_equal(result.real, truth) + np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) From 938989b480e30b000935613e05b2a189015b9c4b Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Jan 2025 14:09:40 -0600 Subject: [PATCH 6/9] perf: use scan here too --- jax_galsim/utilities.py | 74 ++++++++++++----------------------------- 1 file changed, 22 insertions(+), 52 deletions(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index a533d658..34bd07ed 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -183,6 +183,7 @@ def unweighted_shape(arg): @implements(_galsim.utilities.horner) +@functools.partial(jax.jit, static_argnames=("dtype",)) def horner(x, coef, dtype=None): x = jnp.array(x) coef = jnp.atleast_1d(coef) @@ -206,11 +207,13 @@ def horner(x, coef, dtype=None): @implements(_galsim.utilities.horner2d) +@functools.partial(jax.jit, static_argnames=("triangle", "dtype")) def horner2d(x, y, coefs, dtype=None, triangle=False): x = jnp.array(x) y = jnp.array(y) coefs = jnp.atleast_1d(coefs) - res = jnp.zeros_like(x, dtype=dtype) + res_dtype = jnp.result_type(x, coefs, dtype) + res = jnp.zeros_like(x, dtype=res_dtype) if x.shape != y.shape: raise GalSimIncompatibleValuesError("x and y are not the same shape", x=x, y=y) @@ -223,57 +226,24 @@ def horner2d(x, y, coefs, dtype=None, triangle=False): "coefs must be square if triangle is True", coefs=coefs, triangle=triangle ) + coefs = coefs[::-1, :] if triangle: - # this loop in python looks like - # Note: for each power of x, it computes all powers of y - # - # result = np.zeros_like(x) - # temp = np.zeros_like(x) - # - # for i, coef in enumerate(coefs[::-1]): - # result *= x - # _horner(y, coef[:i+1], temp) - # result += temp - - def _body(i, args): - res, coeffs, y = args - # only grab the triangular part - tri_coeffs = jnp.where( - jnp.arange(coeffs.shape[1]) < i + 1, - coeffs[i, :], - jnp.zeros_like(coeffs[i, :]), - ) - res = res * x + horner(y, tri_coeffs, dtype=dtype) - return res, coeffs, y - - res = jax.lax.fori_loop( - 0, - coefs.shape[0], - _body, - (res, coefs[::-1, :], y), - )[0] - else: - # this loop in python looks like - # Note: for each power of x, it computes all powers of y - # - # result = np.zeros_like(x) - # temp = np.zeros_like(x) - # - # for coef in coefs[::-1]: - # result *= x - # _horner(y, coef, temp) - # result += temp - - def _body(i, args): - res, coeffs, y = args - res = res * x + horner(y, coeffs[i], dtype=dtype) - return res, coeffs, y - - res = jax.lax.fori_loop( - 0, - coefs.shape[0], - _body, - (res, coefs[::-1, :], y), - )[0] + coefs = jnp.tril(coefs) + + # this loop in python looks like + # Note: for each power of x, it computes all powers of y + # + # result = np.zeros_like(x) + # temp = np.zeros_like(x) + # + # for coef in coefs[::-1]: + # result *= x + # _horner(y, coef, temp) + # result += temp + + res = jnp.zeros_like(x, dtype=res_dtype) + res, _ = jax.lax.scan( + lambda res, p: (res * x + horner(y, p, dtype=res_dtype), None), res, coefs + ) return res From d8d14e48a63538cd4634e092d17a3c2912279749 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Tue, 28 Jan 2025 16:54:37 -0600 Subject: [PATCH 7/9] Apply suggestions from code review --- jax_galsim/utilities.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 34bd07ed..fd39e03b 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -187,7 +187,10 @@ def unweighted_shape(arg): def horner(x, coef, dtype=None): x = jnp.array(x) coef = jnp.atleast_1d(coef) - res_dtype = jnp.result_type(x, coef, dtype) + if dtype is None: + res_dtype = jnp.result_type(x, coef) + else: + res_dtype = dtype if len(coef.shape) != 1: raise GalSimValueError("coef must be 1-dimensional", coef) @@ -212,7 +215,10 @@ def horner2d(x, y, coefs, dtype=None, triangle=False): x = jnp.array(x) y = jnp.array(y) coefs = jnp.atleast_1d(coefs) - res_dtype = jnp.result_type(x, coefs, dtype) + if dtype is None: + res_dtype = jnp.result_type(x, coefs) + else: + res_dtype = dtype res = jnp.zeros_like(x, dtype=res_dtype) if x.shape != y.shape: @@ -246,4 +252,4 @@ def horner2d(x, y, coefs, dtype=None, triangle=False): lambda res, p: (res * x + horner(y, p, dtype=res_dtype), None), res, coefs ) - return res + return res.astype(res_dtype) From da368b6dc804c3039dcc57a9fae15fd9ba9730f0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Jan 2025 16:57:24 -0600 Subject: [PATCH 8/9] fix: need casts in tests --- tests/jax/test_jax_utilities.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_jax_utilities.py b/tests/jax/test_jax_utilities.py index 4f5c093e..c6cd852c 100644 --- a/tests/jax/test_jax_utilities.py +++ b/tests/jax/test_jax_utilities.py @@ -11,7 +11,7 @@ def test_jax_utilities_horner_dtype(): truth = coef[0] + coef[1] * x + coef[2] * x**2 + coef[3] * x**3 + coef[4] * x**4 result = jax_galsim.utilities.horner(x, coef, dtype=int) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_almost_equal(result, truth.astype(int)) result = jax_galsim.utilities.horner(x, coef, dtype=float) np.testing.assert_almost_equal(result, truth) @@ -21,19 +21,19 @@ def test_jax_utilities_horner_dtype(): np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int32) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_almost_equal(result, truth.astype(np.int32)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int64) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_almost_equal(result, truth.astype(np.int64)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float32) - np.testing.assert_almost_equal(result, truth) + np.testing.assert_almost_equal(result, truth.astype(np.float32)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float64) np.testing.assert_almost_equal(result, truth) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex64) - np.testing.assert_almost_equal(result.real, truth) + np.testing.assert_almost_equal(result.real, truth.astype(np.complex64)) np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex128) From aed2f894f50071ba47c7070327387c063bcfaec4 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 28 Jan 2025 17:02:38 -0600 Subject: [PATCH 9/9] test: explicitly test the dtype --- tests/jax/test_jax_utilities.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/jax/test_jax_utilities.py b/tests/jax/test_jax_utilities.py index c6cd852c..504cbe1d 100644 --- a/tests/jax/test_jax_utilities.py +++ b/tests/jax/test_jax_utilities.py @@ -21,21 +21,27 @@ def test_jax_utilities_horner_dtype(): np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int32) + assert result.dtype == jnp.int32 np.testing.assert_almost_equal(result, truth.astype(np.int32)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int64) + assert result.dtype == jnp.int64 np.testing.assert_almost_equal(result, truth.astype(np.int64)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float32) + assert result.dtype == jnp.float32 np.testing.assert_almost_equal(result, truth.astype(np.float32)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float64) + assert result.dtype == jnp.float64 np.testing.assert_almost_equal(result, truth) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex64) + assert result.dtype == jnp.complex64 np.testing.assert_almost_equal(result.real, truth.astype(np.complex64)) np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex128) + assert result.dtype == jnp.complex128 np.testing.assert_almost_equal(result.real, truth) np.testing.assert_almost_equal(result.imag, np.zeros_like(truth))