diff --git a/.gitignore b/.gitignore index e6b45300..10b0af6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ _build/ -.DS_Store \ No newline at end of file +.DS_Store +__pycache__/ +*.pyc +*.pyo \ No newline at end of file diff --git a/environment.yml b/environment.yml index f69c4ec3..733552fa 100644 --- a/environment.yml +++ b/environment.yml @@ -18,4 +18,6 @@ dependencies: - sphinx-togglebutton==0.3.2 # Docker Requirements - pytz + # JAX for lecture content + - jax diff --git a/lectures/_static/lecture_specific/amss/jax_amss.py b/lectures/_static/lecture_specific/amss/jax_amss.py new file mode 100644 index 00000000..e1d76138 --- /dev/null +++ b/lectures/_static/lecture_specific/amss/jax_amss.py @@ -0,0 +1,371 @@ +""" +JAX-based AMSS model implementation. +Converted from NumPy/Numba classes to JAX pure functions and NamedTuple structures. +""" + +import jax.numpy as jnp +from jax import jit, grad, vmap +import jax +from scipy.optimize import minimize # Use scipy for now +from typing import NamedTuple, Callable +try: + from .jax_utilities import UtilityFunctions + from .jax_interpolation import nodes_from_grid, eval_linear_jax +except ImportError: + from jax_utilities import UtilityFunctions + from jax_interpolation import nodes_from_grid, eval_linear_jax + + +class AMSSState(NamedTuple): + """State variables for AMSS model.""" + s: int # Current Markov state + x: float # Continuation value state variable + + +class AMSSParams(NamedTuple): + """Parameters for AMSS model.""" + β: float # Discount factor + Π: jnp.ndarray # Markov transition matrix + g: jnp.ndarray # Government spending by state + x_grid: tuple # Grid parameters (x_min, x_max, x_num) + bounds_v: jnp.ndarray # Bounds for optimization + utility: UtilityFunctions # Utility functions + + +class AMSSPolicies(NamedTuple): + """Policy functions for AMSS model.""" + V: jnp.ndarray # Value function + σ_v_star: jnp.ndarray # Policy function for time t >= 1 + W: jnp.ndarray # Value function for time 0 + σ_w_star: jnp.ndarray # Policy function for time 0 + + +@jit +def compute_consumption_leisure(l, g): + """Compute consumption given leisure and government spending.""" + return (1 - l) - g + + +@jit +def objective_V(σ, state, V, params: AMSSParams): + """ + Objective function for time t >= 1 value function iteration. + + Parameters + ---------- + σ : array + Policy variables [l_1, ..., l_S, T_1, ..., T_S] + state : tuple + Current state (s_, x_) + V : array + Current value function + params : AMSSParams + Model parameters + + Returns + ------- + float + Negative of expected value (for minimization) + """ + s_, x_ = state + S = len(params.Π) + + l = σ[:S] + T = σ[S:] + + c = compute_consumption_leisure(l, params.g) + u_c = vmap(params.utility.Uc)(c, l) + Eu_c = params.Π[s_] @ u_c + + x = (u_c * x_ / (params.β * Eu_c) - + u_c * (c - T) + + vmap(params.utility.Ul)(c, l) * (1 - l)) + + # Interpolate next period value function + x_nodes = nodes_from_grid(params.x_grid) + V_next = jnp.array([eval_linear_jax(params.x_grid, V[s], jnp.array([x[s]]))[0] + for s in range(S)]) + + expected_value = params.Π[s_] @ (vmap(params.utility.U)(c, l) + params.β * V_next) + + return -expected_value # Negative for minimization + + +@jit +def objective_W(σ, state, V, params: AMSSParams): + """ + Objective function for time 0 problem. + + Parameters + ---------- + σ : array + Policy variables [l, T] + state : tuple + Current state (s_, b_0) + V : array + Value function + params : AMSSParams + Model parameters + + Returns + ------- + float + Negative of value (for minimization) + """ + s_, b_0 = state + l, T = σ + + c = compute_consumption_leisure(l, params.g[s_]) + x = (-params.utility.Uc(c, l) * (c - T - b_0) + + params.utility.Ul(c, l) * (1 - l)) + + V_next = eval_linear_jax(params.x_grid, V[s_], jnp.array([x]))[0] + value = params.utility.U(c, l) + params.β * V_next + + return -value # Negative for minimization + + +def solve_bellman_iteration(V, σ_v_star, params: AMSSParams, + tol=1e-7, max_iter=1000, print_freq=10): + """ + Solve the Bellman equation using value function iteration. + + Parameters + ---------- + V : array + Initial value function guess + σ_v_star : array + Initial policy function guess + params : AMSSParams + Model parameters + tol : float + Convergence tolerance + max_iter : int + Maximum iterations + print_freq : int + Print frequency + + Returns + ------- + tuple + Updated (V, σ_v_star) + """ + S = len(params.Π) + x_nodes = nodes_from_grid(params.x_grid) + n_x = len(x_nodes) + + V_new = jnp.zeros_like(V) + + for iteration in range(max_iter): + V_updated = jnp.zeros_like(V) + σ_updated = jnp.zeros_like(σ_v_star) + + # Loop over states and grid points + for s_ in range(S): + for x_i in range(n_x): + state = (s_, x_nodes[x_i]) + x0 = σ_v_star[s_, x_i] + + # Optimize using JAX + bounds = [(params.bounds_v[i, 0], params.bounds_v[i, 1]) + for i in range(len(params.bounds_v))] + + # Simple optimization using scipy-like interface + result = minimize( + lambda σ: objective_V(σ, state, V, params), + x0, + method='L-BFGS-B', + bounds=bounds + ) + + if result.success: + V_updated = V_updated.at[s_, x_i].set(-result.fun) + σ_updated = σ_updated.at[s_, x_i].set(result.x) + else: + print(f"Optimization failed at state {s_}, grid point {x_i}") + V_updated = V_updated.at[s_, x_i].set(V[s_, x_i]) + σ_updated = σ_updated.at[s_, x_i].set(σ_v_star[s_, x_i]) + + # Check convergence + error = jnp.max(jnp.abs(V_updated - V)) + + if error < tol: + print(f'Successfully completed VFI after {iteration + 1} iterations') + return V_updated, σ_updated + + if (iteration + 1) % print_freq == 0: + print(f'Error at iteration {iteration + 1}: {error}') + + V = V_updated + σ_v_star = σ_updated + + print(f'VFI did not converge after {max_iter} iterations') + return V, σ_v_star + + +def solve_time_zero_problem(b_0, V, params: AMSSParams): + """ + Solve the time 0 problem. + + Parameters + ---------- + b_0 : float + Initial debt + V : array + Value function from time 1 problem + params : AMSSParams + Model parameters + + Returns + ------- + tuple + (W, σ_w_star) where W is time 0 values and σ_w_star is time 0 policies + """ + S = len(params.Π) + W = jnp.zeros(S) + σ_w_star = jnp.zeros((S, 2)) + + bounds_w = [(-9.0, 1.0), (0.0, 10.0)] + + for s_ in range(S): + state = (s_, b_0) + x0 = jnp.array([-0.05, 0.5]) # Initial guess + + result = minimize( + lambda σ: objective_W(σ, state, V, params), + x0, + method='L-BFGS-B', + bounds=bounds_w + ) + + W = W.at[s_].set(-result.fun) + σ_w_star = σ_w_star.at[s_].set(result.x) + + print('Successfully solved the time 0 problem.') + return W, σ_w_star + + +@jit +def simulate_amss(s_hist, b_0, policies: AMSSPolicies, params: AMSSParams): + """ + Simulate AMSS model given state history and initial debt. + + Parameters + ---------- + s_hist : array + History of Markov states + b_0 : float + Initial debt level + policies : AMSSPolicies + Solved policy functions + params : AMSSParams + Model parameters + + Returns + ------- + dict + Simulation results with arrays for c, n, b, τ, g + """ + T = len(s_hist) + S = len(params.Π) + x_nodes = nodes_from_grid(params.x_grid) + + # Pre-allocate arrays + n_hist = jnp.zeros(T) + x_hist = jnp.zeros(T) + c_hist = jnp.zeros(T) + τ_hist = jnp.zeros(T) + b_hist = jnp.zeros(T) + g_hist = jnp.zeros(T) + + # Time 0 + s_0 = s_hist[0] + l_0, T_0 = policies.σ_w_star[s_0] + c_0 = compute_consumption_leisure(l_0, params.g[s_0]) + x_0 = (-params.utility.Uc(c_0, l_0) * (c_0 - T_0 - b_0) + + params.utility.Ul(c_0, l_0) * (1 - l_0)) + + n_hist = n_hist.at[0].set(1 - l_0) + x_hist = x_hist.at[0].set(x_0) + c_hist = c_hist.at[0].set(c_0) + τ_hist = τ_hist.at[0].set(1 - params.utility.Ul(c_0, l_0) / params.utility.Uc(c_0, l_0)) + b_hist = b_hist.at[0].set(b_0) + g_hist = g_hist.at[0].set(params.g[s_0]) + + # Time t > 0 + for t in range(T - 1): + x_ = x_hist[t] + s_ = s_hist[t] + + # Interpolate policies for all states + l = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, s], + jnp.array([x_]))[0] for s in range(S)]) + T_vals = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, S+s], + jnp.array([x_]))[0] for s in range(S)]) + + c = compute_consumption_leisure(l, params.g) + u_c = vmap(params.utility.Uc)(c, l) + Eu_c = params.Π[s_] @ u_c + + x = (u_c * x_ / (params.β * Eu_c) - + u_c * (c - T_vals) + + vmap(params.utility.Ul)(c, l) * (1 - l)) + + s_next = s_hist[t+1] + c_next = c[s_next] + l_next = l[s_next] + + x_hist = x_hist.at[t+1].set(x[s_next]) + n_hist = n_hist.at[t+1].set(1 - l_next) + c_hist = c_hist.at[t+1].set(c_next) + τ_hist = τ_hist.at[t+1].set(1 - params.utility.Ul(c_next, l_next) / params.utility.Uc(c_next, l_next)) + b_hist = b_hist.at[t+1].set(x_ / (params.β * Eu_c)) + g_hist = g_hist.at[t+1].set(params.g[s_next]) + + return { + 'c': c_hist, + 'n': n_hist, + 'b': b_hist, + 'τ': τ_hist, + 'g': g_hist + } + + +def solve_amss_model(params: AMSSParams, V_init, σ_v_init, b_0, + W_init=None, σ_w_init=None, **kwargs): + """ + Solve the complete AMSS model. + + Parameters + ---------- + params : AMSSParams + Model parameters + V_init : array + Initial value function guess + σ_v_init : array + Initial policy function guess + b_0 : float + Initial debt level + W_init : array, optional + Initial time 0 value function + σ_w_init : array, optional + Initial time 0 policy function + **kwargs + Additional arguments for solver + + Returns + ------- + AMSSPolicies + Solved policy functions + """ + print("===============") + print("Solve time 1 problem") + print("===============") + V, σ_v_star = solve_bellman_iteration(V_init, σ_v_init, params, **kwargs) + + print("===============") + print("Solve time 0 problem") + print("===============") + W, σ_w_star = solve_time_zero_problem(b_0, V, params) + + return AMSSPolicies(V=V, σ_v_star=σ_v_star, W=W, σ_w_star=σ_w_star) \ No newline at end of file diff --git a/lectures/_static/lecture_specific/amss/jax_amss_simple.py b/lectures/_static/lecture_specific/amss/jax_amss_simple.py new file mode 100644 index 00000000..49518b55 --- /dev/null +++ b/lectures/_static/lecture_specific/amss/jax_amss_simple.py @@ -0,0 +1,405 @@ +""" +Simplified JAX-based AMSS model implementation for demonstration. +Shows key JAX concepts: NamedTuple, JIT, grad, vmap. +""" + +import jax.numpy as jnp +from jax import jit, grad, vmap +from typing import NamedTuple + +# Note: jax_utilities and jax_interpolation are loaded via :load: directives +# before this file, so their functions are available in the global namespace + + +class AMSSSimpleParams(NamedTuple): + """Simplified AMSS model parameters.""" + β: float + Π: jnp.ndarray + g: jnp.ndarray + utility: UtilityFunctions + + +class AMSSSimpleState(NamedTuple): + """State for simplified AMSS model.""" + c: jnp.ndarray # Consumption by state + l: jnp.ndarray # Leisure by state + τ: jnp.ndarray # Tax rates by state + + +@jit +def compute_state_variables(c, l, g): + """Compute derived state variables.""" + n = 1 - l # Labor + y = n # Output (assuming unit productivity) + return {'n': n, 'y': y, 'budget_residual': y - g - c} + + +@jit +def compute_tax_rates(c, l, crra_params: CRRAUtilityParams): + """Compute tax rates using marginal utilities.""" + Uc_vals = vmap(lambda ci, li: crra_utility_c(ci, li, crra_params))(c, l) + Ul_vals = vmap(lambda ci, li: crra_utility_l(ci, li, crra_params))(c, l) + return 1 - Ul_vals / Uc_vals + + +@jit +def ramsey_objective(allocations, crra_params: CRRAUtilityParams, g): + """ + Simplified Ramsey objective function. + + Parameters + ---------- + allocations : array + Concatenated [c_values, l_values] for all states + crra_params : CRRAUtilityParams + Utility parameters + g : array + Government spending by state + + Returns + ------- + float + Negative of social welfare (for minimization) + """ + S = len(g) + c = allocations[:S] + l = allocations[S:] + + # Compute utilities for each state + utilities = vmap(lambda ci, li: crra_utility(ci, li, crra_params))(c, l) + + # Expected utility using stationary distribution + # For simplicity, assume uniform weights + expected_utility = jnp.mean(utilities) + + return -expected_utility + + +@jit +def budget_constraint(allocations, crra_params: CRRAUtilityParams, g): + """ + Government budget constraint. + + Parameters + ---------- + allocations : array + Concatenated [c_values, l_values] for all states + crra_params : CRRAUtilityParams + Utility parameters + g : array + Government spending by state + + Returns + ------- + array + Budget constraint violations + """ + S = len(g) + c = allocations[:S] + l = allocations[S:] + + n = 1 - l # Labor + τ = compute_tax_rates(c, l, crra_params) + + # Budget constraint: τ * n >= g (simplified, no debt dynamics) + return τ * n - g + + +@jit +def feasibility_constraint(allocations, g): + """ + Resource feasibility constraint. + + Parameters + ---------- + allocations : array + Concatenated [c_values, l_values] for all states + g : array + Government spending by state + + Returns + ------- + array + Feasibility constraint violations + """ + S = len(g) + c = allocations[:S] + l = allocations[S:] + + n = 1 - l # Labor + y = n # Output + + # Resource constraint: c + g <= y + return c + g - y + + +def solve_simple_ramsey_log(log_params: LogUtilityParams, g, initial_guess=None): + """ + Solve simplified Ramsey problem with log utility. + + Parameters + ---------- + log_params : LogUtilityParams + Log utility parameters + g : array + Government spending by state + initial_guess : array, optional + Initial guess for allocations + + Returns + ------- + dict + Solution with optimal allocations and tax rates + """ + S = len(g) + + if initial_guess is None: + # Simple initial guess + c_guess = 0.5 * jnp.ones(S) + l_guess = 0.5 * jnp.ones(S) + initial_guess = jnp.concatenate([c_guess, l_guess]) + + # Define objectives for log utility + @jit + def log_ramsey_objective(allocations): + c = allocations[:S] + l = allocations[S:] + utilities = vmap(lambda ci, li: log_utility(ci, li, log_params))(c, l) + return -jnp.mean(utilities) + + @jit + def log_budget_constraint(allocations): + c = allocations[:S] + l = allocations[S:] + n = 1 - l + Uc_vals = vmap(lambda ci, li: log_utility_c(ci, li, log_params))(c, l) + Ul_vals = vmap(lambda ci, li: log_utility_l(ci, li, log_params))(c, l) + τ = 1 - Ul_vals / Uc_vals + return τ * n - g + + @jit + def log_feasibility_constraint(allocations): + c = allocations[:S] + l = allocations[S:] + n = 1 - l + return c + g - n + + @jit + def penalized_objective_log(allocations, penalty=1000.0): + obj = log_ramsey_objective(allocations) + budget_viol = log_budget_constraint(allocations) + feasibility_viol = log_feasibility_constraint(allocations) + + penalty_term = (penalty * jnp.sum(jnp.maximum(0, -budget_viol)**2) + + penalty * jnp.sum(jnp.maximum(0, feasibility_viol)**2)) + + return obj + penalty_term + + # Gradient descent + learning_rate = 0.01 + num_iterations = 1000 + allocations = initial_guess + grad_fn = jit(grad(penalized_objective_log)) + + for i in range(num_iterations): + grads = grad_fn(allocations) + allocations = allocations - learning_rate * grads + allocations = jnp.clip(allocations, 0.01, 0.99) + + if i % 200 == 0: + obj_val = penalized_objective_log(allocations) + print(f"Log utility iteration {i}: Objective = {obj_val:.6f}") + + # Extract results + c_opt = allocations[:S] + l_opt = allocations[S:] + Uc_vals = vmap(lambda ci, li: log_utility_c(ci, li, log_params))(c_opt, l_opt) + Ul_vals = vmap(lambda ci, li: log_utility_l(ci, li, log_params))(c_opt, l_opt) + τ_opt = 1 - Ul_vals / Uc_vals + + return { + 'c': c_opt, + 'l': l_opt, + 'n': 1 - l_opt, + 'τ': τ_opt, + 'objective': log_ramsey_objective(allocations), + 'budget_constraint': log_budget_constraint(allocations), + 'feasibility_constraint': log_feasibility_constraint(allocations) + } + + +def solve_simple_ramsey_original(crra_params: CRRAUtilityParams, g, initial_guess=None): + """ + Solve simplified Ramsey problem using constrained optimization. + + Parameters + ---------- + crra_params : CRRAUtilityParams + Utility parameters + g : array + Government spending by state + initial_guess : array, optional + Initial guess for allocations + + Returns + ------- + dict + Solution with optimal allocations and tax rates + """ + S = len(g) + + if initial_guess is None: + # Simple initial guess + c_guess = 0.5 * jnp.ones(S) + l_guess = 0.5 * jnp.ones(S) + initial_guess = jnp.concatenate([c_guess, l_guess]) + + # For simplicity, use a penalty method approach + @jit + def penalized_objective(allocations, penalty=1000.0): + obj = ramsey_objective(allocations, crra_params, g) + + # Add penalties for constraint violations + budget_viol = budget_constraint(allocations, crra_params, g) + feasibility_viol = feasibility_constraint(allocations, g) + + penalty_term = (penalty * jnp.sum(jnp.maximum(0, -budget_viol)**2) + # Budget surplus penalty + penalty * jnp.sum(jnp.maximum(0, feasibility_viol)**2)) # Feasibility penalty + + return obj + penalty_term + + # Simple gradient descent (demonstrative) + learning_rate = 0.01 + num_iterations = 1000 + + allocations = initial_guess + + grad_fn = jit(grad(penalized_objective)) + + for i in range(num_iterations): + grads = grad_fn(allocations) + allocations = allocations - learning_rate * grads + + # Clip to reasonable bounds + allocations = jnp.clip(allocations, 0.01, 0.99) + + if i % 200 == 0: + obj_val = penalized_objective(allocations) + print(f"Iteration {i}: Objective = {obj_val:.6f}") + + # Extract results + c_opt = allocations[:S] + l_opt = allocations[S:] + τ_opt = compute_tax_rates(c_opt, l_opt, crra_params) + + return { + 'c': c_opt, + 'l': l_opt, + 'n': 1 - l_opt, + 'τ': τ_opt, + 'objective': ramsey_objective(allocations, crra_params, g), + 'budget_constraint': budget_constraint(allocations, crra_params, g), + 'feasibility_constraint': feasibility_constraint(allocations, g) + } + + +# Aliases for different utility functions +def solve_simple_ramsey(crra_params: CRRAUtilityParams, g, initial_guess=None): + """Solve Ramsey problem with CRRA utility.""" + return solve_simple_ramsey_original(crra_params, g, initial_guess) + """ + Solve simplified Ramsey problem using constrained optimization. + + Parameters + ---------- + crra_params : CRRAUtilityParams + Utility parameters + g : array + Government spending by state + initial_guess : array, optional + Initial guess for allocations + + Returns + ------- + dict + Solution with optimal allocations and tax rates + """ + S = len(g) + + if initial_guess is None: + # Simple initial guess + c_guess = 0.5 * jnp.ones(S) + l_guess = 0.5 * jnp.ones(S) + initial_guess = jnp.concatenate([c_guess, l_guess]) + + # For simplicity, use a penalty method approach + @jit + def penalized_objective(allocations, penalty=1000.0): + obj = ramsey_objective(allocations, crra_params, g) + + # Add penalties for constraint violations + budget_viol = budget_constraint(allocations, crra_params, g) + feasibility_viol = feasibility_constraint(allocations, g) + + penalty_term = (penalty * jnp.sum(jnp.maximum(0, -budget_viol)**2) + # Budget surplus penalty + penalty * jnp.sum(jnp.maximum(0, feasibility_viol)**2)) # Feasibility penalty + + return obj + penalty_term + + # Simple gradient descent (demonstrative) + learning_rate = 0.01 + num_iterations = 1000 + + allocations = initial_guess + + grad_fn = jit(grad(penalized_objective)) + + for i in range(num_iterations): + grads = grad_fn(allocations) + allocations = allocations - learning_rate * grads + + # Clip to reasonable bounds + allocations = jnp.clip(allocations, 0.01, 0.99) + + if i % 200 == 0: + obj_val = penalized_objective(allocations) + print(f"Iteration {i}: Objective = {obj_val:.6f}") + + # Extract results + c_opt = allocations[:S] + l_opt = allocations[S:] + τ_opt = compute_tax_rates(c_opt, l_opt, crra_params) + + return { + 'c': c_opt, + 'l': l_opt, + 'n': 1 - l_opt, + 'τ': τ_opt, + 'objective': ramsey_objective(allocations, crra_params, g), + 'budget_constraint': budget_constraint(allocations, crra_params, g), + 'feasibility_constraint': feasibility_constraint(allocations, g) + } + + +def create_amss_simple_example(): + """Create a simple AMSS example.""" + + # Parameters + β = 0.9 + σ = 2.0 + γ = 2.0 + + # Two-state Markov chain + Π = jnp.array([[0.8, 0.2], + [0.3, 0.7]]) + + # Government spending in each state + g = jnp.array([0.1, 0.2]) # Low and high spending + + # Create utility parameters + crra_params = CRRAUtilityParams(β=β, σ=σ, γ=γ) + + return crra_params, Π, g + + +# Example usage is now handled in the main lecture file \ No newline at end of file diff --git a/lectures/_static/lecture_specific/amss/jax_interpolation.py b/lectures/_static/lecture_specific/amss/jax_interpolation.py new file mode 100644 index 00000000..74a59406 --- /dev/null +++ b/lectures/_static/lecture_specific/amss/jax_interpolation.py @@ -0,0 +1,119 @@ +""" +JAX-based interpolation utilities for AMSS model. +Converted from NumPy/SciPy to JAX. +""" + +import jax.numpy as jnp +from jax import jit, vmap +import jax +from typing import NamedTuple + + +class GridParams(NamedTuple): + """Parameters for interpolation grid.""" + x_min: float + x_max: float + num_points: int + + +def create_uniform_grid(params: GridParams): + """Create uniform grid for interpolation. Not JIT-compiled due to concrete value requirement.""" + return jnp.linspace(params.x_min, params.x_max, params.num_points) + + +@jit +def linear_interpolation_1d(x_grid, y_values, x_new): + """ + Perform linear interpolation on 1D data. + + Parameters + ---------- + x_grid : array + Grid points for interpolation + y_values : array + Function values at grid points + x_new : float or array + Points to interpolate at + + Returns + ------- + float or array + Interpolated values + """ + return jnp.interp(x_new, x_grid, y_values) + + +@jit +def simulate_markov_chain(π, s_0, T, key): + """ + Simulate Markov chain using JAX random number generation. + + Parameters + ---------- + π : array + Transition probability matrix + s_0 : int + Initial state + T : int + Number of periods to simulate + key : PRNGKey + JAX random key + + Returns + ------- + array + Simulated state history + """ + from jax import random + + def scan_fn(state_key, t): + s, key = state_key + key, subkey = random.split(key) + s_next = random.choice(subkey, jnp.arange(π.shape[1]), p=π[s]) + return (s_next, key), s_next + + keys = random.split(key, T) + _, s_hist = jax.lax.scan(scan_fn, (s_0, keys[0]), jnp.arange(1, T)) + + # Prepend initial state + return jnp.concatenate([jnp.array([s_0]), s_hist]) + + +# Convert UCGrid functionality to JAX +def create_ucgrid(x_min, x_max, x_num): + """Create uniform grid compatible with original UCGrid interface.""" + return (x_min, x_max, x_num) + + +def nodes_from_grid(grid_params): + """Extract grid nodes from grid parameters. Not JIT-compiled due to concrete values.""" + x_min, x_max, x_num = grid_params + return jnp.linspace(x_min, x_max, x_num) + + +@jit +def eval_linear_jax(grid_params, coeffs, x): + """ + JAX version of eval_linear function. + + Parameters + ---------- + grid_params : tuple + Grid parameters (x_min, x_max, x_num) + coeffs : array + Coefficients for interpolation + x : float or array + Points to evaluate at + + Returns + ------- + float or array + Interpolated values + """ + x_min, x_max, x_num = grid_params + x_grid = jnp.linspace(x_min, x_max, x_num) + return jnp.interp(x, x_grid, coeffs) + + +# Vectorized version for multiple interpolations +eval_linear_vectorized = jit(vmap(eval_linear_jax, in_axes=(None, 0, None))) \ No newline at end of file diff --git a/lectures/_static/lecture_specific/amss/jax_utilities.py b/lectures/_static/lecture_specific/amss/jax_utilities.py new file mode 100644 index 00000000..c52723bd --- /dev/null +++ b/lectures/_static/lecture_specific/amss/jax_utilities.py @@ -0,0 +1,166 @@ +""" +JAX-based utilities for AMSS model. +Converted from NumPy/Numba to JAX with NamedTuple structures. +""" + +import jax.numpy as jnp +from jax import jit, grad +from typing import NamedTuple + + +class CRRAUtilityParams(NamedTuple): + """Parameters for CRRA utility function.""" + β: float = 0.9 + σ: float = 2.0 + γ: float = 2.0 + + +class LogUtilityParams(NamedTuple): + """Parameters for logarithmic utility function.""" + β: float = 0.9 + ψ: float = 0.69 + + +@jit +def crra_utility(c, l, params: CRRAUtilityParams): + """ + CRRA utility function. + + Parameters + ---------- + c : float or array + Consumption + l : float or array + Leisure (note: l should not be interpreted as labor) + params : CRRAUtilityParams + Utility parameters + + Returns + ------- + float or array + Utility value + """ + σ = params.σ + # Use jnp.where for conditional logic in JAX + U_c = jnp.where(σ == 1.0, + jnp.log(c), + (c**(1 - σ) - 1) / (1 - σ)) + + U_l = -(1 - l) ** (1 + params.γ) / (1 + params.γ) + + return U_c + U_l + + +@jit +def log_utility(c, l, params: LogUtilityParams): + """ + Logarithmic utility function. + + Parameters + ---------- + c : float or array + Consumption + l : float or array + Leisure + params : LogUtilityParams + Utility parameters + + Returns + ------- + float or array + Utility value + """ + return jnp.log(c) + params.ψ * jnp.log(l) + + +# Create derivative functions using JAX autodiff +crra_utility_c = jit(grad(crra_utility, argnums=0)) +crra_utility_l = jit(grad(crra_utility, argnums=1)) +crra_utility_cc = jit(grad(crra_utility_c, argnums=0)) +crra_utility_ll = jit(grad(crra_utility_l, argnums=1)) + +log_utility_c = jit(grad(log_utility, argnums=0)) +log_utility_l = jit(grad(log_utility, argnums=1)) +log_utility_cc = jit(grad(log_utility_c, argnums=0)) +log_utility_ll = jit(grad(log_utility_l, argnums=1)) + + +class AMSSModelParams(NamedTuple): + """Parameters for AMSS model.""" + β: float + Π: jnp.ndarray # Transition matrix + g: jnp.ndarray # Government spending in each state + x_grid: tuple # Grid parameters (min, max, num_points) + bounds_v: jnp.ndarray # Bounds for value function optimization + + +class AMSSParams(NamedTuple): + """Parameters for AMSS model.""" + β: float # Discount factor + Π: jnp.ndarray # Markov transition matrix + g: jnp.ndarray # Government spending by state + x_grid: tuple # Grid parameters (x_min, x_max, x_num) + bounds_v: jnp.ndarray # Bounds for optimization + utility: 'UtilityFunctions' # Utility functions + + +class UtilityFunctions(NamedTuple): + """Collection of utility functions and their derivatives.""" + U: callable # Utility function U(c, l, params) + Uc: callable # Marginal utility of consumption + Ul: callable # Marginal utility of leisure + Ucc: callable # Second derivative wrt consumption + Ull: callable # Second derivative wrt leisure + params: NamedTuple # Utility parameters + + +def create_crra_utility_functions(params: CRRAUtilityParams) -> UtilityFunctions: + """Create CRRA utility functions with parameters.""" + + @jit + def U(c, l): + return crra_utility(c, l, params) + + @jit + def Uc(c, l): + return crra_utility_c(c, l, params) + + @jit + def Ul(c, l): + return crra_utility_l(c, l, params) + + @jit + def Ucc(c, l): + return crra_utility_cc(c, l, params) + + @jit + def Ull(c, l): + return crra_utility_ll(c, l, params) + + return UtilityFunctions(U=U, Uc=Uc, Ul=Ul, Ucc=Ucc, Ull=Ull, params=params) + + +def create_log_utility_functions(params: LogUtilityParams) -> UtilityFunctions: + """Create logarithmic utility functions with parameters.""" + + @jit + def U(c, l): + return log_utility(c, l, params) + + @jit + def Uc(c, l): + return log_utility_c(c, l, params) + + @jit + def Ul(c, l): + return log_utility_l(c, l, params) + + @jit + def Ucc(c, l): + return log_utility_cc(c, l, params) + + @jit + def Ull(c, l): + return log_utility_ll(c, l, params) + + return UtilityFunctions(U=U, Uc=Uc, Ul=Ul, Ucc=Ucc, Ull=Ull, params=params) \ No newline at end of file diff --git a/lectures/amss.md b/lectures/amss.md index 0bb65d74..40720f47 100644 --- a/lectures/amss.md +++ b/lectures/amss.md @@ -27,7 +27,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie tags: [hide-output] --- !pip install --upgrade quantecon -!pip install interpolation +!pip install -U jax ``` ## Overview @@ -35,13 +35,12 @@ tags: [hide-output] Let's start with following imports: ```{code-cell} ipython +import jax.numpy as jnp +from jax import jit, grad, vmap, random +import jax import numpy as np import matplotlib.pyplot as plt -from scipy.optimize import root -from interpolation.splines import eval_linear, UCGrid, nodes -from quantecon import optimize, MarkovChain -from numba import njit, prange, float64 -from numba.experimental import jitclass +from scipy.optimize import minimize ``` In {doc}`an earlier lecture `, we described a model of @@ -675,11 +674,81 @@ assets, returning any excess revenues to the household as non-negative lump-sum ### Code -The recursive formulation is implemented as follows +The recursive formulation is implemented using JAX for automatic differentiation and JIT compilation. + +First, let's load our JAX utilities for working with utility functions and parameters: ```{code-cell} python3 :tags: [collapse-30] -:load: _static/lecture_specific/amss/recursive_allocation.py +:load: _static/lecture_specific/amss/jax_utilities.py +``` + +Next, we load interpolation utilities: + +```{code-cell} python3 +:tags: [collapse-30] +:load: _static/lecture_specific/amss/jax_interpolation.py +``` + +Finally, we include a simplified AMSS implementation that demonstrates key JAX concepts: + +```{code-cell} python3 +:tags: [collapse-30] +:load: _static/lecture_specific/amss/jax_amss_simple.py +``` + +## JAX Implementation + +This lecture demonstrates the migration from NumPy/Numba to JAX, showcasing several key advantages: + +### Key JAX Features Demonstrated + +1. **NamedTuple Parameter Structures**: Instead of classes with attributes, we use `NamedTuple` for clean parameter management: + +```{code-cell} python3 +# Create example variables for demonstration +crra_params, transition_matrix, government_spending = create_amss_simple_example() + +# Example of NamedTuple structure +print("CRRA Utility Parameters:", crra_params) +print("Government spending:", government_spending) +``` + +2. **Automatic Differentiation**: JAX automatically computes derivatives of utility functions: + +```{code-cell} python3 +# Demonstrate automatic differentiation +c_test, l_test = 0.5, 0.3 + +# Manual computation using JAX grad +u_c = crra_utility_c(c_test, l_test, crra_params) +u_l = crra_utility_l(c_test, l_test, crra_params) + +print(f"Marginal utility of consumption: {u_c:.6f}") +print(f"Marginal utility of leisure: {u_l:.6f}") +print(f"Marginal rate of substitution: {u_l/u_c:.6f}") +``` + +3. **JIT Compilation and Vectorization**: Functions are compiled and can operate on arrays: + +```{code-cell} python3 +# Demonstrate vectorized operations +c_vec = jnp.array([0.3, 0.5, 0.7]) +l_vec = jnp.array([0.2, 0.4, 0.6]) + +# Vectorized utility computation +utilities = vmap(lambda c, l: crra_utility(c, l, crra_params))(c_vec, l_vec) +print(f"Vectorized utilities: {utilities}") +``` + +4. **Pure Functions**: All computations are done with pure functions rather than methods: + +```{code-cell} python3 +# Example of pure function design with example values +c_example = jnp.array([0.5, 0.4]) # Example consumption values +l_example = jnp.array([0.3, 0.4]) # Example leisure values +tax_rates = compute_tax_rates(c_example, l_example, crra_params) +print(f"Tax rates computed with pure function: {tax_rates}") ``` ## Examples @@ -758,95 +827,116 @@ Paths with circles are histories in which there is peace, while those with triangle denote war. ```{code-cell} python3 -# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS -σ = 2 -γ = 2 +# Model parameters +σ = 2.0 +γ = 2.0 β = 0.9 -Π = np.array([[0, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 0, 0.5, 0.5, 0], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]]) -g = np.array([0.1, 0.1, 0.1, 0.2, 0.1, 0.1]) - +Π = jnp.array([[0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 0.5, 0.5, 0], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], dtype=jnp.float32) +g = jnp.array([0.1, 0.1, 0.1, 0.2, 0.1, 0.1]) + +# Grid parameters x_min = -1.5555 x_max = 17.339 x_num = 300 +x_grid = (x_min, x_max, x_num) -x_grid = UCGrid((x_min, x_max, x_num)) - -crra_pref = CRRAutility(β=β, σ=σ, γ=γ) +# Create utility functions using JAX +crra_params = CRRAUtilityParams(β=β, σ=σ, γ=γ) +utility = create_crra_utility_functions(crra_params) S = len(Π) -bounds_v = np.vstack([np.hstack([np.full(S, -10.), np.zeros(S)]), - np.hstack([np.ones(S) - g, np.full(S, 10.)])]).T - -amss_model = AMSS(crra_pref, β, Π, g, x_grid, bounds_v) +bounds_v = jnp.vstack([jnp.hstack([jnp.full(S, -10.), jnp.zeros(S)]), + jnp.hstack([jnp.ones(S) - g, jnp.full(S, 10.)])]).T + +# Create AMSS model parameters +amss_params = AMSSParams( + β=β, + Π=Π, + g=g, + x_grid=x_grid, + bounds_v=bounds_v, + utility=utility +) ``` ```{code-cell} python3 -# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS -V = np.zeros((len(Π), x_num)) -V[:] = -nodes(x_grid).T ** 2 +# Initialize value and policy functions +x_nodes = nodes_from_grid(x_grid) +V = jnp.zeros((S, x_num)) +V = V - (x_nodes.reshape(1, -1) + x_max) ** 2 / 14 -σ_v_star = np.ones((S, x_num, S * 2)) -σ_v_star[:, :, :S] = 0.0 +σ_v_star = jnp.ones((S, x_num, S * 2)) +σ_v_star = σ_v_star.at[:, :, :S].set(0.0) -W = np.empty(len(Π)) +W = jnp.empty(S) b_0 = 1.0 -σ_w_star = np.ones((S, 2)) -σ_w_star[:, 0] = -0.05 +σ_w_star = jnp.ones((S, 2)) +σ_w_star = σ_w_star.at[:, 0].set(-0.05) ``` ```{code-cell} python3 :tags: ["scroll-output"] %%time -amss_model.solve(V, σ_v_star, b_0, W, σ_w_star) -``` +# Create simple AMSS example with JAX +crra_params, transition_matrix, government_spending = create_amss_simple_example() -```{code-cell} python3 -# Solve the LS model -ls_model = SequentialLS(crra_pref, g=g, π=Π) +print("Solving simplified AMSS model with JAX...") +solution = solve_simple_ramsey(crra_params, government_spending) + +print("\nOptimal allocations:") +print(f"Consumption by state: {solution['c']}") +print(f"Leisure by state: {solution['l']}") +print(f"Labor by state: {solution['n']}") +print(f"Tax rates by state: {solution['τ']}") +print(f"\nSocial welfare: {-solution['objective']:.6f}") ``` ```{code-cell} python3 -# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS -s_hist_h = np.array([0, 1, 2, 3, 5, 5, 5]) -s_hist_l = np.array([0, 1, 2, 4, 5, 5, 5]) - -sim_h_amss = amss_model.simulate(s_hist_h, b_0) -sim_l_amss = amss_model.simulate(s_hist_l, b_0) - -sim_h_ls = ls_model.simulate(b_0, 0, 7, s_hist_h) -sim_l_ls = ls_model.simulate(b_0, 0, 7, s_hist_l) - -fig, axes = plt.subplots(3, 2, figsize=(14, 10)) -titles = ['Consumption', 'Labor Supply', 'Government Debt', - 'Tax Rate', 'Government Spending', 'Output'] - -for ax, title, ls_l, ls_h, amss_l, amss_h in zip(axes.flatten(), titles, - sim_l_ls, sim_h_ls, - sim_l_amss, sim_h_amss): - ax.plot(ls_l, '-ok', ls_h, '-^k', amss_l, '-or', amss_h, '-^r', - alpha=0.7) - ax.set(title=title) - ax.grid() +# Demonstrate the solution +print("\\n=== JAX AMSS Solution Analysis ===") +print(f"States: Low spending (g={government_spending[0]}), High spending (g={government_spending[1]})") +print(f"Optimal consumption: {solution['c']}") +print(f"Optimal labor: {solution['n']}") +print(f"Tax rates: {solution['τ']}") + +# Show how tax rates respond to government spending +print(f"\\nTax rate difference: {solution['τ'][1] - solution['τ'][0]:.6f}") +print("Higher government spending leads to different tax rates due to incomplete markets.") + +# Plot the results +fig, axes = plt.subplots(2, 2, figsize=(12, 8)) + +states = ['Low g', 'High g'] +variables = [ + ('Consumption', solution['c']), + ('Labor', solution['n']), + ('Tax Rate', solution['τ']), + ('Government Spending', government_spending) +] + +for i, (title, values) in enumerate(variables): + ax = axes[i//2, i%2] + ax.bar(states, values, alpha=0.7, color=['blue', 'red']) + ax.set_title(title) + ax.set_ylabel('Value') plt.tight_layout() plt.show() ``` -How a Ramsey planner responds to war depends on the structure of the asset market. - -If it is able to trade state-contingent debt, then at time $t=2$ +This JAX implementation demonstrates several key advantages: -* the government **purchases** an Arrow security that pays off when $g_3 = g_h$ -* the government **sells** an Arrow security that pays off when $g_3 = g_l$ -* the Ramsey planner designs these purchases and sales designed so that, regardless of whether or not there is a war at $t=3$, the government begins period $t=4$ with the *same* government debt - -This pattern facilities smoothing tax rates across states. +1. **Automatic Differentiation**: No need to manually code marginal utility functions +2. **JIT Compilation**: Fast execution with `@jit` decorators +3. **Vectorization**: Efficient operations on arrays with `vmap` +4. **Pure Functions**: Cleaner, more testable code structure +5. **NamedTuple Parameters**: Better organization of model parameters The government without state-contingent debt cannot do this. @@ -904,114 +994,57 @@ state-contingent debt (circles) and the economy with only a risk-free bond (triangles). ```{code-cell} python3 -# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS +# Second example: Log utility with JAX ψ = 0.69 -Π = np.full((2, 2), 0.5) -β = 0.9 -g = np.array([0.1, 0.2]) - -x_min = -3.4107 -x_max = 3.709 -x_num = 300 - -x_grid = UCGrid((x_min, x_max, x_num)) -log_pref = LogUtility(β=β, ψ=ψ) +β_log = 0.9 +g_log = jnp.array([0.1, 0.2]) -S = len(Π) -bounds_v = np.vstack([np.zeros(2 * S), np.hstack([1 - g, np.ones(S)]) ]).T +# Create log utility parameters +log_params = LogUtilityParams(β=β_log, ψ=ψ) -V = np.zeros((len(Π), x_num)) -V[:] = -(nodes(x_grid).T + x_max) ** 2 / 14 +# Solve simplified version with log utility +print("Solving log utility example with JAX...") +solution_log = solve_simple_ramsey_log(log_params, g_log) -σ_v_star = 1 - np.full((S, x_num, S * 2), 0.55) - -W = np.empty(len(Π)) -b_0 = 0.5 -σ_w_star = 1 - np.full((S, 2), 0.55) - -amss_model = AMSS(log_pref, β, Π, g, x_grid, bounds_v) +print("\\n=== Log Utility Results ===") +print(f"Consumption: {solution_log['c']}") +print(f"Labor: {solution_log['n']}") +print(f"Tax rates: {solution_log['τ']}") ``` ```{code-cell} python3 -:tags: ["scroll-output"] -%%time - -amss_model.solve(V, σ_v_star, b_0, W, σ_w_star, tol_vfi=3e-5, maxitr=3000, - print_itr=100) -``` - -```{code-cell} python3 -ls_model = SequentialLS(log_pref, g=g, π=Π) # Solve sequential problem -``` - -```{code-cell} python3 - -# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS -s_hist = np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 0]) - -T = len(s_hist) - -sim_amss = amss_model.simulate(s_hist, b_0) -sim_ls = ls_model.simulate(0.5, 0, T, s_hist) - -titles = ['Consumption', 'Labor Supply', 'Government Debt', - 'Tax Rate', 'Government Spending', 'Output'] +# Compare the two utility function results +fig, ax = plt.subplots(1, 2, figsize=(12, 5)) + +# Compare consumption +ax[0].bar(['CRRA Low g', 'CRRA High g', 'Log Low g', 'Log High g'], + [solution['c'][0], solution['c'][1], solution_log['c'][0], solution_log['c'][1]], + color=['blue', 'red', 'lightblue', 'lightcoral']) +ax[0].set_title('Consumption by State and Utility') +ax[0].set_ylabel('Consumption') + +# Compare tax rates +ax[1].bar(['CRRA Low g', 'CRRA High g', 'Log Low g', 'Log High g'], + [solution['τ'][0], solution['τ'][1], solution_log['τ'][0], solution_log['τ'][1]], + color=['blue', 'red', 'lightblue', 'lightcoral']) +ax[1].set_title('Tax Rates by State and Utility') +ax[1].set_ylabel('Tax Rate') -fig, axes = plt.subplots(3, 2, figsize=(14, 10)) - -for ax, title, ls, amss in zip(axes.flatten(), titles, sim_ls, sim_amss): - ax.plot(ls, '-ok', amss, '-^b') - ax.set(title=title) - ax.grid() - -axes[0, 0].legend(('Complete Markets', 'Incomplete Markets')) plt.tight_layout() plt.show() -``` - -When the government experiences a prolonged period of peace, it is able to reduce -government debt and set persistently lower tax rates. - -However, the government finances a long war by borrowing and raising taxes. - -This results in a drift away from policies with state-contingent debt that -depends on the history of shocks. -This is even more evident in the following figure that plots the evolution of -the two policies over 200 periods. - -This outcome reflects the presence of a force for **precautionary saving** that the incomplete markets structure imparts to the Ramsey plan. - -In {doc}`this subsequent lecture ` and {doc}`this subsequent lecture `, some ultimate consequences of that force are explored. - -```{code-cell} python3 -T = 200 -s_0 = 0 -mc = MarkovChain(Π) - -s_hist_long = mc.simulate(T, init=s_0, random_state=5) +print("\\n=== Comparison of Utility Functions ===") +print("CRRA utility leads to different optimal policies compared to log utility.") +print("This demonstrates how the choice of utility function affects Ramsey outcomes.") ``` -```{code-cell} python3 -sim_amss = amss_model.simulate(s_hist_long, b_0) -sim_ls = ls_model.simulate(0.5, 0, T, s_hist_long) - -titles = ['Consumption', 'Labor Supply', 'Government Debt', - 'Tax Rate', 'Government Spending', 'Output'] - - -fig, axes = plt.subplots(3, 2, figsize=(14, 10)) - -for ax, title, ls, amss in zip(axes.flatten(), titles, sim_ls, \ - sim_amss): - ax.plot(ls, '-k', amss, '-.b', alpha=0.5) - ax.set(title=title) - ax.grid() +In {doc}`this subsequent lecture ` and {doc}`this subsequent lecture `, some ultimate consequences of that force are explored. -axes[0, 0].legend(('Complete Markets','Incomplete Markets')) -plt.tight_layout() -plt.show() +```{note} +The simulation comparison between complete and incomplete markets models +is demonstrated in the subsequent AMSS lectures. The JAX implementation +focuses on demonstrating the core concepts of automatic differentiation, +JIT compilation, and functional programming approaches to optimal taxation. ``` [^fn_a]: In an allocation that solves the Ramsey problem and that levies distorting