Skip to content

Commit afc2d36

Browse files
committed
style: use jstac's fori_loop naming conventions
- loop_body -> update - state -> loop_state - Added explicit new_loop_state and final_loop_state variables - More verbose but clearer for first-time fori_loop readers
1 parent 50de16b commit afc2d36

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

lectures/jax_intro.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,15 +833,18 @@ def compute_call_price_jax(β=β,
833833
s = jnp.full(M, np.log(S0))
834834
h = jnp.full(M, h0)
835835
836-
def loop_body(i, state):
837-
s, h, key = state
836+
def update(i, loop_state):
837+
s, h, key = loop_state
838838
key, subkey = jax.random.split(key)
839839
Z = jax.random.normal(subkey, (2, M))
840840
s = s + μ + jnp.exp(h) * Z[0, :]
841841
h = ρ * h + ν * Z[1, :]
842-
return s, h, key
842+
new_loop_state = s, h, key
843+
return new_loop_state
843844
844-
s, h, key = jax.lax.fori_loop(0, n, loop_body, (s, h, key))
845+
loop_state = s, h, key
846+
final_loop_state = jax.lax.fori_loop(0, n, update, loop_state)
847+
s, h, key = final_loop_state
845848
846849
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
847850

0 commit comments

Comments
 (0)