Skip to content

Commit 288e4ae

Browse files
committed
update amss to remove dependency on interpolation
1 parent fbc8121 commit 288e4ae

File tree

2 files changed

+150
-14
lines changed

2 files changed

+150
-14
lines changed

lectures/_static/lecture_specific/amss/recursive_allocation.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,64 @@
1+
import numpy as np
2+
from numba import njit, prange
3+
from quantecon import optimize
4+
5+
@njit
6+
def get_grid_nodes(grid):
7+
"""
8+
Get the actual grid points from a grid tuple.
9+
"""
10+
x_min, x_max, x_num = grid
11+
return np.linspace(x_min, x_max, x_num)
12+
13+
@njit
14+
def linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_val):
15+
"""Helper function for scalar interpolation"""
16+
x_nodes = np.linspace(x_min, x_max, x_num)
17+
18+
# Extrapolation with linear extension
19+
if x_val <= x_nodes[0]:
20+
# Linear extrapolation using first two points
21+
if x_num >= 2:
22+
slope = (y_values[1] - y_values[0]) / (x_nodes[1] - x_nodes[0])
23+
return y_values[0] + slope * (x_val - x_nodes[0])
24+
else:
25+
return y_values[0]
26+
27+
if x_val >= x_nodes[-1]:
28+
# Linear extrapolation using last two points
29+
if x_num >= 2:
30+
slope = (y_values[-1] - y_values[-2]) / (x_nodes[-1] - x_nodes[-2])
31+
return y_values[-1] + slope * (x_val - x_nodes[-1])
32+
else:
33+
return y_values[-1]
34+
35+
# Binary search for the right interval
36+
left = 0
37+
right = x_num - 1
38+
while right - left > 1:
39+
mid = (left + right) // 2
40+
if x_nodes[mid] <= x_val:
41+
left = mid
42+
else:
43+
right = mid
44+
45+
# Linear interpolation
46+
x_left = x_nodes[left]
47+
x_right = x_nodes[right]
48+
y_left = y_values[left]
49+
y_right = y_values[right]
50+
51+
weight = (x_val - x_left) / (x_right - x_left)
52+
return y_left * (1 - weight) + y_right * weight
53+
54+
@njit
55+
def linear_interp_1d(x_grid, y_values, x_query):
56+
"""
57+
Perform 1D linear interpolation.
58+
"""
59+
x_min, x_max, x_num = x_grid
60+
return linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_query[0])
61+
162
class AMSS:
263
# WARNING: THE CODE IS EXTREMELY SENSITIVE TO CHOCIES OF PARAMETERS.
364
# DO NOT CHANGE THE PARAMETERS AND EXPECT IT TO WORK
@@ -78,6 +139,10 @@ def simulate(self, s_hist, b_0):
78139
pref = self.pref
79140
x_grid, g, β, S = self.x_grid, self.g, self.β, self.S
80141
σ_v_star, σ_w_star = self.σ_v_star, self.σ_w_star
142+
Π = self.Π
143+
144+
# Extract the grid tuple from the list
145+
grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid
81146

82147
T = len(s_hist)
83148
s_0 = s_hist[0]
@@ -111,8 +176,8 @@ def simulate(self, s_hist, b_0):
111176
T = np.zeros(S)
112177
for s in range(S):
113178
x_arr = np.array([x_])
114-
l[s] = eval_linear(x_grid, σ_v_star[s_, :, s], x_arr)
115-
T[s] = eval_linear(x_grid, σ_v_star[s_, :, S+s], x_arr)
179+
l[s] = linear_interp_1d(grid_tuple, σ_v_star[s_, :, s], x_arr)
180+
T[s] = linear_interp_1d(grid_tuple, σ_v_star[s_, :, S+s], x_arr)
116181

117182
c = (1 - l) - g
118183
u_c = pref.Uc(c, l)
@@ -135,6 +200,8 @@ def simulate(self, s_hist, b_0):
135200

136201
def obj_factory(Π, β, x_grid, g):
137202
S = len(Π)
203+
# Extract the grid tuple from the list
204+
grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid
138205

139206
@njit
140207
def obj_V(σ, state, V, pref):
@@ -152,7 +219,7 @@ def obj_V(σ, state, V, pref):
152219
V_next = np.zeros(S)
153220

154221
for s in range(S):
155-
V_next[s] = eval_linear(x_grid, V[s], np.array([x[s]]))
222+
V_next[s] = linear_interp_1d(grid_tuple, V[s], np.array([x[s]]))
156223

157224
out = Π[s_] @ (pref.U(c, l) + β * V_next)
158225

@@ -167,7 +234,7 @@ def obj_W(σ, state, V, pref):
167234
c = (1 - l) - g[s_]
168235
x = -pref.Uc(c, l) * (c - T - b_0) + pref.Ul(c, l) * (1 - l)
169236

170-
V_next = eval_linear(x_grid, V[s_], np.array([x]))
237+
V_next = linear_interp_1d(grid_tuple, V[s_], np.array([x]))
171238

172239
out = pref.U(c, l) + β * V_next
173240

@@ -178,9 +245,11 @@ def obj_W(σ, state, V, pref):
178245

179246
def bellman_operator_factory(Π, β, x_grid, g, bounds_v):
180247
obj_V, obj_W = obj_factory(Π, β, x_grid, g)
181-
n = x_grid[0][2]
248+
# Extract the grid tuple from the list
249+
grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid
250+
n = grid_tuple[2]
182251
S = len(Π)
183-
x_nodes = nodes(x_grid)
252+
x_nodes = get_grid_nodes(grid_tuple)
184253

