1+ """
2+ JAX-based AMSS model implementation.
3+ Converted from NumPy/Numba classes to JAX pure functions and NamedTuple structures.
4+ """
5+
6+ import jax .numpy as jnp
7+ from jax import jit , grad , vmap
8+ import jax
9+ from scipy .optimize import minimize # Use scipy for now
10+ from typing import NamedTuple , Callable
11+ try :
12+ from .jax_utilities import UtilityFunctions
13+ from .jax_interpolation import nodes_from_grid , eval_linear_jax
14+ except ImportError :
15+ from jax_utilities import UtilityFunctions
16+ from jax_interpolation import nodes_from_grid , eval_linear_jax
17+
18+
19+ class AMSSState (NamedTuple ):
20+ """State variables for AMSS model."""
21+ s : int # Current Markov state
22+ x : float # Continuation value state variable
23+
24+
25+ class AMSSParams (NamedTuple ):
26+ """Parameters for AMSS model."""
27+ β : float # Discount factor
28+ Π : jnp .ndarray # Markov transition matrix
29+ g : jnp .ndarray # Government spending by state
30+ x_grid : tuple # Grid parameters (x_min, x_max, x_num)
31+ bounds_v : jnp .ndarray # Bounds for optimization
32+ utility : UtilityFunctions # Utility functions
33+
34+
35+ class AMSSPolicies (NamedTuple ):
36+ """Policy functions for AMSS model."""
37+ V : jnp .ndarray # Value function
38+ σ_v_star : jnp .ndarray # Policy function for time t >= 1
39+ W : jnp .ndarray # Value function for time 0
40+ σ_w_star : jnp .ndarray # Policy function for time 0
41+
42+
43+ @jit
44+ def compute_consumption_leisure (l , g ):
45+ """Compute consumption given leisure and government spending."""
46+ return (1 - l ) - g
47+
48+
49+ @jit
50+ def objective_V (σ , state , V , params : AMSSParams ):
51+ """
52+ Objective function for time t >= 1 value function iteration.
53+
54+ Parameters
55+ ----------
56+ σ : array
57+ Policy variables [l_1, ..., l_S, T_1, ..., T_S]
58+ state : tuple
59+ Current state (s_, x_)
60+ V : array
61+ Current value function
62+ params : AMSSParams
63+ Model parameters
64+
65+ Returns
66+ -------
67+ float
68+ Negative of expected value (for minimization)
69+ """
70+ s_ , x_ = state
71+ S = len (params .Π )
72+
73+ l = σ [:S ]
74+ T = σ [S :]
75+
76+ c = compute_consumption_leisure (l , params .g )
77+ u_c = vmap (params .utility .Uc )(c , l )
78+ Eu_c = params .Π [s_ ] @ u_c
79+
80+ x = (u_c * x_ / (params .β * Eu_c ) -
81+ u_c * (c - T ) +
82+ vmap (params .utility .Ul )(c , l ) * (1 - l ))
83+
84+ # Interpolate next period value function
85+ x_nodes = nodes_from_grid (params .x_grid )
86+ V_next = jnp .array ([eval_linear_jax (params .x_grid , V [s ], jnp .array ([x [s ]]))[0 ]
87+ for s in range (S )])
88+
89+ expected_value = params .Π [s_ ] @ (vmap (params .utility .U )(c , l ) + params .β * V_next )
90+
91+ return - expected_value # Negative for minimization
92+
93+
94+ @jit
95+ def objective_W (σ , state , V , params : AMSSParams ):
96+ """
97+ Objective function for time 0 problem.
98+
99+ Parameters
100+ ----------
101+ σ : array
102+ Policy variables [l, T]
103+ state : tuple
104+ Current state (s_, b_0)
105+ V : array
106+ Value function
107+ params : AMSSParams
108+ Model parameters
109+
110+ Returns
111+ -------
112+ float
113+ Negative of value (for minimization)
114+ """
115+ s_ , b_0 = state
116+ l , T = σ
117+
118+ c = compute_consumption_leisure (l , params .g [s_ ])
119+ x = (- params .utility .Uc (c , l ) * (c - T - b_0 ) +
120+ params .utility .Ul (c , l ) * (1 - l ))
121+
122+ V_next = eval_linear_jax (params .x_grid , V [s_ ], jnp .array ([x ]))[0 ]
123+ value = params .utility .U (c , l ) + params .β * V_next
124+
125+ return - value # Negative for minimization
126+
127+
128+ def solve_bellman_iteration (V , σ_v_star , params : AMSSParams ,
129+ tol = 1e-7 , max_iter = 1000 , print_freq = 10 ):
130+ """
131+ Solve the Bellman equation using value function iteration.
132+
133+ Parameters
134+ ----------
135+ V : array
136+ Initial value function guess
137+ σ_v_star : array
138+ Initial policy function guess
139+ params : AMSSParams
140+ Model parameters
141+ tol : float
142+ Convergence tolerance
143+ max_iter : int
144+ Maximum iterations
145+ print_freq : int
146+ Print frequency
147+
148+ Returns
149+ -------
150+ tuple
151+ Updated (V, σ_v_star)
152+ """
153+ S = len (params .Π )
154+ x_nodes = nodes_from_grid (params .x_grid )
155+ n_x = len (x_nodes )
156+
157+ V_new = jnp .zeros_like (V )
158+
159+ for iteration in range (max_iter ):
160+ V_updated = jnp .zeros_like (V )
161+ σ_updated = jnp .zeros_like (σ_v_star )
162+
163+ # Loop over states and grid points
164+ for s_ in range (S ):
165+ for x_i in range (n_x ):
166+ state = (s_ , x_nodes [x_i ])
167+ x0 = σ_v_star [s_ , x_i ]
168+
169+ # Optimize using JAX
170+ bounds = [(params .bounds_v [i , 0 ], params .bounds_v [i , 1 ])
171+ for i in range (len (params .bounds_v ))]
172+
173+ # Simple optimization using scipy-like interface
174+ result = minimize (
175+ lambda σ : objective_V (σ , state , V , params ),
176+ x0 ,
177+ method = 'L-BFGS-B' ,
178+ bounds = bounds
179+ )
180+
181+ if result .success :
182+ V_updated = V_updated .at [s_ , x_i ].set (- result .fun )
183+ σ_updated = σ_updated .at [s_ , x_i ].set (result .x )
184+ else :
185+ print (f"Optimization failed at state { s_ } , grid point { x_i } " )
186+ V_updated = V_updated .at [s_ , x_i ].set (V [s_ , x_i ])
187+ σ_updated = σ_updated .at [s_ , x_i ].set (σ_v_star [s_ , x_i ])
188+
189+ # Check convergence
190+ error = jnp .max (jnp .abs (V_updated - V ))
191+
192+ if error < tol :
193+ print (f'Successfully completed VFI after { iteration + 1 } iterations' )
194+ return V_updated , σ_updated
195+
196+ if (iteration + 1 ) % print_freq == 0 :
197+ print (f'Error at iteration { iteration + 1 } : { error } ' )
198+
199+ V = V_updated
200+ σ_v_star = σ_updated
201+
202+ print (f'VFI did not converge after { max_iter } iterations' )
203+ return V , σ_v_star
204+
205+
206+ def solve_time_zero_problem (b_0 , V , params : AMSSParams ):
207+ """
208+ Solve the time 0 problem.
209+
210+ Parameters
211+ ----------
212+ b_0 : float
213+ Initial debt
214+ V : array
215+ Value function from time 1 problem
216+ params : AMSSParams
217+ Model parameters
218+
219+ Returns
220+ -------
221+ tuple
222+ (W, σ_w_star) where W is time 0 values and σ_w_star is time 0 policies
223+ """
224+ S = len (params .Π )
225+ W = jnp .zeros (S )
226+ σ_w_star = jnp .zeros ((S , 2 ))
227+
228+ bounds_w = [(- 9.0 , 1.0 ), (0.0 , 10.0 )]
229+
230+ for s_ in range (S ):
231+ state = (s_ , b_0 )
232+ x0 = jnp .array ([- 0.05 , 0.5 ]) # Initial guess
233+
234+ result = minimize (
235+ lambda σ : objective_W (σ , state , V , params ),
236+ x0 ,
237+ method = 'L-BFGS-B' ,
238+ bounds = bounds_w
239+ )
240+
241+ W = W .at [s_ ].set (- result .fun )
242+ σ_w_star = σ_w_star .at [s_ ].set (result .x )
243+
244+ print ('Successfully solved the time 0 problem.' )
245+ return W , σ_w_star
246+
247+
248+ @jit
249+ def simulate_amss (s_hist , b_0 , policies : AMSSPolicies , params : AMSSParams ):
250+ """
251+ Simulate AMSS model given state history and initial debt.
252+
253+ Parameters
254+ ----------
255+ s_hist : array
256+ History of Markov states
257+ b_0 : float
258+ Initial debt level
259+ policies : AMSSPolicies
260+ Solved policy functions
261+ params : AMSSParams
262+ Model parameters
263+
264+ Returns
265+ -------
266+ dict
267+ Simulation results with arrays for c, n, b, τ, g
268+ """
269+ T = len (s_hist )
270+ S = len (params .Π )
271+ x_nodes = nodes_from_grid (params .x_grid )
272+
273+ # Pre-allocate arrays
274+ n_hist = jnp .zeros (T )
275+ x_hist = jnp .zeros (T )
276+ c_hist = jnp .zeros (T )
277+ τ_hist = jnp .zeros (T )
278+ b_hist = jnp .zeros (T )
279+ g_hist = jnp .zeros (T )
280+
281+ # Time 0
282+ s_0 = s_hist [0 ]
283+ l_0 , T_0 = policies .σ_w_star [s_0 ]
284+ c_0 = compute_consumption_leisure (l_0 , params .g [s_0 ])
285+ x_0 = (- params .utility .Uc (c_0 , l_0 ) * (c_0 - T_0 - b_0 ) +
286+ params .utility .Ul (c_0 , l_0 ) * (1 - l_0 ))
287+
288+ n_hist = n_hist .at [0 ].set (1 - l_0 )
289+ x_hist = x_hist .at [0 ].set (x_0 )
290+ c_hist = c_hist .at [0 ].set (c_0 )
291+ τ_hist = τ_hist .at [0 ].set (1 - params .utility .Ul (c_0 , l_0 ) / params .utility .Uc (c_0 , l_0 ))
292+ b_hist = b_hist .at [0 ].set (b_0 )
293+ g_hist = g_hist .at [0 ].set (params .g [s_0 ])
294+
295+ # Time t > 0
296+ for t in range (T - 1 ):
297+ x_ = x_hist [t ]
298+ s_ = s_hist [t ]
299+
300+ # Interpolate policies for all states
301+ l = jnp .array ([eval_linear_jax (params .x_grid , policies .σ_v_star [s_ , :, s ],
302+ jnp .array ([x_ ]))[0 ] for s in range (S )])
303+ T_vals = jnp .array ([eval_linear_jax (params .x_grid , policies .σ_v_star [s_ , :, S + s ],
304+ jnp .array ([x_ ]))[0 ] for s in range (S )])
305+
306+ c = compute_consumption_leisure (l , params .g )
307+ u_c = vmap (params .utility .Uc )(c , l )
308+ Eu_c = params .Π [s_ ] @ u_c
309+
310+ x = (u_c * x_ / (params .β * Eu_c ) -
311+ u_c * (c - T_vals ) +
312+ vmap (params .utility .Ul )(c , l ) * (1 - l ))
313+
314+ s_next = s_hist [t + 1 ]
315+ c_next = c [s_next ]
316+ l_next = l [s_next ]
317+
318+ x_hist = x_hist .at [t + 1 ].set (x [s_next ])
319+ n_hist = n_hist .at [t + 1 ].set (1 - l_next )
320+ c_hist = c_hist .at [t + 1 ].set (c_next )
321+ τ_hist = τ_hist .at [t + 1 ].set (1 - params .utility .Ul (c_next , l_next ) / params .utility .Uc (c_next , l_next ))
322+ b_hist = b_hist .at [t + 1 ].set (x_ / (params .β * Eu_c ))
323+ g_hist = g_hist .at [t + 1 ].set (params .g [s_next ])
324+
325+ return {
326+ 'c' : c_hist ,
327+ 'n' : n_hist ,
328+ 'b' : b_hist ,
329+ 'τ' : τ_hist ,
330+ 'g' : g_hist
331+ }
332+
333+
334+ def solve_amss_model (params : AMSSParams , V_init , σ_v_init , b_0 ,
335+ W_init = None , σ_w_init = None , ** kwargs ):
336+ """
337+ Solve the complete AMSS model.
338+
339+ Parameters
340+ ----------
341+ params : AMSSParams
342+ Model parameters
343+ V_init : array
344+ Initial value function guess
345+ σ_v_init : array
346+ Initial policy function guess
347+ b_0 : float
348+ Initial debt level
349+ W_init : array, optional
350+ Initial time 0 value function
351+ σ_w_init : array, optional
352+ Initial time 0 policy function
353+ **kwargs
354+ Additional arguments for solver
355+
356+ Returns
357+ -------
358+ AMSSPolicies
359+ Solved policy functions
360+ """
361+ print ("===============" )
362+ print ("Solve time 1 problem" )
363+ print ("===============" )
364+ V , σ_v_star = solve_bellman_iteration (V_init , σ_v_init , params , ** kwargs )
365+
366+ print ("===============" )
367+ print ("Solve time 0 problem" )
368+ print ("===============" )
369+ W , σ_w_star = solve_time_zero_problem (b_0 , V , params )
370+
371+ return AMSSPolicies (V = V , σ_v_star = σ_v_star , W = W , σ_w_star = σ_w_star )
0 commit comments