Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
group: [1, 2, 3, 4]
python-version: ["3.12"]

steps:
- uses: actions/checkout@v4
Expand All @@ -34,18 +33,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-codspeed pytest-split pytest-randomly
python -m pip install pytest pytest-codspeed pytest-randomly
python -m pip install .
# temp pin until 0.5 is on conda
python -m pip install "jax<0.5.0"

- name: Test with pytest
run: |
git submodule update --init --recursive
pytest -v --durations=0 \
-k "not test_fpack" \
--randomly-seed=42 \
--splits=4 --group=${{ matrix.group }} --splitting-algorithm least_duration
pytest -vv --durations=100 --randomly-seed=42

build-status:
needs: build
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ repos:
hooks:
- id: ruff
args: [ --fix ]
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/|dev/notebooks/
- id: ruff-format
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/|dev/notebooks/
174 changes: 174 additions & 0 deletions dev/notebooks/spergel_fixed_point.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d7b8bc37-8799-433c-9399-de95a21a1727",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import galsim\n",
"import numpy as np\n",
"\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "774101b1",
"metadata": {},
"outputs": [],
"source": [
"from jax_galsim.spergel import (\n",
" fz_nup1, _gammap1, _spergel_hlr_pade,\n",
" fluxfractionFunc, fz_nu, calculateFluxRadius,\n",
")\n",
"\n",
"@jax.jit\n",
"def _calculateFluxRadius_newtons_kernel(i, args):\n",
" \"\"\"Newton's method kernel for calculateFluxRadius\n",
"\n",
" Returns\n",
"\n",
" lnz - fluxfractionFunc(z, nu, alpha) / dfluxfractionFunc(z, nu, alpha)_dz / z\n",
"\n",
" which is Newton's kernel but in log space.\n",
" \"\"\"\n",
" lnz, alpha, nu = args\n",
" z = jnp.exp(lnz)\n",
" dn = (jnp.power(2.0, nu) * _gammap1(nu))\n",
" fz = 1.0 - fz_nup1(z, nu) / dn - alpha\n",
" dfzdz = z * fz_nu(z, nu) / dn\n",
"\n",
" # we clip the result to avoid numerical issues near bounds\n",
" lnz = jnp.clip(\n",
" lnz - fz / dfzdz / z,\n",
" min=-100,\n",
" max=100,\n",
" )\n",
"\n",
" return lnz, alpha, nu\n",
"\n",
"\n",
"@jax.jit\n",
"def calculateFluxRadiusNewton(alpha, nu):\n",
" \"\"\"Return radius R enclosing flux fraction alpha in unit of the scale radius r0\n",
"\n",
" Method: Solve F(R/r0=z)/Flux - alpha = 0 using Netwon's method\n",
"\n",
" We can integrate the profile to get\n",
"\n",
" F(R)/F = int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha\n",
"\n",
" So if we define z = R/r0 and f(z) = F(z * r0)/F - alpha, then Newton's method is\n",
"\n",
" z -> z - f(z) / f'(z)\n",
"\n",
" We actually run the method for ln(z) which is\n",
"\n",
" ln(z) -> ln(z) - f(z) / f'(z) / z\n",
"\n",
" Typical use cases include:\n",
"\n",
" - alpha = 1/2 => R = Half-Light-Radius,\n",
" - alpha = 1 - folding-thresold => R used for stepk computation\n",
" \"\"\"\n",
" # seed the iteration with the Pade approximation to the HLR\n",
" # scaled by the fraction of flux to some power\n",
" zalpha = _spergel_hlr_pade(nu) * jnp.sqrt(alpha / 0.5)\n",
" return jnp.exp(jax.lax.fori_loop(\n",
" 0, 100,\n",
" _calculateFluxRadius_newtons_kernel,\n",
" (jnp.log(zalpha), alpha, nu),\n",
" )[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1be23e1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"eps, nu, log10(alpha): 1e-12 -0.84 -12.0\n",
"3.5138887102897e-38 -2.2121720121483927e-17 1.0587911840678754e-21 1.8761616702453412e-07 1.000534100015216e-12\n",
"\n",
"eps, nu, log10(alpha): 1e-12 3.999 -12.0\n",
"3.9966817649384216e-06 -1.576433954596703e-15 3.999832106175669e-06 3.1094518726606304e-16 9.984622740022494e-13\n",
"\n",
"eps, nu, log10(alpha): 1e-12 -0.84 -4.3428487456249e-13\n",
"25.572845945758726 0.0 25.572509765625 -3.3306690738754696e-16 0.999999999999\n",
"\n",
"eps, nu, log10(alpha): 1e-12 3.999 -4.3428487456249e-13\n",
"38.6677767503012 0.0 38.6676025390625 -1.1102230246251565e-16 0.999999999999\n",
"\n",
"eps, nu, log10(alpha): 0.1 -0.84 -1.0\n",
"0.0008333666650951336 6.38378239159465e-16 0.0008333666650951221 3.0531133177191805e-16 0.10000000000000096\n",
"\n",
"eps, nu, log10(alpha): 0.1 3.999 -1.0\n",
"1.3092245672406861 6.38378239159465e-16 1.3092245672406833 8.326672684688674e-17 0.10000000000000037\n",
"\n",
"eps, nu, log10(alpha): 0.1 -0.84 -0.045757490560675115\n",
"1.2147258941802845 0.0 1.214725894180284 -1.1102230246251565e-16 0.9000000000000001\n",
"\n",
"eps, nu, log10(alpha): 0.1 3.999 -0.045757490560675115\n",
"6.899340112339111 -1.1102230246251565e-16 6.899340112339113 -1.1102230246251565e-16 0.8999999999999999\n"
]
}
],
"source": [
"for eps in [1e-12, 0.1]:\n",
" for alpha in [eps, 1.0 - eps]:\n",
" for nu in [-0.84, 3.999]:\n",
"\n",
" print(\"\\neps, nu, log10(alpha):\", eps, nu, np.log10(alpha))\n",
" zfp = calculateFluxRadiusNewton(alpha, nu)\n",
" zbs = calculateFluxRadius(alpha, nu)\n",
" print(\n",
" zfp,\n",
" fluxfractionFunc(zfp, nu, alpha),\n",
" zbs,\n",
" fluxfractionFunc(zbs, nu, alpha),\n",
" galsim.Spergel(nu, scale_radius=1.0).calculateIntegratedFlux(zfp),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29cd9aa2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax-galsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 1 addition & 1 deletion jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _func(i, args):
flow = func(low)
fhigh = func(high)
args = (func, low, flow, high, fhigh)
return jax.lax.fori_loop(0, niter, _func, args)[-2]
return jax.lax.fori_loop(0, niter, _func, args, unroll=15)[-2]


# start of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py #
Expand Down
45 changes: 41 additions & 4 deletions jax_galsim/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,10 +1052,47 @@ def _invert_ab_noraise(u, v, ab, abp=None):
dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1]
dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:]

