Skip to content
96 changes: 33 additions & 63 deletions jax_galsim/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ def unweighted_shape(arg):


@implements(_galsim.utilities.horner)
@functools.partial(jax.jit, static_argnames=("dtype",))
def horner(x, coef, dtype=None):
x = jnp.array(x)
coef = jnp.atleast_1d(coef)
res = jnp.zeros_like(x, dtype=dtype)
if dtype is None:
res_dtype = jnp.result_type(x, coef)
else:
res_dtype = dtype

if len(coef.shape) != 1:
raise GalSimValueError("coef must be 1-dimensional", coef)
Expand All @@ -198,25 +202,24 @@ def horner(x, coef, dtype=None):

return jax.lax.cond(
coef.shape[0] == 0,
lambda x, coef, res: res,
lambda x, coef, res: jax.lax.fori_loop(
0,
coef.shape[0],
lambda i, args: (args[0] * x + args[1][i], args[1]),
(res, coef[::-1]),
)[0],
lambda x, coef: jnp.zeros_like(x, dtype=res_dtype),
lambda x, coef: jnp.array(jnp.polyval(jnp.flip(coef), x), dtype=res_dtype),
x,
coef,
res,
)


@implements(_galsim.utilities.horner2d)
@functools.partial(jax.jit, static_argnames=("triangle", "dtype"))
def horner2d(x, y, coefs, dtype=None, triangle=False):
x = jnp.array(x)
y = jnp.array(y)
coefs = jnp.atleast_1d(coefs)
res = jnp.zeros_like(x, dtype=dtype)
if dtype is None:
res_dtype = jnp.result_type(x, coefs)
else:
res_dtype = dtype
res = jnp.zeros_like(x, dtype=res_dtype)

if x.shape != y.shape:
raise GalSimIncompatibleValuesError("x and y are not the same shape", x=x, y=y)
Expand All @@ -229,57 +232,24 @@ def horner2d(x, y, coefs, dtype=None, triangle=False):
"coefs must be square if triangle is True", coefs=coefs, triangle=triangle
)

coefs = coefs[::-1, :]
if triangle:
# this loop in python looks like
# Note: for each power of x, it computes all powers of y
#
# result = np.zeros_like(x)
# temp = np.zeros_like(x)
#
# for i, coef in enumerate(coefs[::-1]):
# result *= x
# _horner(y, coef[:i+1], temp)
# result += temp

def _body(i, args):
res, coeffs, y = args
# only grab the triangular part
tri_coeffs = jnp.where(
jnp.arange(coeffs.shape[1]) < i + 1,
coeffs[i, :],
jnp.zeros_like(coeffs[i, :]),
)
res = res * x + horner(y, tri_coeffs, dtype=dtype)
return res, coeffs, y

res = jax.lax.fori_loop(
0,
coefs.shape[0],
_body,
(res, coefs[::-1, :], y),
)[0]
else:
# this loop in python looks like
# Note: for each power of x, it computes all powers of y
#
# result = np.zeros_like(x)
# temp = np.zeros_like(x)
#
# for coef in coefs[::-1]:
# result *= x
# _horner(y, coef, temp)
# result += temp

def _body(i, args):
res, coeffs, y = args
res = res * x + horner(y, coeffs[i], dtype=dtype)
return res, coeffs, y

res = jax.lax.fori_loop(
0,
coefs.shape[0],
_body,
(res, coefs[::-1, :], y),
)[0]

return res
coefs = jnp.tril(coefs)

# this loop in python looks like
# Note: for each power of x, it computes all powers of y
#
# result = np.zeros_like(x)
# temp = np.zeros_like(x)
#
# for coef in coefs[::-1]:
# result *= x
# _horner(y, coef, temp)
# result += temp

res = jnp.zeros_like(x, dtype=res_dtype)
res, _ = jax.lax.scan(
lambda res, p: (res * x + horner(y, p, dtype=res_dtype), None), res, coefs
)

return res.astype(res_dtype)
47 changes: 47 additions & 0 deletions tests/jax/test_jax_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax.numpy as jnp
import numpy as np

import jax_galsim


def test_jax_utilities_horner_dtype():
rng = np.random.RandomState(1234)
x = rng.uniform(size=20)
coef = [1.2332, 3.43242, 4.1231, -0.2342, 0.4242]
truth = coef[0] + coef[1] * x + coef[2] * x**2 + coef[3] * x**3 + coef[4] * x**4

result = jax_galsim.utilities.horner(x, coef, dtype=int)
np.testing.assert_almost_equal(result, truth.astype(int))

result = jax_galsim.utilities.horner(x, coef, dtype=float)
np.testing.assert_almost_equal(result, truth)

result = jax_galsim.utilities.horner(x, coef, dtype=complex)
np.testing.assert_almost_equal(result.real, truth)
np.testing.assert_almost_equal(result.imag, np.zeros_like(truth))

result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int32)
assert result.dtype == jnp.int32
np.testing.assert_almost_equal(result, truth.astype(np.int32))

result = jax_galsim.utilities.horner(x, coef, dtype=jnp.int64)
assert result.dtype == jnp.int64
np.testing.assert_almost_equal(result, truth.astype(np.int64))

result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float32)
assert result.dtype == jnp.float32
np.testing.assert_almost_equal(result, truth.astype(np.float32))

result = jax_galsim.utilities.horner(x, coef, dtype=jnp.float64)
assert result.dtype == jnp.float64
np.testing.assert_almost_equal(result, truth)

result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex64)
assert result.dtype == jnp.complex64
np.testing.assert_almost_equal(result.real, truth.astype(np.complex64))
np.testing.assert_almost_equal(result.imag, np.zeros_like(truth))

result = jax_galsim.utilities.horner(x, coef, dtype=jnp.complex128)
assert result.dtype == jnp.complex128
np.testing.assert_almost_equal(result.real, truth)
np.testing.assert_almost_equal(result.imag, np.zeros_like(truth))