Skip to content

Commit 50de16b

Browse files
committed
Fix jax_intro timeout: use lax.fori_loop instead of Python for loop
The compute_call_price_jax function was timing out during cache.yml builds because JAX unrolls Python for loops during JIT compilation. With large arrays (M=10M), this causes excessive compilation time. Solution: Replace Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Fixes cell execution timeout in jax_intro.md
1 parent 9045b9f commit 50de16b

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

lectures/jax_intro.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,16 +832,28 @@ def compute_call_price_jax(β=β,
832832
833833
s = jnp.full(M, np.log(S0))
834834
h = jnp.full(M, h0)
835-
for t in range(n):
835+
836+
def loop_body(i, state):
837+
s, h, key = state
836838
key, subkey = jax.random.split(key)
837839
Z = jax.random.normal(subkey, (2, M))
838840
s = s + μ + jnp.exp(h) * Z[0, :]
839841
h = ρ * h + ν * Z[1, :]
842+
return s, h, key
843+
844+
s, h, key = jax.lax.fori_loop(0, n, loop_body, (s, h, key))
845+
840846
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
841847
842848
return β**n * expectation
843849
```
844850

851+
```{note}
852+
We use `jax.lax.fori_loop` instead of a Python `for` loop.
853+
This allows JAX to compile the loop efficiently without unrolling it,
854+
which significantly reduces compilation time for large arrays.
855+
```
856+
845857
Let's run it once to compile it:
846858

847859
```{code-cell} ipython3

0 commit comments

Comments
 (0)