Commit 569d0a6
committed
FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop
The compute_call_price_jax function was timing out during builds due to
JAX unrolling the Python for loop during JIT compilation. With large
arrays (M=10,000,000), this causes excessive compilation time.
Solution: Replace the Python for loop with jax.lax.fori_loop, which
compiles the loop efficiently without unrolling.
Same fix as QuantEcon/lecture-python-programming.myst#4421 parent b9ae277 commit 569d0a6
1 file changed
+12
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
727 | 727 | | |
728 | 728 | | |
729 | 729 | | |
730 | | - | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
731 | 733 | | |
732 | 734 | | |
733 | 735 | | |
734 | 736 | | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
735 | 740 | | |
736 | 741 | | |
737 | 742 | | |
738 | 743 | | |
739 | 744 | | |
| 745 | + | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
740 | 751 | | |
741 | 752 | | |
742 | 753 | | |
| |||
0 commit comments