Skip to content

Commit 9f0ac03

Browse files
CopilotHumphreyYang
andcommitted
Fix function naming and test complete JAX AMSS workflow
Co-authored-by: HumphreyYang <39026988+HumphreyYang@users.noreply.github.com>
1 parent 58fb11c commit 9f0ac03

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

lectures/_static/lecture_specific/amss/jax_amss_simple.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,87 @@ def penalized_objective_log(allocations, penalty=1000.0):
229229
'budget_constraint': log_budget_constraint(allocations),
230230
'feasibility_constraint': log_feasibility_constraint(allocations)
231231
}
232+
233+
234+
def solve_simple_ramsey_original(crra_params: CRRAUtilityParams, g, initial_guess=None):
235+
"""
236+
Solve simplified Ramsey problem using constrained optimization.
237+
238+
Parameters
239+
----------
240+
crra_params : CRRAUtilityParams
241+
Utility parameters
242+
g : array
243+
Government spending by state
244+
initial_guess : array, optional
245+
Initial guess for allocations
246+
247+
Returns
248+
-------
249+
dict
250+
Solution with optimal allocations and tax rates
251+
"""
252+
S = len(g)
253+
254+
if initial_guess is None:
255+
# Simple initial guess
256+
c_guess = 0.5 * jnp.ones(S)
257+
l_guess = 0.5 * jnp.ones(S)
258+
initial_guess = jnp.concatenate([c_guess, l_guess])
259+
260+
# For simplicity, use a penalty method approach
261+
@jit
262+
def penalized_objective(allocations, penalty=1000.0):
263+
obj = ramsey_objective(allocations, crra_params, g)
264+
265+
# Add penalties for constraint violations
266+
budget_viol = budget_constraint(allocations, crra_params, g)
267+
feasibility_viol = feasibility_constraint(allocations, g)
268+
269+
penalty_term = (penalty * jnp.sum(jnp.maximum(0, -budget_viol)**2) + # Budget surplus penalty
270+
penalty * jnp.sum(jnp.maximum(0, feasibility_viol)**2)) # Feasibility penalty
271+
272+
return obj + penalty_term
273+
274+
# Simple gradient descent (demonstrative)
275+
learning_rate = 0.01
276+
num_iterations = 1000
277+
278+
allocations = initial_guess
279+
280+
grad_fn = jit(grad(penalized_objective))
281+
282+
for i in range(num_iterations):
283+
grads = grad_fn(allocations)
284+
allocations = allocations - learning_rate * grads
285+
286+
# Clip to reasonable bounds
287+
allocations = jnp.clip(allocations, 0.01, 0.99)
288+
289+
if i % 200 == 0:
290+
obj_val = penalized_objective(allocations)
291+
print(f"Iteration {i}: Objective = {obj_val:.6f}")
292+
293+
# Extract results
294+
c_opt = allocations[:S]
295+
l_opt = allocations[S:]
296+
τ_opt = compute_tax_rates(c_opt, l_opt, crra_params)
297+
298+
return {
299+
'c': c_opt,
300+
'l': l_opt,
301+
'n': 1 - l_opt,
302+
'τ': τ_opt,
303+
'objective': ramsey_objective(allocations, crra_params, g),
304+
'budget_constraint': budget_constraint(allocations, crra_params, g),
305+
'feasibility_constraint': feasibility_constraint(allocations, g)
306+
}
307+
308+
309+
# Aliases for different utility functions
310+
def solve_simple_ramsey(crra_params: CRRAUtilityParams, g, initial_guess=None):
311+
"""Solve Ramsey problem with CRRA utility."""
312+
return solve_simple_ramsey_original(crra_params, g, initial_guess)
232313
"""
233314
Solve simplified Ramsey problem using constrained optimization.
234315

0 commit comments

Comments
 (0)