Skip to content

Commit 4e0b558

Browse files
authored
Merge pull request #295 from rossbar/matplotlib-nopyplot
MAINT: Replace `plt.<cmd>` pattern with explicit fig/ax calls
2 parents 86f231e + 2f42627 commit 4e0b558

File tree

7 files changed

+75
-74
lines changed

7 files changed

+75
-74
lines changed

content/mooreslaw-tutorial.md

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,23 @@ The style sheet replicates
289289
```{code-cell}
290290
transistor_count_predicted = np.exp(B) * np.exp(A * year)
291291
transistor_Moores_law = Moores_law(year)
292+
292293
plt.style.use("fivethirtyeight")
293-
plt.semilogy(year, transistor_count, "s", label="MOS transistor count")
294-
plt.semilogy(year, transistor_count_predicted, label="linear regression")
294+
295+
fig, ax = plt.subplots()
296+
ax.semilogy(year, transistor_count, "s", label="MOS transistor count")
297+
ax.semilogy(year, transistor_count_predicted, label="linear regression")
295298
296299
297-
plt.plot(year, transistor_Moores_law, label="Moore's Law")
298-
plt.title(
300+
ax.plot(year, transistor_Moores_law, label="Moore's Law")
301+
ax.set_title(
299302
"MOS transistor count per microprocessor\n"
300303
+ "every two years \n"
301304
+ "Transistor count was x{:.2f} higher".format(np.exp(A * 2))
302305
)
303-
plt.xlabel("year introduced")
304-
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
305-
plt.ylabel("# of transistors\nper microprocessor")
306+
ax.set_xlabel("year introduced")
307+
ax.set_ylabel("# of transistors\nper microprocessor")
308+
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
306309
```
307310

