Skip to content

Commit 1991775

Browse files
authored
FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop (#249)
The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming.myst#442
1 parent b9ae277 commit 1991775

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

lectures/jax_intro.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,16 +727,31 @@ def compute_call_price_jax(β=β,
727727
728728
s = jnp.full(M, np.log(S0))
729729
h = jnp.full(M, h0)
730-
for t in range(n):
730+
731+
def update(i, loop_state):
732+
s, h, key = loop_state
731733
key, subkey = jax.random.split(key)
732734
Z = jax.random.normal(subkey, (2, M))
733735
s = s + μ + jnp.exp(h) * Z[0, :]
734736
h = ρ * h + ν * Z[1, :]
737+
new_loop_state = s, h, key
738+
return new_loop_state
739+
740+
initial_loop_state = s, h, key
741+
final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state)
742+
s, h, key = final_loop_state
743+
735744
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
736745
737746
return β**n * expectation
738747
```
739748

749+
```{note}
750+
We use `jax.lax.fori_loop` instead of a Python `for` loop.
751+
This allows JAX to compile the loop efficiently without unrolling it,
752+
which significantly reduces compilation time for large arrays.
753+
```
754+
740755
Let's run it once to compile it:
741756

742757
```{code-cell} ipython3

0 commit comments

Comments
 (0)