Skip to content

Commit 300ec30

Browse files
committed
misc
1 parent 267ffc3 commit 300ec30

File tree

3 files changed

+208
-134
lines changed

3 files changed

+208
-134
lines changed

lectures/_toc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ parts:
2323
numbered: true
2424
chapters:
2525
- file: numba
26-
- file: numpy_vs_numba_vs_jax
2726
- file: jax_intro
27+
- file: numpy_vs_numba_vs_jax
2828
- caption: Working with Data
2929
numbered: true
3030
chapters:

lectures/jax_intro.md

Lines changed: 81 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ kernelspec:
1111
name: python3
1212
---
1313

14-
# An Introduction to JAX
14+
# JAX
1515

1616
In addition to what's in Anaconda, this lecture will need the following libraries:
1717

@@ -26,35 +26,41 @@ This lecture provides a short introduction to [Google JAX](https://github.com/ja
2626
Here we are focused on using JAX on the CPU, rather than on accelerators such as
2727
GPUs or TPUs.
2828

29-
This means we will only see a small amount of the possible benefits from using
30-
JAX.
29+
This means we will only see a small amount of the possible benefits from using JAX.
3130

32-
At the same time, JAX computing on the CPU is a good place to start, since the
33-
JAX just-in-time compiler seamlessly handles transitions across different
34-
hardware platforms.
31+
However, JAX seamlessly handles transitions across different hardware platforms.
3532

36-
(In other words, if you do want to shift to using GPUs, you will almost never
37-
need to modify your code.)
33+
As a result, if you run this code on a machine with a GPU and a GPU-aware
34+
version of JAX installed, your code will be automatically accelerated and you
35+
will receive the full benefits.
3836

3937
For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
4038

4139

4240
## JAX as a NumPy Replacement
4341

44-
One way to use JAX is as a plug-in NumPy replacement. Let's look at the
45-
similarities and differences.
42+
One of the attractive features of JAX is that, whenever possible, it conforms to
43+
the NumPy API for array operations.
4644

47-
### Similarities
45+
This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
46+
47+
Let's look at the similarities and differences between JAX and NumPy.
4848

49+
### Similarities
4950

50-
The following import is standard, replacing `import numpy as np`:
51+
We'll use the following imports
5152

5253
```{code-cell} ipython3
5354
import jax
54-
import jax.numpy as jnp
5555
import quantecon as qe
5656
```
5757

58+
In addition, we replace `import numpy as np` with
59+
60+
```{code-cell} ipython3
61+
import jax.numpy as jnp
62+
```
63+
5864
Now we can use `jnp` in place of `np` for the usual array operations:
5965

6066
```{code-cell} ipython3
@@ -101,20 +107,22 @@ B = jnp.identity(2)
101107
A @ B
102108
```
103109

104-
```{code-cell} ipython3
105-
from jax.numpy import linalg
106-
```
110+
JAX's array interface also provides the `linalg` subpackage:
107111

108112
```{code-cell} ipython3
109-
linalg.inv(B) # Inverse of identity is identity
113+
jnp.linalg.inv(B) # Inverse of identity is identity
110114
```
111115

112116
```{code-cell} ipython3
113-
linalg.eigh(B) # Computes eigenvalues and eigenvectors
117+
jnp.linalg.eigh(B) # Computes eigenvalues and eigenvectors
114118
```
115119

120+
116121
### Differences
117122

123+
Let's now look at some differences between JAX and NumPy array operations.
124+
125+
#### Precision
118126

119127
One difference between NumPy and JAX is that JAX uses 32 bit floats by default.
120128

@@ -136,6 +144,8 @@ Let's check this works:
136144
jnp.ones(3)
137145
```
138146

147+
#### Immutability
148+
139149
As a NumPy replacement, a more significant difference is that arrays are treated as **immutable**.
140150

141151
For example, with NumPy we can write
@@ -170,13 +180,13 @@ In line with immutability, JAX does not support inplace operations:
170180

171181
```{code-cell} ipython3
172182
a = np.array((2, 1))
173-
a.sort()
183+
a.sort() # Unlike NumPy, does not mutate a
174184
a
175185
```
176186

177187
```{code-cell} ipython3
178188
a = jnp.array((2, 1))
179-
a_new = a.sort()
189+
a_new = a.sort() # Instead, the sort method returns a new sorted array
180190
a, a_new
181191
```
182192

@@ -185,7 +195,9 @@ The designers of JAX chose to make arrays immutable because JAX uses a
185195

186196
This design choice has important implications, which we explore next!
187197

188-
We should note, however, that, JAX does provide a version of in-place array modification
198+
#### A workaround
199+
200+
We note that JAX does provide a version of in-place array modification
189201
using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
190202

191203
```{code-cell} ipython3
@@ -199,14 +211,14 @@ a = a.at[0].set(1)
199211
a
200212
```
201213

202-
Obviously, there are downsides to using `at`.
214+
Obviously, there are downsides to using `at`:
203215

204-
The syntax is not very pretty and we want to avoid creating fresh arrays in memory every time we change a single value.
216+
* The syntax is cumbersome and
217+
* we want to avoid creating fresh arrays in memory every time we change a single value!
205218

206219
Hence, for the most part, we try to avoid this syntax.
207220

208-
(Although it can in fact be efficient inside JIT-compiled functions -- but let's
209-
put this aside for now.)
221+
(Although it can in fact be efficient inside JIT-compiled functions -- but let's put this aside for now.)
210222

211223

212224
## Functional Programming
@@ -217,29 +229,32 @@ From JAX's documentation:
217229

218230
In other words, JAX assumes a functional programming style.
219231

220-
The major implication is that JAX functions should be pure.
232+
### Pure functions
221233

234+
The major implication is that JAX functions should be pure.
222235

223-
Pure functions have the following characteristics:
236+
**Pure functions** have the following characteristics:
224237

225238
1. *Deterministic*
226239
2. *No side effects*
227240

228-
Deterministic means
241+
**Deterministic** means
229242

230243
* Same input $\implies$ same output
231244
* Outputs do not depend on global state
232245

233246
In particular, pure functions will always return the same result if invoked with the same inputs.
234247

235-
No side effects means that the function
248+
**No side effects** means that the function
236249

237250
* Won't change global state
238251
* Won't modify data passed to the function (immutable data)
239252

253+
254+
240255
### Examples
241256

242-
Here's an example of a non-pure function
257+
Here's an example of a *non-pure* function
243258

244259
```{code-cell} ipython3
245260
tax_rate = 0.1
@@ -248,52 +263,55 @@ prices = [10.0, 20.0]
248263
def add_tax(prices):
249264
for i, price in enumerate(prices):
250265
prices[i] = price * (1 + tax_rate)
251-
print('Modified prices: ', prices)
266+
print('Post-tax prices: ', prices)
252267
return prices
253268
```
254269

255270
This function fails to be pure because
256271

257272
* side effects --- it modifies the global variable `prices`
258273
* non-deterministic --- a change to the global variable `tax_rate` will modify
259-
function outputs, even with the same inputs.
274+
function outputs, even with the same input array `prices`.
260275

261-
Here's a pure version
276+
Here's a *pure* version
262277

263278
```{code-cell} ipython3
264279
tax_rate = 0.1
265280
prices = (10.0, 20.0)
266281
267282
def add_tax_pure(prices, tax_rate):
268-
return [price * (1 + tax_rate) for price in prices]
283+
new_prices = [price * (1 + tax_rate) for price in prices]
284+
return new_prices
269285
```
270286

271287
This pure version makes all dependencies explicit through function arguments, and doesn't modify any external state.
272288

273289
Now that we understand what pure functions are, let's explore how JAX's approach to random numbers maintains this purity.
274290

275291

276-
## Random Numbers
292+
## Random numbers
277293

278294
Random numbers are rather different in JAX, compared to what you find in NumPy
279295
or Matlab.
280296

281297
At first you might find the syntax rather verbose.
282298

283-
But actually it makes a lot of sense:
299+
But you will soon realize that the syntax and semantics are necessary in order
300+
to maintain the functional programming style we just discussed.
301+
302+
Moreover, full control of random state
303+
essential for parallel programming, such as when we want to run independent experiments along multiple threads.
284304

285-
* maintains the functional programming style we just discussed, and
286-
* makes the control of random state explicit and convenient for running over
287-
multiple threads --- essential for parallelization.
288305

289306
### Random number generation
290307

291-
In JAX, the state of the random number generator needs to be controlled explicitly.
308+
In JAX, the state of the random number generator is controlled explicitly.
292309

293310
First we produce a key, which seeds the random number generator.
294311

295312
```{code-cell} ipython3
296-
key = jax.random.PRNGKey(1)
313+
seed = 1234
314+
key = jax.random.PRNGKey(seed)
297315
```
298316

299317
Now we can use the key to generate some random numbers:
@@ -340,7 +358,8 @@ def gen_random_matrices(key, n=2, k=3):
340358
```
341359

342360
```{code-cell} ipython3
343-
key = jax.random.PRNGKey(1)
361+
seed = 42
362+
key = jax.random.PRNGKey(seed)
344363
matrices = gen_random_matrices(key)
345364
```
346365

@@ -358,15 +377,16 @@ def gen_random_matrices(key, n=2, k=3):
358377
```
359378

360379
```{code-cell} ipython3
361-
key = jax.random.PRNGKey(1)
380+
key = jax.random.PRNGKey(seed)
362381
matrices = gen_random_matrices(key)
363382
```
364383

384+
365385
### Why explicit random state?
366386

367-
Why does JAX require this somewhat verbose approach to random number generation.
387+
Why does JAX require this somewhat verbose approach to random number generation?
368388

369-
The reason is to maintain pure functions.
389+
One reason is to maintain pure functions.
370390

371391
Let's see how random number generation relates to pure functions by comparing NumPy and JAX.
372392

@@ -378,12 +398,11 @@ Each time we call a random function, this state is updated:
378398

379399
```{code-cell} ipython3
380400
np.random.seed(42)
381-
print(np.random.randn())
382-
print(np.random.randn())
383-
print(np.random.randn())
401+
print(np.random.randn()) # Updates state of random number generator
402+
print(np.random.randn()) # Updates state of random number generator
384403
```
385404

386-
Notice that each call returns a different value, even though we're calling the same function with the same inputs (no arguments).
405+
Each call returns a different value, even though we're calling the same function with the same inputs (no arguments).
387406

388407
This function is *not pure* because:
389408

@@ -416,14 +435,7 @@ random_sum_jax(key)
416435
random_sum_jax(key)
417436
```
418437

419-
Different keys give different results:
420-
421-
```{code-cell} ipython3
422-
key1 = jax.random.PRNGKey(1)
423-
key2 = jax.random.PRNGKey(2)
424-
print(random_sum_jax(key1))
425-
print(random_sum_jax(key2))
426-
```
438+
To get new draws we need to supply a new key.
427439

428440
The function `random_sum_jax` is pure because:
429441

@@ -437,7 +449,7 @@ The explicitness of JAX brings significant benefits:
437449
* Debugging: No hidden state makes code easier to reason about
438450
* JIT compatibility: The compiler can optimize pure functions more aggressively
439451

440-
The last point about JIT compatibility is explained in the next section.
452+
The last point is expanded on in the next section.
441453

442454

443455
## JIT compilation
@@ -517,7 +529,10 @@ equivalent, which ran on the CPU.
517529
Even if you are running on a machine with many CPUs, the second JAX run should
518530
be substantially faster with JAX.
519531

520-
But notice also that the second time is shorter than the first.
532+
Also, typically, the second run is faster than the first.
533+
534+
(This might not be noticable on the CPU but it should definitely be noticable on
535+
the GPU.)
521536

522537
This is because even built in functions like `jnp.cos` are JIT-compiled --- and the
523538
first run includes compile time.
@@ -534,9 +549,10 @@ requires matching the size of the task to the available hardware.
534549
That's why JAX waits to see the size of the array before compiling --- which
535550
requires a JIT-compiled approach instead of supplying precompiled binaries.
536551

552+
537553
#### Changing array sizes
538554

539-
Here we change the input size and see the run time increase and then fall again.
555+
Here we change the input size and watch the runtimes.
540556

541557
```{code-cell}
542558
x = jnp.linspace(0, 10, n + 1)
@@ -555,10 +571,13 @@ with qe.Timer():
555571
jax.block_until_ready(y);
556572
```
557573

574+
Typically, the run time increases and then falls again (this will be more obvious on the GPU).
575+
558576
This is because the JIT compiler specializes on array size to exploit
559577
parallelization --- and hence generates fresh compiled code when the array size
560578
changes.
561579

580+
562581
### Evaluating a more complicated function
563582

564583
Let's try the same thing with a more complex function.
@@ -583,10 +602,6 @@ with qe.Timer():
583602
y = f(x)
584603
```
585604

586-
```{code-cell}
587-
with qe.Timer():
588-
y = f(x)
589-
```
590605

591606

592607
#### With JAX
@@ -668,7 +683,7 @@ def f(x):
668683

669684
Now that we've seen how powerful JIT compilation can be, it's important to understand its relationship with pure functions.
670685

671-
JAX will not usually throw errors when compiling impure functions but execution becomes unpredictable.
686+
While JAX will not usually throw errors when compiling impure functions, execution becomes unpredictable.
672687

673688
Here's an illustration of this fact, using global variables:
674689

@@ -778,7 +793,7 @@ We defer further exploration of automatic differentiation with JAX until {doc}`j
778793
:label: jax_intro_ex2
779794
```
780795

781-
In the Exercise section of {doc}`a lecture on Numba <numba>`, we used Monte
796+
In the Exercise section of {doc}`our lecture on Numba <numba>`, we used Monte
782797
Carlo to price a European call option.
783798

784799
The code was accelerated by Numba-based multithreading.

0 commit comments

Comments
 (0)