diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 5c59da80..fdfe6b42 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1028,9 +1028,7 @@ def _invert_ab_noraise(u, v, ab, abp=None): dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1] dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:] - def _step(i, args): - x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args - + for _ in range(10): # Want Jac^-1 . du # du du = horner2d(x, y, ab[0], triangle=True) - u @@ -1042,32 +1040,11 @@ def _step(i, args): dvdy = horner2d(x, y, dvdycoef, triangle=True) # J^-1 . du det = dudx * dvdy - dudy * dvdx - duu = -(du * dvdy - dv * dudy) / det - dvv = -(-du * dvdx + dv * dudx) / det - - x += duu - y += dvv - - return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef - - x, y, dx, dy = jax.lax.fori_loop( - 0, - 10, - _step, - ( - x, - y, - jnp.zeros_like(x), - jnp.zeros_like(y), - u, - v, - ab, - dudxcoef, - dudycoef, - dvdxcoef, - dvdycoef, - ), - )[0:4] + dx = -(du * dvdy - dv * dudy) / det + dy = -(-du * dvdx + dv * dudx) / det + + x += dx + y += dy x, y = jax.lax.cond( jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,