From ae7020210498b7e0f90cbe46d4bbc8f9b80509d9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Dec 2024 10:47:28 -0600 Subject: [PATCH 1/5] feat: add benchmarks for spergel profile --- tests/jax/test_benchmarks.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 7fe51487..9ef0cb72 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -159,3 +159,27 @@ def _run(): dt = _run_benchmarks(benchmark, kind, _run) print(f"time: {dt:0.4g} ms", end=" ") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel(benchmark, kind): + def _run_jax(): + gal = jgs.Spergel(nu=-0.6, scale_radius=4.0) + psf = jgs.Gaussian(fwhm=0.9) + obj = jgs.Convolve([gal, psf]) + obj.drawImage(nx=51, ny=51, scale=0.2).array.block_until_ready() + + dt = _run_benchmarks(benchmark, kind, _run_jax) + print(f"jax-galsim time: {dt:0.4g} ms", end=" ") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_galsim(benchmark, kind): + def _run(): + gal = _galsim.Spergel(nu=-0.6, scale_radius=4.0) + psf = _galsim.Gaussian(fwhm=0.9) + obj = _galsim.Convolve([gal, psf]) + obj.drawImage(nx=51, ny=51, scale=0.2) + + dt = _run_benchmarks(benchmark, kind, _run) + print(f"galsim time: {dt:0.4g} ms", end=" ") From b85f52d39fe2676b0e7988f2bd9b96c15e57d9bb Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Dec 2024 10:52:10 -0600 Subject: [PATCH 2/5] fix: add dependabot and update codspeed --- .github/dependabot.yml | 10 ++++++++++ .github/workflows/benchmarks.yaml | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..5f454fdf --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + github-actions: + patterns: + - '*' diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml index 176abc12..f9d51a13 100644 --- a/.github/workflows/benchmarks.yaml +++ b/.github/workflows/benchmarks.yaml @@ -38,7 +38,7 @@ jobs: git submodule update --init --recursive - name: Run benchmarks - uses: CodSpeedHQ/action@v2 + uses: CodSpeedHQ/action@v3 with: token: ${{ secrets.CODSPEED_TOKEN }} run: pytest -vvs --codspeed From 407bc379198e024ab78fcaddadd665b24c1f46d1 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Dec 2024 14:15:09 -0600 Subject: [PATCH 3/5] test: add more tests and benchmarks --- jax_galsim/spergel.py | 6 +-- tests/jax/test_benchmarks.py | 77 +++++++++++++++++++++------ tests/jax/test_spergel_comp_galsim.py | 68 +++++++++++++++++++++++ 3 files changed, 133 insertions(+), 18 deletions(-) create mode 100644 tests/jax/test_spergel_comp_galsim.py diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index 5c02a741..11430931 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -168,7 +168,7 @@ def reducedfluxfractionFunc(z, nu, norm): @jax.jit def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0): - """Return radius R enclosing flux fraction alpha in unit of the scale radius r0 + """Return radius R enclosing flux fraction alpha in unit of the scale radius r0 Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm @@ -303,11 +303,11 @@ def _xnorm0(self): @implements(_galsim.spergel.Spergel.calculateFluxRadius) def calculateFluxRadius(self, f): - return calculateFluxRadius(f, self.nu) + return self._r0 * calculateFluxRadius(f, self.nu) @implements(_galsim.spergel.Spergel.calculateIntegratedFlux) def calculateIntegratedFlux(self, r): - return fluxfractionFunc(r, self.nu, 0.0) + return fluxfractionFunc(r / self._r0, self.nu, 0.0) def __hash__(self): return hash( diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 9ef0cb72..15470caf 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -161,25 +161,72 @@ def _run(): print(f"time: {dt:0.4g} ms", end=" ") +def _run_spergel_bench_conv(gsmod): + obj = gsmod.Spergel(nu=-0.6, scale_radius=5) + psf = gsmod.Gaussian(fwhm=0.9) + obj = gsmod.Convolve( + [obj, psf], + gsparams=gsmod.GSParams(minimum_fft_size=2048, maximum_fft_size=2048), + ) + return obj.drawImage(nx=50, ny=50, scale=0.2).array + + +_run_spergel_bench_conv_jit = jax.jit(partial(_run_spergel_bench_conv, jgs)) + + @pytest.mark.parametrize("kind", ["compile", "run"]) -def test_benchmark_spergel(benchmark, kind): - def _run_jax(): - gal = jgs.Spergel(nu=-0.6, scale_radius=4.0) - psf = jgs.Gaussian(fwhm=0.9) - obj = jgs.Convolve([gal, psf]) - obj.drawImage(nx=51, ny=51, scale=0.2).array.block_until_ready() - - dt = _run_benchmarks(benchmark, kind, _run_jax) +def test_benchmark_spergel_conv(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_conv_jit().block_until_ready() + ) print(f"jax-galsim time: {dt:0.4g} ms", end=" ") @pytest.mark.parametrize("kind", ["compile", "run"]) -def test_benchmark_spergel_galsim(benchmark, kind): - def _run(): - gal = _galsim.Spergel(nu=-0.6, scale_radius=4.0) - psf = _galsim.Gaussian(fwhm=0.9) - obj = _galsim.Convolve([gal, psf]) - obj.drawImage(nx=51, ny=51, scale=0.2) +def test_benchmark_spergel_conv_galsim(benchmark, kind): + dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_conv(_galsim)) + print(f"galsim time: {dt:0.4g} ms", end=" ") - dt = _run_benchmarks(benchmark, kind, _run) + +def _run_spergel_bench_xvalue(gsmod): + obj = gsmod.Spergel(nu=-0.6, scale_radius=5) + return obj.drawImage(nx=50, ny=50, scale=0.2, method="no_pixel").array + + +_run_spergel_bench_xvalue_jit = jax.jit(partial(_run_spergel_bench_xvalue, jgs)) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_xvalue(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready() + ) + print(f"jax-galsim time: {dt:0.4g} ms", end=" ") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_xvalue_galsim(benchmark, kind): + dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_xvalue(_galsim)) + print(f"galsim time: {dt:0.4g} ms", end=" ") + + +def _run_spergel_bench_kvalue(gsmod): + obj = gsmod.Spergel(nu=-0.6, scale_radius=5) + return obj.drawKImage(nx=50, ny=50, scale=0.2).array + + +_run_spergel_bench_kvalue_jit = jax.jit(partial(_run_spergel_bench_kvalue, jgs)) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_kvalue(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready() + ) + print(f"jax-galsim time: {dt:0.4g} ms", end=" ") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_kvalue_galsim(benchmark, kind): + dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_kvalue(_galsim)) print(f"galsim time: {dt:0.4g} ms", end=" ") diff --git a/tests/jax/test_spergel_comp_galsim.py b/tests/jax/test_spergel_comp_galsim.py new file mode 100644 index 00000000..cc709136 --- /dev/null +++ b/tests/jax/test_spergel_comp_galsim.py @@ -0,0 +1,68 @@ +import galsim as gs +import numpy as np +import pytest + +import jax_galsim as jgs + + +@pytest.mark.parametrize( + "attr", ["nu", "scale_radius", "maxk", "stepk", "half_light_radius"] +) +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_properties(nu, scale_radius, attr): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + assert s_jgs.gsparams.folding_threshold == s_gs.gsparams.folding_threshold + assert s_jgs.gsparams.stepk_minimum_hlr == s_gs.gsparams.stepk_minimum_hlr + + np.testing.assert_allclose(getattr(s_jgs, attr), getattr(s_gs, attr), rtol=1e-5) + + +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_flux_radius(nu, scale_radius): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose( + s_jgs.calculateFluxRadius(0.8), + s_gs.calculateFluxRadius(0.8), + rtol=1e-5, + ) + + +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_integ_flux(nu, scale_radius): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose( + s_jgs.calculateIntegratedFlux(0.8), + s_gs.calculateIntegratedFlux(0.8), + rtol=1e-5, + ) + + +@pytest.mark.parametrize("kx", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("ky", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_kvalue(nu, scale_radius, kx, ky): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose(s_jgs.kValue(kx, ky), s_gs.kValue(kx, ky), rtol=1e-5) + + +@pytest.mark.parametrize("x", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("y", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_xvalue(nu, scale_radius, x, y): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose(s_jgs.xValue(x, y), s_gs.xValue(x, y), rtol=1e-5) From 1b2baf5bac6a928883c0e6a7b7f356c9461f3281 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Dec 2024 14:17:39 -0600 Subject: [PATCH 4/5] test: mark stepk as xfail for now --- tests/jax/test_spergel_comp_galsim.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_spergel_comp_galsim.py b/tests/jax/test_spergel_comp_galsim.py index cc709136..1a56f557 100644 --- a/tests/jax/test_spergel_comp_galsim.py +++ b/tests/jax/test_spergel_comp_galsim.py @@ -6,7 +6,19 @@ @pytest.mark.parametrize( - "attr", ["nu", "scale_radius", "maxk", "stepk", "half_light_radius"] + "attr", + [ + "nu", + "scale_radius", + "maxk", + pytest.param( + "stepk", + marks=pytest.mark.xfail( + reason="GalSim has a bug in its stepk routine. See https://github.com/GalSim-developers/GalSim/issues/1324" + ), + ), + "half_light_radius", + ], ) @pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) @pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) From 7c93e54759b93d6093626748b35342cb694d833b Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 19 Dec 2024 17:02:51 -0600 Subject: [PATCH 5/5] test: redo benchmarks to just cover jax and add timing comps --- jax_galsim/spergel.py | 20 +++---- tests/jax/test_benchmarks.py | 28 ++-------- tests/jax/test_spergel_comp_galsim.py | 77 +++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 33 deletions(-) diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index 11430931..a5a7389b 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -267,34 +267,34 @@ def scale_radius(self): def _r0(self): return self.scale_radius - @property + @lazy_property def _inv_r0(self): return 1.0 / self._r0 - @property + @lazy_property def _r0_sq(self): return self._r0 * self._r0 - @property + @lazy_property def _inv_r0_sq(self): return self._inv_r0 * self._inv_r0 - @property + @lazy_property @implements(_galsim.spergel.Spergel.half_light_radius) def half_light_radius(self): return self._r0 * calculateFluxRadius(0.5, self.nu) - @property + @lazy_property def _shootxnorm(self): """Normalization for photon shooting""" return 1.0 / (2.0 * jnp.pi * jnp.power(2.0, self.nu) * _gammap1(self.nu)) - @property + @lazy_property def _xnorm(self): """Normalization of xValue""" return self._shootxnorm * self.flux * self._inv_r0_sq - @property + @lazy_property def _xnorm0(self): """return z^nu K_nu(z) for z=0""" return jax.lax.select( @@ -338,13 +338,13 @@ def __str__(self): s += ")" return s - @property + @lazy_property def _maxk(self): """(1+ (k r0)^2)^(-1-nu) = maxk_threshold""" res = jnp.power(self.gsparams.maxk_threshold, -1.0 / (1.0 + self.nu)) - 1.0 return jnp.sqrt(res) / self._r0 - @property + @lazy_property def _stepk(self): R = calculateFluxRadius(1.0 - self.gsparams.folding_threshold, self.nu) R *= self._r0 @@ -352,7 +352,7 @@ def _stepk(self): R = jnp.maximum(R, self.gsparams.stepk_minimum_hlr * self.half_light_radius) return jnp.pi / R - @property + @lazy_property def _max_sb(self): # from SBSpergelImpl.h return jnp.abs(self._xnorm) * self._xnorm0 diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 15470caf..7ef21332 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -179,18 +179,12 @@ def test_benchmark_spergel_conv(benchmark, kind): dt = _run_benchmarks( benchmark, kind, lambda: _run_spergel_bench_conv_jit().block_until_ready() ) - print(f"jax-galsim time: {dt:0.4g} ms", end=" ") - - -@pytest.mark.parametrize("kind", ["compile", "run"]) -def test_benchmark_spergel_conv_galsim(benchmark, kind): - dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_conv(_galsim)) - print(f"galsim time: {dt:0.4g} ms", end=" ") + print(f"time: {dt:0.4g} ms", end=" ") def _run_spergel_bench_xvalue(gsmod): obj = gsmod.Spergel(nu=-0.6, scale_radius=5) - return obj.drawImage(nx=50, ny=50, scale=0.2, method="no_pixel").array + return obj.drawImage(nx=1024, ny=1204, scale=0.05, method="no_pixel").array _run_spergel_bench_xvalue_jit = jax.jit(partial(_run_spergel_bench_xvalue, jgs)) @@ -201,18 +195,12 @@ def test_benchmark_spergel_xvalue(benchmark, kind): dt = _run_benchmarks( benchmark, kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready() ) - print(f"jax-galsim time: {dt:0.4g} ms", end=" ") - - -@pytest.mark.parametrize("kind", ["compile", "run"]) -def test_benchmark_spergel_xvalue_galsim(benchmark, kind): - dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_xvalue(_galsim)) - print(f"galsim time: {dt:0.4g} ms", end=" ") + print(f"time: {dt:0.4g} ms", end=" ") def _run_spergel_bench_kvalue(gsmod): obj = gsmod.Spergel(nu=-0.6, scale_radius=5) - return obj.drawKImage(nx=50, ny=50, scale=0.2).array + return obj.drawKImage(nx=1024, ny=1204, scale=0.05).array _run_spergel_bench_kvalue_jit = jax.jit(partial(_run_spergel_bench_kvalue, jgs)) @@ -223,10 +211,4 @@ def test_benchmark_spergel_kvalue(benchmark, kind): dt = _run_benchmarks( benchmark, kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready() ) - print(f"jax-galsim time: {dt:0.4g} ms", end=" ") - - -@pytest.mark.parametrize("kind", ["compile", "run"]) -def test_benchmark_spergel_kvalue_galsim(benchmark, kind): - dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_kvalue(_galsim)) - print(f"galsim time: {dt:0.4g} ms", end=" ") + print(f"time: {dt:0.4g} ms", end=" ") diff --git a/tests/jax/test_spergel_comp_galsim.py b/tests/jax/test_spergel_comp_galsim.py index 1a56f557..0c0d4f63 100644 --- a/tests/jax/test_spergel_comp_galsim.py +++ b/tests/jax/test_spergel_comp_galsim.py @@ -1,8 +1,19 @@ +import galsim as _galsim import galsim as gs +import jax import numpy as np import pytest +from test_benchmarks import ( + _run_spergel_bench_conv, + _run_spergel_bench_conv_jit, + _run_spergel_bench_kvalue, + _run_spergel_bench_kvalue_jit, + _run_spergel_bench_xvalue, + _run_spergel_bench_xvalue_jit, +) import jax_galsim as jgs +from jax_galsim.core.testing import time_code_block @pytest.mark.parametrize( @@ -78,3 +89,69 @@ def test_spergel_comp_galsim_xvalue(nu, scale_radius, x, y): s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) np.testing.assert_allclose(s_jgs.xValue(x, y), s_gs.xValue(x, y), rtol=1e-5) + + +def _run_time_test(kind, func): + if kind == "compile": + + def _run(): + jax.clear_caches() + func() + + elif kind == "run": + # run once to compile + func() + + def _run(): + func() + + else: + raise ValueError(f"kind={kind} not recognized") + + tot_time = 0 + for _ in range(3): + with time_code_block(quiet=True) as tr: + _run() + tot_time += tr.dt + + return tot_time / 3 + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_spergel_comp_galsim_perf_conv(benchmark, kind): + dt = _run_time_test(kind, lambda: _run_spergel_bench_conv_jit().block_until_ready()) + print(f"\njax-galsim time: {dt:0.4g} ms") + + dt = _run_time_test( + kind, + lambda: _run_spergel_bench_conv(_galsim), + ) + print(f" galsim time: {dt:0.4g} ms") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_spergel_comp_galsim_perf_kvalue(benchmark, kind): + dt = _run_time_test( + kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready() + ) + print(f"\njax-galsim time: {dt:0.4g} ms") + + dt = _run_time_test( + kind, + lambda: _run_spergel_bench_kvalue(_galsim), + ) + print(f" galsim time: {dt:0.4g} ms") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_spergel_comp_galsim_perf_xvalue(benchmark, kind): + dt = _run_time_test( + kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready() + ) + print(f"\njax-galsim time: {dt:0.4g} ms") + + dt = _run_time_test( + kind, + lambda: _run_spergel_bench_xvalue(_galsim), + ) + print(f" galsim time: {dt:0.4g} ms")