185254
@njit(parallel=True)
186255
def T_v(V, V_new, σ_star, pref):
@@ -209,4 +278,4 @@ def T_w(W, σ_star, V, b_0, pref):
209278
W[s_] = res.fun
210279
σ_star[s_] = res.x
211280

212-
return T_v, T_w
281+
return T_v, T_w

lectures/amss.md

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ In addition to what's in Anaconda, this lecture will need the following librarie
2727
tags: [hide-output]
2828
---
2929
!pip install --upgrade quantecon
30-
!pip install interpolation
3130
```
3231

3332
## Overview
@@ -38,12 +37,80 @@ Let's start with following imports:
3837
import numpy as np
3938
import matplotlib.pyplot as plt
4039
from scipy.optimize import root
41-
from interpolation.splines import eval_linear, UCGrid, nodes
4240
from quantecon import optimize, MarkovChain
4341
from numba import njit, prange, float64
4442
from numba.experimental import jitclass
4543
```
4644

45+
Now let's define numba-compatible interpolation functions for this lecture.
46+
47+
We will soon use the following interpolation functions to interpolate the value function and the policy functions
48+
49+
```{code-cell} ipython
50+
:tags: [collapse-40]
51+
52+
@njit
53+
def get_grid_nodes(grid):
54+
"""
55+
Get the actual grid points from a grid tuple.
56+
"""
57+
x_min, x_max, x_num = grid
58+
return np.linspace(x_min, x_max, x_num)
59+
60+
@njit
61+
def linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_val):
62+
"""Helper function for scalar interpolation"""
63+
x_nodes = np.linspace(x_min, x_max, x_num)
64+
65+
# Extrapolation with linear extension
66+
if x_val <= x_nodes[0]:
67+
# Linear extrapolation using first two points
68+
if x_num >= 2:
69+
slope = (y_values[1] - y_values[0]) \
70+
/ (x_nodes[1] - x_nodes[0])
71+
return y_values[0] + slope * (x_val - x_nodes[0])
72+
else:
73+
return y_values[0]
74+
75+
if x_val >= x_nodes[-1]:
76+
# Linear extrapolation using last two points
77+
if x_num >= 2:
78+
slope = (y_values[-1] - y_values[-2]) \
79+
/ (x_nodes[-1] - x_nodes[-2])
80+
return y_values[-1] + slope * (x_val - x_nodes[-1])
81+
else:
82+
return y_values[-1]
83+
84+
# Binary search for the right interval
85+
left = 0
86+
right = x_num - 1
87+
while right - left > 1:
88+
mid = (left + right) // 2
89+
if x_nodes[mid] <= x_val:
90+
left = mid
91+
else:
92+
right = mid
93+
94+
# Linear interpolation
95+
x_left = x_nodes[left]
96+
x_right = x_nodes[right]
97+
y_left = y_values[left]
98+
y_right = y_values[right]
99+
100+
weight = (x_val - x_left) / (x_right - x_left)
101+
return y_left * (1 - weight) + y_right * weight
102+
103+
@njit
104+
def linear_interp_1d(x_grid, y_values, x_query):
105+
"""
106+
Perform 1D linear interpolation.
107+
"""
108+
x_min, x_max, x_num = x_grid
109+
return linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_query[0])
110+
```
111+
112+
Let's start with following imports:
113+
47114
In {doc}`an earlier lecture <opt_tax_recur>`, we described a model of
48115
optimal taxation with state-contingent debt due to
49116
Robert E. Lucas, Jr., and Nancy Stokey {cite}`LucasStokey1983`.
@@ -774,7 +841,7 @@ x_min = -1.5555
774841
x_max = 17.339
775842
x_num = 300
776843
777-
x_grid = UCGrid((x_min, x_max, x_num))
844+
x_grid = [(x_min, x_max, x_num)]
778845
779846
crra_pref = CRRAutility(β=β, σ=σ, γ=γ)
780847
@@ -788,7 +855,7 @@ amss_model = AMSS(crra_pref, β, Π, g, x_grid, bounds_v)
788855
```{code-cell} python3
789856
# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS
790857
V = np.zeros((len(Π), x_num))
791-
V[:] = -nodes(x_grid).T ** 2
858+
V[:] = -get_grid_nodes(x_grid[0]) ** 2
792859
793860
σ_v_star = np.ones((S, x_num, S * 2))
794861
σ_v_star[:, :, :S] = 0.0
@@ -914,14 +981,14 @@ x_min = -3.4107
914981
x_max = 3.709
915982
x_num = 300
916983
917-
x_grid = UCGrid((x_min, x_max, x_num))
984+
x_grid = [(x_min, x_max, x_num)]
918985
log_pref = LogUtility(β=β, ψ=ψ)
919986
920987
S = len(Π)
921988
bounds_v = np.vstack([np.zeros(2 * S), np.hstack([1 - g, np.ones(S)]) ]).T
922989
923990
V = np.zeros((len(Π), x_num))
924-
V[:] = -(nodes(x_grid).T + x_max) ** 2 / 14
991+
V[:] = -(get_grid_nodes(x_grid[0]) + x_max) ** 2 / 14
925992
926993
σ_v_star = 1 - np.full((S, x_num, S * 2), 0.55)
927994
@@ -1026,4 +1093,4 @@ lump-sum transfers to the private sector.
10261093
problem, there exists another realization $\tilde s^t$ with
10271094
the same history up until the previous period, i.e., $\tilde s^{t-1}=
10281095
s^{t-1}$, but where the multiplier on constraint {eq}`AMSS_46` takes a positive value, so
1029-
$\gamma_t(\tilde s^t)>0$.
1096+
$\gamma_t(\tilde s^t)>0$.

0 commit comments

Comments
 (0)