for _ in range(10):
x, y, dx, dy = _invert_ab_noraise_loop_body(
x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
)
def _step(i, args):
x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args

# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
duu = -(du * dvdy - dv * dudy) / det
dvv = -(-du * dvdx + dv * dudx) / det

x += duu
y += dvv

return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef

x, y, dx, dy = jax.lax.fori_loop(
0,
10,
_step,
(
x,
y,
jnp.zeros_like(x),
jnp.zeros_like(y),
u,
v,
ab,
dudxcoef,
dudycoef,
dvdxcoef,
dvdycoef,
),
unroll=True,
)[0:4]

x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
Expand Down
13 changes: 9 additions & 4 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100):
nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf.
BUT the case rm==0 is already done, so HERE rm != 0
"""
xcur = re
for _ in range(nitr):
xcur = _bodymi(xcur, rm, re, beta)
return xcur

# fix loop iteration is faster and reach eps=1e-6 (single precision)
def body(i, xcur):
x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2
x = jnp.power(x, 1 / (1 - beta))
x = jnp.sqrt(x - 1)
return re / x

return jax.lax.fori_loop(0, 100, body, re, unroll=True)


@implements(_galsim.Moffat)
Expand Down
7 changes: 5 additions & 2 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def reducedfluxfractionFunc(z, nu, norm):


@jax.jit
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=40.0):
"""Return radius R enclosing flux fraction alpha in unit of the scale radius r0

Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm
Expand All @@ -186,7 +186,10 @@ def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
nb. it is supposed that nu is in [-0.85, 4.0] checked in the Spergel class init
"""
return bisect_for_root(
partial(fluxfractionFunc, nu=nu, alpha=alpha), zmin, zmax, niter=75
partial(fluxfractionFunc, nu=nu, alpha=alpha),
zmin,
zmax,
niter=75,
)


Expand Down
14 changes: 14 additions & 0 deletions tests/jax/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,17 @@ def test_benchmark_moffat_init(benchmark, kind):
benchmark, kind, lambda: _run_benchmark_moffat_init().block_until_ready()
)
print(f"time: {dt:0.4g} ms", end=" ")


def _run_benchmark_spergel_calcfluxrad():
return jgs.spergel.calculateFluxRadius(1e-10, 2.0)


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_benchmark_spergel_calcfluxrad(benchmark, kind):
dt = _run_benchmarks(
benchmark,
kind,
lambda: _run_benchmark_spergel_calcfluxrad().block_until_ready(),
)
print(f"time: {dt:0.4g} ms", end=" ")