diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index b03a2835..fd39e03b 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -183,10 +183,14 @@ def unweighted_shape(arg): @implements(_galsim.utilities.horner) +@functools.partial(jax.jit, static_argnames=("dtype",)) def horner(x, coef, dtype=None): x = jnp.array(x) coef = jnp.atleast_1d(coef) - res = jnp.zeros_like(x, dtype=dtype) + if dtype is None: + res_dtype = jnp.result_type(x, coef) + else: + res_dtype = dtype if len(coef.shape) != 1: raise GalSimValueError("coef must be 1-dimensional", coef) @@ -198,25 +202,24 @@ def horner(x, coef, dtype=None): return jax.lax.cond( coef.shape[0] == 0, - lambda x, coef, res: res, - lambda x, coef, res: jax.lax.fori_loop( - 0, - coef.shape[0], - lambda i, args: (args[0] * x + args[1][i], args[1]), - (res, coef[::-1]), - )[0], + lambda x, coef: jnp.zeros_like(x, dtype=res_dtype), + lambda x, coef: jnp.array(jnp.polyval(jnp.flip(coef), x), dtype=res_dtype), x, coef, - res, ) @implements(_galsim.utilities.horner2d) +@functools.partial(jax.jit, static_argnames=("triangle", "dtype")) def horner2d(x, y, coefs, dtype=None, triangle=False): x = jnp.array(x) y = jnp.array(y) coefs = jnp.atleast_1d(coefs) - res = jnp.zeros_like(x, dtype=dtype) + if dtype is None: + res_dtype = jnp.result_type(x, coefs) + else: + res_dtype = dtype + res = jnp.zeros_like(x, dtype=res_dtype) if x.shape != y.shape: raise GalSimIncompatibleValuesError("x and y are not the same shape", x=x, y=y) @@ -229,57 +232,24 @@ def horner2d(x, y, coefs, dtype=None, triangle=False): "coefs must be square if triangle is True", coefs=coefs, triangle=triangle ) + coefs = coefs[::-1, :] if triangle: - # this loop in python looks like - # Note: for each power of x, it computes all powers of y - # - # result = np.zeros_like(x) - # temp = np.zeros_like(x) - # - # for i, coef in enumerate(coefs[::-1]): - # result *= x - # _horner(y, coef[:i+1], temp) - # result += temp - - def _body(i, args): - res, coeffs, y = args - # only grab the triangular part - tri_coeffs = jnp.where( - jnp.arange(coeffs.shape[1]) < i + 1, - coeffs[i, :], - jnp.zeros_like(coeffs[i, :]), - ) - res = res * x + horner(y, tri_coeffs, dtype=dtype) - return res, coeffs, y - - res = jax.lax.fori_loop( - 0, - coefs.shape[0], - _body, - (res, coefs[::-1, :], y), - )[0] - else: - # this loop in python looks like - # Note: for each power of x, it computes all powers of y - # - # result = np.zeros_like(x) - # temp = np.zeros_like(x) - # - # for coef in coefs[::-1]: - # result *= x - # _horner(y, coef, temp) - # result += temp - - def _body(i, args): - res, coeffs, y = args - res = res * x + horner(y, coeffs[i], dtype=dtype) - return res, coeffs, y - - res = jax.lax.fori_loop( - 0, - coefs.shape[0], - _body, - (res, coefs[::-1, :], y), - )[0] - - return res + coefs = jnp.tril(coefs) + + # this loop in python looks like + # Note: for each power of x, it computes all powers of y + # + # result = np.zeros_like(x) + # temp = np.zeros_like(x) + # + # for coef in coefs[::-1]: + # result *= x + # _horner(y, coef, temp) + # result += temp + + res = jnp.zeros_like(x, dtype=res_dtype) + res, _ = jax.lax.scan( + lambda res, p: (res * x + horner(y, p, dtype=res_dtype), None), res, coefs + ) + + return res.astype(res_dtype) diff --git a/tests/jax/test_jax_utilities.py b/tests/jax/test_jax_utilities.py new file mode 100644 index 00000000..504cbe1d --- /dev/null +++ b/tests/jax/test_jax_utilities.py @@ -0,0 +1,47 @@ +import jax.numpy as jnp +import numpy as np + +import jax_galsim + + +def test_jax_utilities_horner_dtype(): + rng = np.random.RandomState(1234) + x = rng.uniform(size=20) + coef = [1.2332, 3.43242, 4.1231, -0.2342, 0.4242] + truth = coef[0] + coef[1] * x + coef[2] * x**2 + coef[3] * x**3 + coef[4] * x**4 + + result = jax_galsim.utilities.horner(x, coef, dtype=int) + np.testing.assert_almost_equal(result, truth.astype(int)) + + result = jax_galsim.utilities.horner(x, coef, dtype=float) + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=complex) + np.testing.assert_almost_equal(result.real, truth) + np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int32) + assert result.dtype == jnp.int32 + np.testing.assert_almost_equal(result, truth.astype(np.int32)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int64) + assert result.dtype == jnp.int64 + np.testing.assert_almost_equal(result, truth.astype(np.int64)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float32) + assert result.dtype == jnp.float32 + np.testing.assert_almost_equal(result, truth.astype(np.float32)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float64) + assert result.dtype == jnp.float64 + np.testing.assert_almost_equal(result, truth) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex64) + assert result.dtype == jnp.complex64 + np.testing.assert_almost_equal(result.real, truth.astype(np.complex64)) + np.testing.assert_almost_equal(result.imag, np.zeros_like(truth)) + + result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex128) + assert result.dtype == jnp.complex128 + np.testing.assert_almost_equal(result.real, truth) + np.testing.assert_almost_equal(result.imag, np.zeros_like(truth))