@@ -11,7 +11,7 @@ kernelspec:
1111 name : python3
1212---
1313
14- # An Introduction to JAX
14+ # JAX
1515
1616In addition to what's in Anaconda, this lecture will need the following libraries:
1717
@@ -26,35 +26,41 @@ This lecture provides a short introduction to [Google JAX](https://github.com/ja
2626Here we are focused on using JAX on the CPU, rather than on accelerators such as
2727GPUs or TPUs.
2828
29- This means we will only see a small amount of the possible benefits from using
30- JAX.
29+ This means we will only see a small amount of the possible benefits from using JAX.
3130
32- At the same time, JAX computing on the CPU is a good place to start, since the
33- JAX just-in-time compiler seamlessly handles transitions across different
34- hardware platforms.
31+ However, JAX seamlessly handles transitions across different hardware platforms.
3532
36- (In other words, if you do want to shift to using GPUs, you will almost never
37- need to modify your code.)
33+ As a result, if you run this code on a machine with a GPU and a GPU-aware
34+ version of JAX installed, your code will be automatically accelerated and you
35+ will receive the full benefits.
3836
3937For a discussion of JAX on GPUs, see [ our JAX lecture series] ( https://jax.quantecon.org/intro.html ) .
4038
4139
4240## JAX as a NumPy Replacement
4341
44- One way to use JAX is as a plug-in NumPy replacement. Let's look at the
45- similarities and differences .
42+ One of the attractive features of JAX is that, whenever possible, it conforms to
43+ the NumPy API for array operations .
4644
47- ### Similarities
45+ This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
46+
47+ Let's look at the similarities and differences between JAX and NumPy.
4848
49+ ### Similarities
4950
50- The following import is standard, replacing ` import numpy as np ` :
51+ We'll use the following imports
5152
5253``` {code-cell} ipython3
5354import jax
54- import jax.numpy as jnp
5555import quantecon as qe
5656```
5757
58+ In addition, we replace ` import numpy as np ` with
59+
60+ ``` {code-cell} ipython3
61+ import jax.numpy as jnp
62+ ```
63+
5864Now we can use ` jnp ` in place of ` np ` for the usual array operations:
5965
6066``` {code-cell} ipython3
@@ -101,20 +107,22 @@ B = jnp.identity(2)
101107A @ B
102108```
103109
104- ``` {code-cell} ipython3
105- from jax.numpy import linalg
106- ```
110+ JAX's array interface also provides the ` linalg ` subpackage:
107111
108112``` {code-cell} ipython3
109- linalg.inv(B) # Inverse of identity is identity
113+ jnp. linalg.inv(B) # Inverse of identity is identity
110114```
111115
112116``` {code-cell} ipython3
113- linalg.eigh(B) # Computes eigenvalues and eigenvectors
117+ jnp. linalg.eigh(B) # Computes eigenvalues and eigenvectors
114118```
115119
120+
116121### Differences
117122
123+ Let's now look at some differences between JAX and NumPy array operations.
124+
125+ #### Precision
118126
119127One difference between NumPy and JAX is that JAX uses 32 bit floats by default.
120128
@@ -136,6 +144,8 @@ Let's check this works:
136144jnp.ones(3)
137145```
138146
147+ #### Immutability
148+
139149As a NumPy replacement, a more significant difference is that arrays are treated as ** immutable** .
140150
141151For example, with NumPy we can write
@@ -170,13 +180,13 @@ In line with immutability, JAX does not support inplace operations:
170180
171181``` {code-cell} ipython3
172182a = np.array((2, 1))
173- a.sort()
183+ a.sort() # Unlike NumPy, does not mutate a
174184a
175185```
176186
177187``` {code-cell} ipython3
178188a = jnp.array((2, 1))
179- a_new = a.sort()
189+ a_new = a.sort() # Instead, the sort method returns a new sorted array
180190a, a_new
181191```
182192
@@ -185,7 +195,9 @@ The designers of JAX chose to make arrays immutable because JAX uses a
185195
186196This design choice has important implications, which we explore next!
187197
188- We should note, however, that, JAX does provide a version of in-place array modification
198+ #### A workaround
199+
200+ We note that JAX does provide a version of in-place array modification
189201using the [ ` at ` method] ( https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html ) .
190202
191203``` {code-cell} ipython3
@@ -199,14 +211,14 @@ a = a.at[0].set(1)
199211a
200212```
201213
202- Obviously, there are downsides to using ` at ` .
214+ Obviously, there are downsides to using ` at ` :
203215
204- The syntax is not very pretty and we want to avoid creating fresh arrays in memory every time we change a single value.
216+ * The syntax is cumbersome and
217+ * we want to avoid creating fresh arrays in memory every time we change a single value!
205218
206219Hence, for the most part, we try to avoid this syntax.
207220
208- (Although it can in fact be efficient inside JIT-compiled functions -- but let's
209- put this aside for now.)
221+ (Although it can in fact be efficient inside JIT-compiled functions -- but let's put this aside for now.)
210222
211223
212224## Functional Programming
@@ -217,29 +229,32 @@ From JAX's documentation:
217229
218230In other words, JAX assumes a functional programming style.
219231
220- The major implication is that JAX functions should be pure.
232+ ### Pure functions
221233
234+ The major implication is that JAX functions should be pure.
222235
223- Pure functions have the following characteristics:
236+ ** Pure functions** have the following characteristics:
224237
2252381 . * Deterministic*
2262392 . * No side effects*
227240
228- Deterministic means
241+ ** Deterministic** means
229242
230243* Same input $\implies$ same output
231244* Outputs do not depend on global state
232245
233246In particular, pure functions will always return the same result if invoked with the same inputs.
234247
235- No side effects means that the function
248+ ** No side effects** means that the function
236249
237250* Won't change global state
238251* Won't modify data passed to the function (immutable data)
239252
253+
254+
240255### Examples
241256
242- Here's an example of a non-pure function
257+ Here's an example of a * non-pure* function
243258
244259``` {code-cell} ipython3
245260tax_rate = 0.1
@@ -248,52 +263,55 @@ prices = [10.0, 20.0]
248263def add_tax(prices):
249264 for i, price in enumerate(prices):
250265 prices[i] = price * (1 + tax_rate)
251- print('Modified prices: ', prices)
266+ print('Post-tax prices: ', prices)
252267 return prices
253268```
254269
255270This function fails to be pure because
256271
257272* side effects --- it modifies the global variable ` prices `
258273* non-deterministic --- a change to the global variable ` tax_rate ` will modify
259- function outputs, even with the same inputs .
274+ function outputs, even with the same input array ` prices ` .
260275
261- Here's a pure version
276+ Here's a * pure* version
262277
263278``` {code-cell} ipython3
264279tax_rate = 0.1
265280prices = (10.0, 20.0)
266281
267282def add_tax_pure(prices, tax_rate):
268- return [price * (1 + tax_rate) for price in prices]
283+ new_prices = [price * (1 + tax_rate) for price in prices]
284+ return new_prices
269285```
270286
271287This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state.
272288
273289Now that we understand what pure functions are, let's explore how JAX's approach to random numbers maintains this purity.
274290
275291
276- ## Random Numbers
292+ ## Random numbers
277293
278294Random numbers are rather different in JAX, compared to what you find in NumPy
279295or Matlab.
280296
281297At first you might find the syntax rather verbose.
282298
283- But actually it makes a lot of sense:
299+ But you will soon realize that the syntax and semantics are necessary in order
300+ to maintain the functional programming style we just discussed.
301+
302+ Moreover, full control of random state
303+ essential for parallel programming, such as when we want to run independent experiments along multiple threads.
284304
285- * maintains the functional programming style we just discussed, and
286- * makes the control of random state explicit and convenient for running over
287- multiple threads --- essential for parallelization.
288305
289306### Random number generation
290307
291- In JAX, the state of the random number generator needs to be controlled explicitly.
308+ In JAX, the state of the random number generator is controlled explicitly.
292309
293310First we produce a key, which seeds the random number generator.
294311
295312``` {code-cell} ipython3
296- key = jax.random.PRNGKey(1)
313+ seed = 1234
314+ key = jax.random.PRNGKey(seed)
297315```
298316
299317Now we can use the key to generate some random numbers:
@@ -340,7 +358,8 @@ def gen_random_matrices(key, n=2, k=3):
340358```
341359
342360``` {code-cell} ipython3
343- key = jax.random.PRNGKey(1)
361+ seed = 42
362+ key = jax.random.PRNGKey(seed)
344363matrices = gen_random_matrices(key)
345364```
346365
@@ -358,15 +377,16 @@ def gen_random_matrices(key, n=2, k=3):
358377```
359378
360379``` {code-cell} ipython3
361- key = jax.random.PRNGKey(1 )
380+ key = jax.random.PRNGKey(seed )
362381matrices = gen_random_matrices(key)
363382```
364383
384+
365385### Why explicit random state?
366386
367- Why does JAX require this somewhat verbose approach to random number generation.
387+ Why does JAX require this somewhat verbose approach to random number generation?
368388
369- The reason is to maintain pure functions.
389+ One reason is to maintain pure functions.
370390
371391Let's see how random number generation relates to pure functions by comparing NumPy and JAX.
372392
@@ -378,12 +398,11 @@ Each time we call a random function, this state is updated:
378398
379399``` {code-cell} ipython3
380400np.random.seed(42)
381- print(np.random.randn())
382- print(np.random.randn())
383- print(np.random.randn())
401+ print(np.random.randn()) # Updates state of random number generator
402+ print(np.random.randn()) # Updates state of random number generator
384403```
385404
386- Notice that each call returns a different value, even though we're calling the same function with the same inputs (no arguments).
405+ Each call returns a different value, even though we're calling the same function with the same inputs (no arguments).
387406
388407This function is * not pure* because:
389408
@@ -416,14 +435,7 @@ random_sum_jax(key)
416435random_sum_jax(key)
417436```
418437
419- Different keys give different results:
420-
421- ``` {code-cell} ipython3
422- key1 = jax.random.PRNGKey(1)
423- key2 = jax.random.PRNGKey(2)
424- print(random_sum_jax(key1))
425- print(random_sum_jax(key2))
426- ```
438+ To get new draws we need to supply a new key.
427439
428440The function ` random_sum_jax ` is pure because:
429441
@@ -437,7 +449,7 @@ The explicitness of JAX brings significant benefits:
437449* Debugging: No hidden state makes code easier to reason about
438450* JIT compatibility: The compiler can optimize pure functions more aggressively
439451
440- The last point about JIT compatibility is explained in the next section.
452+ The last point is expanded on in the next section.
441453
442454
443455## JIT compilation
@@ -517,7 +529,10 @@ equivalent, which ran on the CPU.
517529Even if you are running on a machine with many CPUs, the second JAX run should
518530be substantially faster with JAX.
519531
520- But notice also that the second time is shorter than the first.
532+ Also, typically, the second run is faster than the first.
533+
534+ (This might not be noticable on the CPU but it should definitely be noticable on
535+ the GPU.)
521536
522537This is because even built in functions like ` jnp.cos ` are JIT-compiled --- and the
523538first run includes compile time.
@@ -534,9 +549,10 @@ requires matching the size of the task to the available hardware.
534549That's why JAX waits to see the size of the array before compiling --- which
535550requires a JIT-compiled approach instead of supplying precompiled binaries.
536551
552+
537553#### Changing array sizes
538554
539- Here we change the input size and see the run time increase and then fall again.
555+ Here we change the input size and watch the runtimes.
540556
541557``` {code-cell}
542558x = jnp.linspace(0, 10, n + 1)
@@ -555,10 +571,13 @@ with qe.Timer():
555571 jax.block_until_ready(y);
556572```
557573
574+ Typically, the run time increases and then falls again (this will be more obvious on the GPU).
575+
558576This is because the JIT compiler specializes on array size to exploit
559577parallelization --- and hence generates fresh compiled code when the array size
560578changes.
561579
580+
562581### Evaluating a more complicated function
563582
564583Let's try the same thing with a more complex function.
@@ -583,10 +602,6 @@ with qe.Timer():
583602 y = f(x)
584603```
585604
586- ``` {code-cell}
587- with qe.Timer():
588- y = f(x)
589- ```
590605
591606
592607#### With JAX
@@ -668,7 +683,7 @@ def f(x):
668683
669684Now that we've seen how powerful JIT compilation can be, it's important to understand its relationship with pure functions.
670685
671- JAX will not usually throw errors when compiling impure functions but execution becomes unpredictable.
686+ While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.
672687
673688Here's an illustration of this fact, using global variables:
674689
@@ -778,7 +793,7 @@ We defer further exploration of automatic differentiation with JAX until {doc}`j
778793:label: jax_intro_ex2
779794```
780795
781- In the Exercise section of {doc}` a lecture on Numba <numba>` , we used Monte
796+ In the Exercise section of {doc}` our lecture on Numba <numba>` , we used Monte
782797Carlo to price a European call option.
783798
784799The code was accelerated by Numba-based multithreading.
0 commit comments