1+ """
2+ Simplified JAX-based AMSS model implementation for demonstration.
3+ Shows key JAX concepts: NamedTuple, JIT, grad, vmap.
4+ """
5+
6+ import jax .numpy as jnp
7+ from jax import jit , grad , vmap
8+ from typing import NamedTuple
9+ try :
10+ from .jax_utilities import *
11+ from .jax_interpolation import *
12+ except ImportError :
13+ from jax_utilities import *
14+ from jax_interpolation import *
15+
16+
17+ class AMSSSimpleParams (NamedTuple ):
18+ """Simplified AMSS model parameters."""
19+ β : float
20+ Π : jnp .ndarray
21+ g : jnp .ndarray
22+ utility : UtilityFunctions
23+
24+
25+ class AMSSSimpleState (NamedTuple ):
26+ """State for simplified AMSS model."""
27+ c : jnp .ndarray # Consumption by state
28+ l : jnp .ndarray # Leisure by state
29+ τ : jnp .ndarray # Tax rates by state
30+
31+
32+ @jit
33+ def compute_state_variables (c , l , g ):
34+ """Compute derived state variables."""
35+ n = 1 - l # Labor
36+ y = n # Output (assuming unit productivity)
37+ return {'n' : n , 'y' : y , 'budget_residual' : y - g - c }
38+
39+
40+ @jit
41+ def compute_tax_rates (c , l , crra_params : CRRAUtilityParams ):
42+ """Compute tax rates using marginal utilities."""
43+ Uc_vals = vmap (lambda ci , li : crra_utility_c (ci , li , crra_params ))(c , l )
44+ Ul_vals = vmap (lambda ci , li : crra_utility_l (ci , li , crra_params ))(c , l )
45+ return 1 - Ul_vals / Uc_vals
46+
47+
48+ @jit
49+ def ramsey_objective (allocations , crra_params : CRRAUtilityParams , g ):
50+ """
51+ Simplified Ramsey objective function.
52+
53+ Parameters
54+ ----------
55+ allocations : array
56+ Concatenated [c_values, l_values] for all states
57+ crra_params : CRRAUtilityParams
58+ Utility parameters
59+ g : array
60+ Government spending by state
61+
62+ Returns
63+ -------
64+ float
65+ Negative of social welfare (for minimization)
66+ """
67+ S = len (g )
68+ c = allocations [:S ]
69+ l = allocations [S :]
70+
71+ # Compute utilities for each state
72+ utilities = vmap (lambda ci , li : crra_utility (ci , li , crra_params ))(c , l )
73+
74+ # Expected utility using stationary distribution
75+ # For simplicity, assume uniform weights
76+ expected_utility = jnp .mean (utilities )
77+
78+ return - expected_utility
79+
80+
81+ @jit
82+ def budget_constraint (allocations , crra_params : CRRAUtilityParams , g ):
83+ """
84+ Government budget constraint.
85+
86+ Parameters
87+ ----------
88+ allocations : array
89+ Concatenated [c_values, l_values] for all states
90+ crra_params : CRRAUtilityParams
91+ Utility parameters
92+ g : array
93+ Government spending by state
94+
95+ Returns
96+ -------
97+ array
98+ Budget constraint violations
99+ """
100+ S = len (g )
101+ c = allocations [:S ]
102+ l = allocations [S :]
103+
104+ n = 1 - l # Labor
105+ τ = compute_tax_rates (c , l , crra_params )
106+
107+ # Budget constraint: τ * n >= g (simplified, no debt dynamics)
108+ return τ * n - g
109+
110+
111+ @jit
112+ def feasibility_constraint (allocations , g ):
113+ """
114+ Resource feasibility constraint.
115+
116+ Parameters
117+ ----------
118+ allocations : array
119+ Concatenated [c_values, l_values] for all states
120+ g : array
121+ Government spending by state
122+
123+ Returns
124+ -------
125+ array
126+ Feasibility constraint violations
127+ """
128+ S = len (g )
129+ c = allocations [:S ]
130+ l = allocations [S :]
131+
132+ n = 1 - l # Labor
133+ y = n # Output
134+
135+ # Resource constraint: c + g <= y
136+ return c + g - y
137+
138+
139+ def solve_simple_ramsey_log (log_params : LogUtilityParams , g , initial_guess = None ):
140+ """
141+ Solve simplified Ramsey problem with log utility.
142+
143+ Parameters
144+ ----------
145+ log_params : LogUtilityParams
146+ Log utility parameters
147+ g : array
148+ Government spending by state
149+ initial_guess : array, optional
150+ Initial guess for allocations
151+
152+ Returns
153+ -------
154+ dict
155+ Solution with optimal allocations and tax rates
156+ """
157+ S = len (g )
158+
159+ if initial_guess is None :
160+ # Simple initial guess
161+ c_guess = 0.5 * jnp .ones (S )
162+ l_guess = 0.5 * jnp .ones (S )
163+ initial_guess = jnp .concatenate ([c_guess , l_guess ])
164+
165+ # Define objectives for log utility
166+ @jit
167+ def log_ramsey_objective (allocations ):
168+ c = allocations [:S ]
169+ l = allocations [S :]
170+ utilities = vmap (lambda ci , li : log_utility (ci , li , log_params ))(c , l )
171+ return - jnp .mean (utilities )
172+
173+ @jit
174+ def log_budget_constraint (allocations ):
175+ c = allocations [:S ]
176+ l = allocations [S :]
177+ n = 1 - l
178+ Uc_vals = vmap (lambda ci , li : log_utility_c (ci , li , log_params ))(c , l )
179+ Ul_vals = vmap (lambda ci , li : log_utility_l (ci , li , log_params ))(c , l )
180+ τ = 1 - Ul_vals / Uc_vals
181+ return τ * n - g
182+
183+ @jit
184+ def log_feasibility_constraint (allocations ):
185+ c = allocations [:S ]
186+ l = allocations [S :]
187+ n = 1 - l
188+ return c + g - n
189+
190+ @jit
191+ def penalized_objective_log (allocations , penalty = 1000.0 ):
192+ obj = log_ramsey_objective (allocations )
193+ budget_viol = log_budget_constraint (allocations )
194+ feasibility_viol = log_feasibility_constraint (allocations )
195+
196+ penalty_term = (penalty * jnp .sum (jnp .maximum (0 , - budget_viol )** 2 ) +
197+ penalty * jnp .sum (jnp .maximum (0 , feasibility_viol )** 2 ))
198+
199+ return obj + penalty_term
200+
201+ # Gradient descent
202+ learning_rate = 0.01
203+ num_iterations = 1000
204+ allocations = initial_guess
205+ grad_fn = jit (grad (penalized_objective_log ))
206+
207+ for i in range (num_iterations ):
208+ grads = grad_fn (allocations )
209+ allocations = allocations - learning_rate * grads
210+ allocations = jnp .clip (allocations , 0.01 , 0.99 )
211+
212+ if i % 200 == 0 :
213+ obj_val = penalized_objective_log (allocations )
214+ print (f"Log utility iteration { i } : Objective = { obj_val :.6f} " )
215+
216+ # Extract results
217+ c_opt = allocations [:S ]
218+ l_opt = allocations [S :]
219+ Uc_vals = vmap (lambda ci , li : log_utility_c (ci , li , log_params ))(c_opt , l_opt )
220+ Ul_vals = vmap (lambda ci , li : log_utility_l (ci , li , log_params ))(c_opt , l_opt )
221+ τ_opt = 1 - Ul_vals / Uc_vals
222+
223+ return {
224+ 'c' : c_opt ,
225+ 'l' : l_opt ,
226+ 'n' : 1 - l_opt ,
227+ 'τ' : τ_opt ,
228+ 'objective' : log_ramsey_objective (allocations ),
229+ 'budget_constraint' : log_budget_constraint (allocations ),
230+ 'feasibility_constraint' : log_feasibility_constraint (allocations )
231+ }
232+ """
233+ Solve simplified Ramsey problem using constrained optimization.
234+
235+ Parameters
236+ ----------
237+ crra_params : CRRAUtilityParams
238+ Utility parameters
239+ g : array
240+ Government spending by state
241+ initial_guess : array, optional
242+ Initial guess for allocations
243+
244+ Returns
245+ -------
246+ dict
247+ Solution with optimal allocations and tax rates
248+ """
249+ S = len (g )
250+
251+ if initial_guess is None :
252+ # Simple initial guess
253+ c_guess = 0.5 * jnp .ones (S )
254+ l_guess = 0.5 * jnp .ones (S )
255+ initial_guess = jnp .concatenate ([c_guess , l_guess ])
256+
257+ # For simplicity, use a penalty method approach
258+ @jit
259+ def penalized_objective (allocations , penalty = 1000.0 ):
260+ obj = ramsey_objective (allocations , crra_params , g )
261+
262+ # Add penalties for constraint violations
263+ budget_viol = budget_constraint (allocations , crra_params , g )
264+ feasibility_viol = feasibility_constraint (allocations , g )
265+
266+ penalty_term = (penalty * jnp .sum (jnp .maximum (0 , - budget_viol )** 2 ) + # Budget surplus penalty
267+ penalty * jnp .sum (jnp .maximum (0 , feasibility_viol )** 2 )) # Feasibility penalty
268+
269+ return obj + penalty_term
270+
271+ # Simple gradient descent (demonstrative)
272+ learning_rate = 0.01
273+ num_iterations = 1000
274+
275+ allocations = initial_guess
276+
277+ grad_fn = jit (grad (penalized_objective ))
278+
279+ for i in range (num_iterations ):
280+ grads = grad_fn (allocations )
281+ allocations = allocations - learning_rate * grads
282+
283+ # Clip to reasonable bounds
284+ allocations = jnp .clip (allocations , 0.01 , 0.99 )
285+
286+ if i % 200 == 0 :
287+ obj_val = penalized_objective (allocations )
288+ print (f"Iteration { i } : Objective = { obj_val :.6f} " )
289+
290+ # Extract results
291+ c_opt = allocations [:S ]
292+ l_opt = allocations [S :]
293+ τ_opt = compute_tax_rates (c_opt , l_opt , crra_params )
294+
295+ return {
296+ 'c' : c_opt ,
297+ 'l' : l_opt ,
298+ 'n' : 1 - l_opt ,
299+ 'τ' : τ_opt ,
300+ 'objective' : ramsey_objective (allocations , crra_params , g ),
301+ 'budget_constraint' : budget_constraint (allocations , crra_params , g ),
302+ 'feasibility_constraint' : feasibility_constraint (allocations , g )
303+ }
304+
305+
306+ def create_amss_simple_example ():
307+ """Create a simple AMSS example."""
308+
309+ # Parameters
310+ β = 0.9
311+ σ = 2.0
312+ γ = 2.0
313+
314+ # Two-state Markov chain
315+ Π = jnp .array ([[0.8 , 0.2 ],
316+ [0.3 , 0.7 ]])
317+
318+ # Government spending in each state
319+ g = jnp .array ([0.1 , 0.2 ]) # Low and high spending
320+
321+ # Create utility parameters
322+ crra_params = CRRAUtilityParams (β = β , σ = σ , γ = γ )
323+
324+ return crra_params , Π , g
325+
326+
327+ # Example usage
328+ if __name__ == "__main__" :
329+ # Create and solve simple AMSS model
330+ crra_params , Π , g = create_amss_simple_example ()
331+
332+ print ("Solving simplified AMSS model..." )
333+ solution = solve_simple_ramsey (crra_params , g )
334+
335+ print ("\n Optimal allocations:" )
336+ print (f"Consumption: { solution ['c' ]} " )
337+ print (f"Leisure: { solution ['l' ]} " )
338+ print (f"Labor: { solution ['n' ]} " )
339+ print (f"Tax rates: { solution ['τ' ]} " )
340+ print (f"\n Objective value: { solution ['objective' ]:.6f} " )
0 commit comments