diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index 7c0d30ff..a7e21ae5 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -21,6 +21,7 @@ jobs: fail-fast: false matrix: python-version: ["3.10", "3.11", "3.12"] + group: [1, 2, 3, 4] steps: - uses: actions/checkout@v4 @@ -33,7 +34,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest pytest-codspeed + python -m pip install pytest pytest-codspeed pytest-split pytest-randomly python -m pip install . # temp pin until 0.5 is on conda python -m pip install "jax<0.5.0" @@ -41,4 +42,16 @@ jobs: - name: Test with pytest run: | git submodule update --init --recursive - pytest -v --durations=0 + pytest -v --durations=0 \ + -k "not test_fpack" \ + --randomly-seed=42 \ + --splits=4 --group=${{ matrix.group }} --splitting-algorithm least_duration + + build-status: + needs: build + runs-on: ubuntu-latest + + steps: + - name: check status + run: | + echo "Builds all passed!" diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index fdfe6b42..6b85517b 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1008,6 +1008,30 @@ def FitsWCS( FitsWCS._opt_params = {"dir": str, "hdu": int, "compression": str, "text_file": bool} +@jax.jit +def _invert_ab_noraise_loop_body( + x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef +): + # Want Jac^-1 . du + # du + du = horner2d(x, y, ab[0], triangle=True) - u + dv = horner2d(x, y, ab[1], triangle=True) - v + # J + dudx = horner2d(x, y, dudxcoef, triangle=True) + dudy = horner2d(x, y, dudycoef, triangle=True) + dvdx = horner2d(x, y, dvdxcoef, triangle=True) + dvdy = horner2d(x, y, dvdycoef, triangle=True) + # J^-1 . du + det = dudx * dvdy - dudy * dvdx + dx = -(du * dvdy - dv * dudy) / det + dy = -(-du * dvdx + dv * dudx) / det + + x += dx + y += dy + + return x, y, dx, dy + + @jax.jit def _invert_ab_noraise(u, v, ab, abp=None): # get guess from abp if we have it @@ -1029,22 +1053,9 @@ def _invert_ab_noraise(u, v, ab, abp=None): dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:] for _ in range(10): - # Want Jac^-1 . du - # du - du = horner2d(x, y, ab[0], triangle=True) - u - dv = horner2d(x, y, ab[1], triangle=True) - v - # J - dudx = horner2d(x, y, dudxcoef, triangle=True) - dudy = horner2d(x, y, dudycoef, triangle=True) - dvdx = horner2d(x, y, dvdxcoef, triangle=True) - dvdy = horner2d(x, y, dvdycoef, triangle=True) - # J^-1 . du - det = dudx * dvdy - dudy * dvdx - dx = -(du * dvdy - dv * dudy) / det - dy = -(-du * dvdx + dv * dudx) / det - - x += dx - y += dy + x, y, dx, dy = _invert_ab_noraise_loop_body( + x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef + ) x, y = jax.lax.cond( jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12, diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 7557c2be..0efc2e1b 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -61,7 +61,6 @@ allowed_failures: - "module 'jax_galsim.utilities' has no attribute 'dol_to_lod'" - "module 'jax_galsim.utilities' has no attribute 'nCr'" - "module 'jax_galsim' has no attribute 'LookupTable'" - - "module 'jax_galsim.bessel' has no attribute 'j0'" - "module 'jax_galsim.bessel' has no attribute 'j1'" - "module 'jax_galsim.bessel' has no attribute 'jn'" - "module 'jax_galsim.bessel' has no attribute 'jv'" @@ -69,7 +68,6 @@ allowed_failures: - "module 'jax_galsim.bessel' has no attribute 'yv'" - "module 'jax_galsim.bessel' has no attribute 'iv'" - "module 'jax_galsim.bessel' has no attribute 'kn'" - - "module 'jax_galsim.bessel' has no attribute 'kv'" - "module 'jax_galsim.bessel' has no attribute 'j0_root'" - "module 'jax_galsim.bessel' has no attribute 'gammainc'" - "module 'jax_galsim.bessel' has no attribute 'sinc'" @@ -145,6 +143,6 @@ allowed_failures: - "Sensor/photon_ops not yet implemented in drawImage for method != 'phot'" - "module 'jax_galsim' has no attribute 'SiliconSensor'" - "module 'jax_galsim' has no attribute 'set_omp_threads'" - - "module 'jax_galsim' has no attribute 'Spergel'" - "module 'jax_galsim' has no attribute 'LookupTable2D'" - "module 'jax_galsim' has no attribute 'zernike'" + - "Invalid TFORM4: 1PE(7)" # see https://github.com/astropy/astropy/issues/15477