Skip to content

Commit d5e578f

Browse files
CopilotHumphreyYang
andcommitted
Create JAX-based utility functions and interpolation utilities for AMSS model
Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com>
1 parent 25ca89d commit d5e578f

File tree

5 files changed

+646
-0
lines changed

5 files changed

+646
-0
lines changed
Binary file not shown.
Binary file not shown.
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
"""
2+
JAX-based AMSS model implementation.
3+
Converted from NumPy/Numba classes to JAX pure functions and NamedTuple structures.
4+
"""
5+
6+
import jax.numpy as jnp
7+
from jax import jit, grad, vmap
8+
import jax
9+
from scipy.optimize import minimize # Use scipy for now
10+
from typing import NamedTuple, Callable
11+
try:
12+
from .jax_utilities import UtilityFunctions
13+
from .jax_interpolation import nodes_from_grid, eval_linear_jax
14+
except ImportError:
15+
from jax_utilities import UtilityFunctions
16+
from jax_interpolation import nodes_from_grid, eval_linear_jax
17+
18+
19+
class AMSSState(NamedTuple):
20+
"""State variables for AMSS model."""
21+
s: int # Current Markov state
22+
x: float # Continuation value state variable
23+
24+
25+
class AMSSParams(NamedTuple):
26+
"""Parameters for AMSS model."""
27+
β: float # Discount factor
28+
Π: jnp.ndarray # Markov transition matrix
29+
g: jnp.ndarray # Government spending by state
30+
x_grid: tuple # Grid parameters (x_min, x_max, x_num)
31+
bounds_v: jnp.ndarray # Bounds for optimization
32+
utility: UtilityFunctions # Utility functions
33+
34+
35+
class AMSSPolicies(NamedTuple):
36+
"""Policy functions for AMSS model."""
37+
V: jnp.ndarray # Value function
38+
σ_v_star: jnp.ndarray # Policy function for time t >= 1
39+
W: jnp.ndarray # Value function for time 0
40+
σ_w_star: jnp.ndarray # Policy function for time 0
41+
42+
43+
@jit
44+
def compute_consumption_leisure(l, g):
45+
"""Compute consumption given leisure and government spending."""
46+
return (1 - l) - g
47+
48+
49+
@jit
50+
def objective_V(σ, state, V, params: AMSSParams):
51+
"""
52+
Objective function for time t >= 1 value function iteration.
53+
54+
Parameters
55+
----------
56+
σ : array
57+
Policy variables [l_1, ..., l_S, T_1, ..., T_S]
58+
state : tuple
59+
Current state (s_, x_)
60+
V : array
61+
Current value function
62+
params : AMSSParams
63+
Model parameters
64+
65+
Returns
66+
-------
67+
float
68+
Negative of expected value (for minimization)
69+
"""
70+
s_, x_ = state
71+
S = len(params.Π)
72+
73+
l = σ[:S]
74+
T = σ[S:]
75+
76+
c = compute_consumption_leisure(l, params.g)
77+
u_c = vmap(params.utility.Uc)(c, l)
78+
Eu_c = params.Π[s_] @ u_c
79+
80+
x = (u_c * x_ / (params.β * Eu_c) -
81+
u_c * (c - T) +
82+
vmap(params.utility.Ul)(c, l) * (1 - l))
83+
84+
# Interpolate next period value function
85+
x_nodes = nodes_from_grid(params.x_grid)
86+
V_next = jnp.array([eval_linear_jax(params.x_grid, V[s], jnp.array([x[s]]))[0]
87+
for s in range(S)])
88+
89+
expected_value = params.Π[s_] @ (vmap(params.utility.U)(c, l) + params.β * V_next)
90+
91+
return -expected_value # Negative for minimization
92+
93+
94+
@jit
95+
def objective_W(σ, state, V, params: AMSSParams):
96+
"""
97+
Objective function for time 0 problem.
98+
99+
Parameters
100+
----------
101+
σ : array
102+
Policy variables [l, T]
103+
state : tuple
104+
Current state (s_, b_0)
105+
V : array
106+
Value function
107+
params : AMSSParams
108+
Model parameters
109+
110+
Returns
111+
-------
112+
float
113+
Negative of value (for minimization)
114+
"""
115+
s_, b_0 = state
116+
l, T = σ
117+
118+
c = compute_consumption_leisure(l, params.g[s_])
119+
x = (-params.utility.Uc(c, l) * (c - T - b_0) +
120+
params.utility.Ul(c, l) * (1 - l))
121+
122+
V_next = eval_linear_jax(params.x_grid, V[s_], jnp.array([x]))[0]
123+
value = params.utility.U(c, l) + params.β * V_next
124+
125+
return -value # Negative for minimization
126+
127+
128+
def solve_bellman_iteration(V, σ_v_star, params: AMSSParams,
129+
tol=1e-7, max_iter=1000, print_freq=10):
130+
"""
131+
Solve the Bellman equation using value function iteration.
132+
133+
Parameters
134+
----------
135+
V : array
136+
Initial value function guess
137+
σ_v_star : array
138+
Initial policy function guess
139+
params : AMSSParams
140+
Model parameters
141+
tol : float
142+
Convergence tolerance
143+
max_iter : int
144+
Maximum iterations
145+
print_freq : int
146+
Print frequency
147+
148+
Returns
149+
-------
150+
tuple
151+
Updated (V, σ_v_star)
152+
"""
153+
S = len(params.Π)
154+
x_nodes = nodes_from_grid(params.x_grid)
155+
n_x = len(x_nodes)
156+
157+
V_new = jnp.zeros_like(V)
158+
159+
for iteration in range(max_iter):
160+
V_updated = jnp.zeros_like(V)
161+
σ_updated = jnp.zeros_like(σ_v_star)
162+
163+
# Loop over states and grid points
164+
for s_ in range(S):
165+
for x_i in range(n_x):
166+
state = (s_, x_nodes[x_i])
167+
x0 = σ_v_star[s_, x_i]
168+
169+
# Optimize using JAX
170+
bounds = [(params.bounds_v[i, 0], params.bounds_v[i, 1])
171+
for i in range(len(params.bounds_v))]
172+
173+
# Simple optimization using scipy-like interface
174+
result = minimize(
175+
lambda σ: objective_V(σ, state, V, params),
176+
x0,
177+
method='L-BFGS-B',
178+
bounds=bounds
179+
)
180+
181+
if result.success:
182+
V_updated = V_updated.at[s_, x_i].set(-result.fun)
183+
σ_updated = σ_updated.at[s_, x_i].set(result.x)
184+
else:
185+
print(f"Optimization failed at state {s_}, grid point {x_i}")
186+
V_updated = V_updated.at[s_, x_i].set(V[s_, x_i])
187+
σ_updated = σ_updated.at[s_, x_i].set(σ_v_star[s_, x_i])
188+
189+
# Check convergence
190+
error = jnp.max(jnp.abs(V_updated - V))
191+
192+
if error < tol:
193+
print(f'Successfully completed VFI after {iteration + 1} iterations')
194+
return V_updated, σ_updated
195+
196+
if (iteration + 1) % print_freq == 0:
197+
print(f'Error at iteration {iteration + 1}: {error}')
198+
199+
V = V_updated
200+
σ_v_star = σ_updated
201+
202+
print(f'VFI did not converge after {max_iter} iterations')
203+
return V, σ_v_star
204+
205+
206+
def solve_time_zero_problem(b_0, V, params: AMSSParams):
207+
"""
208+
Solve the time 0 problem.
209+
210+
Parameters
211+
----------
212+
b_0 : float
213+
Initial debt
214+
V : array
215+
Value function from time 1 problem
216+
params : AMSSParams
217+
Model parameters
218+
219+
Returns
220+
-------
221+
tuple
222+
(W, σ_w_star) where W is time 0 values and σ_w_star is time 0 policies
223+
"""
224+
S = len(params.Π)
225+
W = jnp.zeros(S)
226+
σ_w_star = jnp.zeros((S, 2))
227+
228+
bounds_w = [(-9.0, 1.0), (0.0, 10.0)]
229+
230+
for s_ in range(S):
231+
state = (s_, b_0)
232+
x0 = jnp.array([-0.05, 0.5]) # Initial guess
233+
234+
result = minimize(
235+
lambda σ: objective_W(σ, state, V, params),
236+
x0,
237+
method='L-BFGS-B',
238+
bounds=bounds_w
239+
)
240+
241+
W = W.at[s_].set(-result.fun)
242+
σ_w_star = σ_w_star.at[s_].set(result.x)
243+
244+
print('Successfully solved the time 0 problem.')
245+
return W, σ_w_star
246+
247+
248+
@jit
249+
def simulate_amss(s_hist, b_0, policies: AMSSPolicies, params: AMSSParams):
250+
"""
251+
Simulate AMSS model given state history and initial debt.
252+
253+
Parameters
254+
----------
255+
s_hist : array
256+
History of Markov states
257+
b_0 : float
258+
Initial debt level
259+
policies : AMSSPolicies
260+
Solved policy functions
261+
params : AMSSParams
262+
Model parameters
263+
264+
Returns
265+
-------
266+
dict
267+
Simulation results with arrays for c, n, b, τ, g
268+
"""
269+
T = len(s_hist)
270+
S = len(params.Π)
271+
x_nodes = nodes_from_grid(params.x_grid)
272+
273+
# Pre-allocate arrays
274+
n_hist = jnp.zeros(T)
275+
x_hist = jnp.zeros(T)
276+
c_hist = jnp.zeros(T)
277+
τ_hist = jnp.zeros(T)
278+
b_hist = jnp.zeros(T)
279+
g_hist = jnp.zeros(T)
280+
281+
# Time 0
282+
s_0 = s_hist[0]
283+
l_0, T_0 = policies.σ_w_star[s_0]
284+
c_0 = compute_consumption_leisure(l_0, params.g[s_0])
285+
x_0 = (-params.utility.Uc(c_0, l_0) * (c_0 - T_0 - b_0) +
286+
params.utility.Ul(c_0, l_0) * (1 - l_0))
287+
288+
n_hist = n_hist.at[0].set(1 - l_0)
289+
x_hist = x_hist.at[0].set(x_0)
290+
c_hist = c_hist.at[0].set(c_0)
291+
τ_hist = τ_hist.at[0].set(1 - params.utility.Ul(c_0, l_0) / params.utility.Uc(c_0, l_0))
292+
b_hist = b_hist.at[0].set(b_0)
293+
g_hist = g_hist.at[0].set(params.g[s_0])
294+
295+
# Time t > 0
296+
for t in range(T - 1):
297+
x_ = x_hist[t]
298+
s_ = s_hist[t]
299+
300+
# Interpolate policies for all states
301+
l = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, s],
302+
jnp.array([x_]))[0] for s in range(S)])
303+
T_vals = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, S+s],
304+
jnp.array([x_]))[0] for s in range(S)])
305+
306+
c = compute_consumption_leisure(l, params.g)
307+
u_c = vmap(params.utility.Uc)(c, l)
308+
Eu_c = params.Π[s_] @ u_c
309+
310+
x = (u_c * x_ / (params.β * Eu_c) -
311+
u_c * (c - T_vals) +
312+
vmap(params.utility.Ul)(c, l) * (1 - l))
313+
314+
s_next = s_hist[t+1]
315+
c_next = c[s_next]
316+
l_next = l[s_next]
317+
318+
x_hist = x_hist.at[t+1].set(x[s_next])
319+
n_hist = n_hist.at[t+1].set(1 - l_next)
320+
c_hist = c_hist.at[t+1].set(c_next)
321+
τ_hist = τ_hist.at[t+1].set(1 - params.utility.Ul(c_next, l_next) / params.utility.Uc(c_next, l_next))
322+
b_hist = b_hist.at[t+1].set(x_ / (params.β * Eu_c))
323+
g_hist = g_hist.at[t+1].set(params.g[s_next])
324+
325+
return {
326+
'c': c_hist,
327+
'n': n_hist,
328+
'b': b_hist,
329+
'τ': τ_hist,
330+
'g': g_hist
331+
}
332+
333+
334+
def solve_amss_model(params: AMSSParams, V_init, σ_v_init, b_0,
335+
W_init=None, σ_w_init=None, **kwargs):
336+
"""
337+
Solve the complete AMSS model.
338+
339+
Parameters
340+
----------
341+
params : AMSSParams
342+
Model parameters
343+
V_init : array
344+
Initial value function guess
345+
σ_v_init : array
346+
Initial policy function guess
347+
b_0 : float
348+
Initial debt level
349+
W_init : array, optional
350+
Initial time 0 value function
351+
σ_w_init : array, optional
352+
Initial time 0 policy function
353+
**kwargs
354+
Additional arguments for solver
355+
356+
Returns
357+
-------
358+
AMSSPolicies
359+
Solved policy functions
360+
"""
361+
print("===============")
362+
print("Solve time 1 problem")
363+
print("===============")
364+
V, σ_v_star = solve_bellman_iteration(V_init, σ_v_init, params, **kwargs)
365+
366+
print("===============")
367+
print("Solve time 0 problem")
368+
print("===============")
369+
W, σ_w_star = solve_time_zero_problem(b_0, V, params)
370+
371+
return AMSSPolicies(V=V, σ_v_star=σ_v_star, W=W, σ_w_star=σ_w_star)

0 commit comments

Comments
 (0)