diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3620e30b..1b54f80b 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1168,26 +1168,24 @@ def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp): @jax.jit def _flux_frac(a, x, y, cenx, ceny): - def _body(d, args): - res, a, dx, dy, cenx, ceny = args - msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) - - res = res.at[d].set( - jnp.sum( - jnp.where( - msk, - a, - 0.0, - ) - ) - ) - - return [res, a, dx, dy, cenx, ceny] - - res = jnp.zeros(a.shape[0], dtype=float) - jnp.inf - return jax.lax.fori_loop( - 0, a.shape[0], _body, [res, a, x - cenx, y - ceny, cenx, ceny] - )[0] + a = jnp.reshape(a, (a.shape[0], a.shape[1], 1)) + dx = x - cenx + dx = jnp.reshape(dx, (a.shape[0], a.shape[1], 1)) + dy = y - ceny + dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1)) + d = jnp.arange(a.shape[0]) + d = jnp.reshape(d, (1, 1, -1)) + msk = (jnp.abs(dx) <= d) & (jnp.abs(dx) <= d) + res = jnp.sum( + jnp.where( + msk, + a, + 0.0, + ), + axis=(0, 1), + ) + res = jnp.where(res > 0, res, -jnp.inf) + return res @jax.jit