From 4764697201b2e81623f04dcac8f17528215bb5fe Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Jan 2025 06:38:50 -0600 Subject: [PATCH 1/2] perf: remove fori loop for fits wcs --- jax_galsim/fitswcs.py | 35 ++++++----------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 5c59da80..fdfe6b42 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1028,9 +1028,7 @@ def _invert_ab_noraise(u, v, ab, abp=None): dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1] dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:] - def _step(i, args): - x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args - + for _ in range(10): # Want Jac^-1 . du # du du = horner2d(x, y, ab[0], triangle=True) - u @@ -1042,32 +1040,11 @@ def _step(i, args): dvdy = horner2d(x, y, dvdycoef, triangle=True) # J^-1 . du det = dudx * dvdy - dudy * dvdx - duu = -(du * dvdy - dv * dudy) / det - dvv = -(-du * dvdx + dv * dudx) / det - - x += duu - y += dvv - - return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef - - x, y, dx, dy = jax.lax.fori_loop( - 0, - 10, - _step, - ( - x, - y, - jnp.zeros_like(x), - jnp.zeros_like(y), - u, - v, - ab, - dudxcoef, - dudycoef, - dvdxcoef, - dvdycoef, - ), - )[0:4] + dx = -(du * dvdy - dv * dudy) / det + dy = -(-du * dvdx + dv * dudx) / det + + x += dx + y += dy x, y = jax.lax.cond( jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12, From dbdde18272e037d7e0ce38844db9475787b29d0d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Jan 2025 06:59:59 -0600 Subject: [PATCH 2/2] test: add benchmark for fits --- tests/jax/test_benchmarks.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 3c773dc6..f5eb0802 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -272,3 +272,20 @@ def test_benchmark_rng_discard(benchmark, kind): benchmark, kind, lambda: _run_benchmark_rng_discard(rng).block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_invert_ab_noraise(u, v, ab): + return jgs.fitswcs._invert_ab_noraise(u, v, ab)[0] + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_invert_ab_noraise(benchmark, kind): + u = jnp.arange(1000).astype(jnp.float64) + v = jnp.arange(1000).astype(jnp.float64) + ab = jnp.array([[[-0.5, 0.3], [-0.1, 2.0]], [[-1.0, 0.3], [-0.1, 4.0]]]) + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_invert_ab_noraise(u, v, ab).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ")