Skip to content

Commit 569d0a6

Browse files
committed
FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop
The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming.myst#442
1 parent b9ae277 commit 569d0a6

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

lectures/jax_intro.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,16 +727,27 @@ def compute_call_price_jax(β=β,
727727
728728
s = jnp.full(M, np.log(S0))
729729
h = jnp.full(M, h0)
730-
for t in range(n):
730+
731+
def update(i, state):
732+
s, h, key = state
731733
key, subkey = jax.random.split(key)
732734
Z = jax.random.normal(subkey, (2, M))
733735
s = s + μ + jnp.exp(h) * Z[0, :]
734736
h = ρ * h + ν * Z[1, :]
737+
return s, h, key
738+
739+
s, h, key = jax.lax.fori_loop(0, n, update, (s, h, key))
735740
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
736741
737742
return β**n * expectation
738743
```
739744

745+
```{note}
746+
We use `jax.lax.fori_loop` instead of a Python `for` loop.
747+
This allows JAX to compile the loop efficiently without unrolling it,
748+
which significantly reduces compilation time for large arrays.
749+
```
750+
740751
Let's run it once to compile it:
741752

742753
```{code-cell} ipython3

0 commit comments

Comments
 (0)