308311
_A scatter plot of MOS transistor count per microprocessor every two years with a red line for the ordinary least squares prediction and an orange line for Moore's law._
@@ -346,19 +349,20 @@ y = np.linspace(2016.5, 2017.5)
346349
your_model2017 = np.exp(B) * np.exp(A * y)
347350
Moore_Model2017 = Moores_law(y)
348351
349-
plt.plot(
352+
fig, ax = plt.subplots()
353+
ax.plot(
350354
2017 * np.ones(np.sum(year == 2017)),
351355
transistor_count2017,
352356
"ro",
353357
label="2017",
354358
alpha=0.2,
355359
)
356-
plt.plot(2017, transistor_count2017.mean(), "g+", markersize=20, mew=6)
360+
ax.plot(2017, transistor_count2017.mean(), "g+", markersize=20, mew=6)
357361
358-
plt.plot(y, your_model2017, label="Your prediction")
359-
plt.plot(y, Moore_Model2017, label="Moores law")
360-
plt.ylabel("# of transistors\nper microprocessor")
361-
plt.legend()
362+
ax.plot(y, your_model2017, label="Your prediction")
363+
ax.plot(y, Moore_Model2017, label="Moores law")
364+
ax.set_ylabel("# of transistors\nper microprocessor")
365+
ax.legend()
362366
```
363367

364368
The result is that your model is close to the mean, but Gordon

content/tutorial-deep-learning-on-mnist.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ import matplotlib.pyplot as plt
166166
# Take the 60,000th image (indexed at 59,999) from the training set,
167167
# reshape from (784, ) to (28, 28) to have a valid shape for displaying purposes.
168168
mnist_image = x_train[59999, :].reshape(28, 28)
169+
170+
fig, ax = plt.subplots()
169171
# Set the color mapping to grayscale to have a black background.
170-
plt.imshow(mnist_image, cmap="gray")
171-
# Display the image.
172-
plt.show()
172+
ax.imshow(mnist_image, cmap="gray")
173173
```
174174

175175
```{code-cell}
@@ -586,7 +586,6 @@ for ax, metrics, title in zip(
586586
ax.set_title(title)
587587
ax.set_xlabel("Epochs")
588588
ax.legend()
589-
plt.show()
590589
```
591590

592591
_The training and testing error is shown above in the left and right

content/tutorial-ma.md

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ First of all, we can plot the whole set of data we have and see what it looks li
131131
import matplotlib.pyplot as plt
132132
133133
selected_dates = [0, 3, 11, 13]
134-
plt.plot(dates, nbcases.T, "--")
135-
plt.xticks(selected_dates, dates[selected_dates])
136-
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
134+
135+
fig, ax = plt.subplots()
136+
ax.plot(dates, nbcases.T, "--")
137+
ax.set_xticks(selected_dates, dates[selected_dates])
138+
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
137139
```
138140

139141
The graph has a strange shape from January 24th to February 1st. It would be interesting to know where this data comes from. If we look at the `locations` array we extracted from the `.csv` file, we can see that we have two columns, where the first would contain regions and the second would contain the name of the country. However, only the first few rows contain data for the the first column (province names in China). Following that, we only have country names. So it would make sense to group all the data from China into a single row. For this, we'll select from the `nbcases` array only the rows for which the second entry of the `locations` array corresponds to China. Next, we'll use the [numpy.sum](https://numpy.org/devdocs/reference/generated/numpy.sum.html#numpy.sum) function to sum all the selected rows (`axis=0`). Note also that row 35 corresponds to the total counts for the whole country for each date. Since we want to calculate the sum ourselves from the provinces data, we have to remove that row first from both `locations` and `nbcases`:
@@ -183,9 +185,10 @@ Let's try and see what the data looks like excluding the first row (data from th
183185
closely:
184186

185187
```{code-cell}
186-
plt.plot(dates, nbcases_ma[1:].T, "--")
187-
plt.xticks(selected_dates, dates[selected_dates])
188-
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
188+
fig, ax = plt.subplots()
189+
ax.plot(dates, nbcases_ma[1:].T, "--")
190+
ax.set_xticks(selected_dates, dates[selected_dates])
191+
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
189192
```
190193

191194
Now that our data has been masked, let's try summing up all the cases in China:
@@ -232,9 +235,10 @@ china_total
232235
We can replace the data with this information and plot a new graph, focusing on Mainland China:
233236

234237
```{code-cell}
235-
plt.plot(dates, china_total.T, "--")
236-
plt.xticks(selected_dates, dates[selected_dates])
237-
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
238+
fig, ax = plt.subplots()
239+
ax.plot(dates, china_total.T, "--")
240+
ax.set_xticks(selected_dates, dates[selected_dates])
241+
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
238242
```
239243

240244
It's clear that masked arrays are the right solution here. We cannot represent the missing data without mischaracterizing the evolution of the curve.
@@ -271,21 +275,25 @@ package to create a cubic polynomial model that fits the data as best as possibl
271275
```{code-cell}
272276
t = np.arange(len(china_total))
273277
model = np.polynomial.Polynomial.fit(t[~china_total.mask], valid, deg=3)
274-
plt.plot(t, china_total)
275-
plt.plot(t, model(t), "--")
278+
279+
fig, ax = plt.subplots()
280+
ax.plot(t, china_total)
281+
ax.plot(t, model(t), "--")
276282
```
277283

278284
This plot is not so readable since the lines seem to be over each other, so let's summarize in a more elaborate plot. We'll plot the real data when
279285
available, and show the cubic fit for unavailable data, using this fit to compute an estimate to the observed number of cases on January 28th 2020, 7 days after the beginning of the records:
280286

281287
```{code-cell}
282-
plt.plot(t, china_total)
283-
plt.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
284-
plt.plot(7, model(7), "r*")
285-
plt.xticks([0, 7, 13], dates[[0, 7, 13]])
286-
plt.yticks([0, model(7), 10000, 17500])
287-
plt.legend(["Mainland China", "Cubic estimate", "7 days after start"])
288-
plt.title(
288+
fig, ax = plt.subplots()
289+
ax.plot(t, china_total)
290+
ax.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
291+
ax.plot(7, model(7), "r*")
292+
293+
ax.set_xticks([0, 7, 13], dates[[0, 7, 13]])
294+
ax.set_yticks([0, model(7), 10000, 17500])
295+
ax.legend(["Mainland China", "Cubic estimate", "7 days after start"])
296+
ax.set_title(
289297
"COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China\n"
290298
"Cubic estimate for 7 days after start"
291299
)

content/tutorial-plotting-fractals.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ mesh = x + (1j * y)
219219
220220
output = divergence_rate(mesh)
221221
222-
fig = plt.figure(figsize=(5, 5))
223-
ax = plt.axes()
222+
fig, ax = plt.subplots(figsize=(5, 5))
224223
225224
ax.set_title('$f(z) = z^2 -1$')
226225
ax.set_xlabel('Real axis')
@@ -273,8 +272,7 @@ We will also write a function that we will use to create our fractal plots:
273272
```{code-cell} ipython3
274273
def plot_fractal(fractal, title='Fractal', figsize=(6, 6), cmap='rainbow', extent=[-2, 2, -2, 2]):
275274
276-
plt.figure(figsize=figsize)
277-
ax = plt.axes()
275+
fig, ax = plt.subplots(figsize=figsize)
278276
279277
ax.set_title(f'${title}$')
280278
ax.set_xlabel('Real axis')

content/tutorial-static_equilibrium.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ d3.quiver(x, y, z, u, v, w, color="r", label="forceA")
9797
u, v, w = forceB
9898
d3.quiver(x, y, z, u, v, w, color="b", label="forceB")
9999
100-
plt.legend()
101-
plt.show()
100+
d3.legend()
102101
```
103102

104103
There are two forces emanating from a single point. In order to simplify this problem, you can add them together to find the sum of forces. Note that both `forceA` and `forceB` are three-dimensional vectors, represented by NumPy as arrays with three components. Because NumPy is meant to simplify and optimize operations between vectors, you can easily compute the sum of these two vectors as follows:
@@ -129,8 +128,7 @@ d3.quiver(x, y, z, u, v, w, color="b", label="forceB")
129128
u, v, w = forceC
130129
d3.quiver(x, y, z, u, v, w, color="g", label="forceC")
131130
132-
plt.legend()
133-
plt.show()
131+
d3.legend()
134132
```
135133

136134
However, the goal is equilibrium.
@@ -172,8 +170,6 @@ x, y, z = np.array([0, 0, 0])
172170
173171
u, v, w = forceA + forceB + R # add them all together for sum of forces
174172
d3.quiver(x, y, z, u, v, w)
175-
176-
plt.show()
177173
```
178174

179175
The empty graph signifies that there are no outlying forces. This denotes a system in equilibrium.

content/tutorial-svd.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ import matplotlib.pyplot as plt
7474
```
7575

7676
```{code-cell}
77-
plt.imshow(img)
78-
plt.show()
77+
fig, ax = plt.subplots()
78+
ax.imshow(img)
7979
```
8080

8181
### Shape, axis and array properties
@@ -196,8 +196,8 @@ To see if this makes sense in our image, we should use a colormap from `matplotl
196196
In our case, we are approximating the grayscale portion of the image, so we will use the colormap `gray`:
197197

198198
```{code-cell}
199-
plt.imshow(img_gray, cmap="gray")
200-
plt.show()
199+
fig, ax = plt.subplots()
200+
ax.imshow(img_gray, cmap="gray")
201201
```
202202

203203
Now, applying the [linalg.svd](https://numpy.org/devdocs/reference/generated/numpy.linalg.svd.html#numpy.linalg.svd) function to this matrix, we obtain the following decomposition:
@@ -259,8 +259,8 @@ np.allclose(img_gray, U @ Sigma @ Vt)
259259
To see if an approximation is reasonable, we can check the values in `s`:
260260

261261
```{code-cell}
262-
plt.plot(s)
263-
plt.show()
262+
fig, ax = plt.subplots()
263+
ax.plot(s)
264264
```
265265

266266
In the graph, we can see that although we have 768 singular values in `s`, most of those (after the 150th entry or so) are pretty small. So it might make sense to use only the information related to the first (say, 50) *singular values* to build a more economical approximation to our image.
@@ -282,8 +282,8 @@ approx = U @ Sigma[:, :k] @ Vt[:k, :]
282282
Note that we had to use only the first `k` rows of `Vt`, since all other rows would be multiplied by the zeros corresponding to the singular values we eliminated from this approximation.
283283

284284
```{code-cell}
285-
plt.imshow(approx, cmap="gray")
286-
plt.show()
285+
fig, ax = plt.subplots()
286+
ax.imshow(approx, cmap="gray")
287287
```
288288

289289
Now, you can go ahead and repeat this experiment with other values of `k`, and each of your experiments should give you a slightly better (or worse) image depending on the value you choose.
@@ -362,8 +362,9 @@ Since `imshow` expects values in the range, we can use `clip` to excise the floa
362362

363363
```{code-cell}
364364
reconstructed = np.clip(reconstructed, 0, 1)
365-
plt.imshow(np.transpose(reconstructed, (1, 2, 0)))
366-
plt.show()
365+
366+
fig, ax = plt.subplots()
367+
ax.imshow(np.transpose(reconstructed, (1, 2, 0)))
367368
```
368369

369370
```{note}
@@ -391,8 +392,8 @@ approx_img.shape
391392
which is not the right shape for showing the image. Finally, reordering the axes back to our original shape of `(768, 1024, 3)`, we can see our approximation:
392393

393394
```{code-cell}
394-
plt.imshow(np.transpose(np.clip(approx_img, 0, 1), (1, 2, 0)))
395-
plt.show()
395+
fig, ax = plt.subplots()
396+
ax.imshow(np.transpose(np.clip(approx_img, 0, 1), (1, 2, 0)))
396397
```
397398

398399
Even though the image is not as sharp, using a small number of `k` singular values (compared to the original set of 768 values), we can recover many of the distinguishing features from this image.

content/tutorial-x-ray-image-processing.md

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ print(xray_image.dtype)
134134
```{code-cell}
135135
import matplotlib.pyplot as plt
136136
137-
plt.imshow(xray_image, cmap="gray")
138-
plt.axis("off")
139-
plt.show()
137+
fig, ax = plt.subplots()
138+
ax.imshow(xray_image, cmap="gray")
139+
ax.set_axis_off()
140140
```
141141

142142
## Combine images into a multidimensional array to demonstrate progression
@@ -239,7 +239,6 @@ axes[1].set_title("Laplacian-Gaussian (edges)")
239239
axes[1].imshow(xray_image_laplace_gaussian, cmap="gray")
240240
for i in axes:
241241
i.axis("off")
242-
plt.show()
243242
```
244243

245244
### The Gaussian gradient magnitude method
@@ -272,7 +271,6 @@ axes[1].set_title("Gaussian gradient (edges)")
272271
axes[1].imshow(x_ray_image_gaussian_gradient, cmap="gray")
273272
for i in axes:
274273
i.axis("off")
275-
plt.show()
276274
```
277275

278276
### The Sobel-Feldman operator (the Sobel filter)
@@ -337,7 +335,6 @@ axes[2].set_title("Sobel (edges) - CMRmap")
337335
axes[2].imshow(xray_image_sobel, cmap="CMRmap")
338336
for i in axes:
339337
i.axis("off")
340-
plt.show()
341338
```
342339

343340
### The Canny filter
@@ -398,7 +395,6 @@ axes[3].set_title("Canny (edges) - terrain")
398395
axes[3].imshow(xray_image_canny, cmap="terrain")
399396
for i in axes:
400397
i.axis("off")
401-
plt.show()
402398
```
403399

404400
## Apply masks to X-rays with `np.where()`
@@ -437,9 +433,9 @@ pixel_intensity_distribution = ndimage.histogram(
437433
xray_image, min=np.min(xray_image), max=np.max(xray_image), bins=256
438434
)
439435
440-
plt.plot(pixel_intensity_distribution)
441-
plt.title("Pixel intensity distribution")
442-
plt.show()
436+
fig, ax = plt.subplots()
437+
ax.plot(pixel_intensity_distribution)
438+
ax.set_title("Pixel intensity distribution")
443439
```
444440

445441
As the pixel intensity distribution suggests, there are many low (between around
@@ -454,19 +450,19 @@ a certain threshold:
454450
# Return the original image if true, `0` otherwise
455451
xray_image_mask_noisy = np.where(xray_image > 150, xray_image, 0)
456452
457-
plt.imshow(xray_image_mask_noisy, cmap="gray")
458-
plt.axis("off")
459-
plt.show()
453+
fig, ax = plt.subplots()
454+
ax.imshow(xray_image_mask_noisy, cmap="gray")
455+
ax.set_axis_off()
460456
```
461457

462458
```{code-cell}
463459
# The threshold is "greater than 150"
464460
# Return `1` if true, `0` otherwise
465461
xray_image_mask_less_noisy = np.where(xray_image > 150, 1, 0)
466462
467-
plt.imshow(xray_image_mask_less_noisy, cmap="gray")
468-
plt.axis("off")
469-
plt.show()
463+
fig, ax = plt.subplots()
464+
ax.imshow(xray_image_mask_less_noisy, cmap="gray")
465+
ax.set_axis_off()
470466
```
471467

472468
## Compare the results
@@ -499,7 +495,6 @@ axes[8].set_title("Mask (> 150, less noisy)")
499495
axes[8].imshow(xray_image_mask_less_noisy, cmap="gray")
500496
for i in axes:
501497
i.axis("off")
502-
plt.show()
503498
```
504499

505500
## Next steps

0 commit comments

Comments
 (0)