Commit fa8d48e
committed
FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop
The compute_call_price_jax function was timing out during builds due to
JAX unrolling the Python for loop during JIT compilation. With large
arrays (M=10,000,000), this causes excessive compilation time.
Solution: Replace the Python for loop with jax.lax.fori_loop, which
compiles the loop efficiently without unrolling.
Same fix as QuantEcon/lecture-python-programming.myst#4421 parent b9ae277 commit fa8d48e
1 file changed
+16
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
727 | 727 | | |
728 | 728 | | |
729 | 729 | | |
730 | | - | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
731 | 733 | | |
732 | 734 | | |
733 | 735 | | |
734 | 736 | | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
735 | 744 | | |
736 | 745 | | |
737 | 746 | | |
738 | 747 | | |
739 | 748 | | |
| 749 | + | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
740 | 755 | | |
741 | 756 | | |
742 | 757 | | |
| |||
0 commit comments