diff --git a/scripts/seed.py b/scripts/seed.py index f3fbfa8..24d4843 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -206,14 +206,13 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer f" [green]✓[/green] Added {len(model_version.variables)} variables" ) - # Add parameters (only user-facing ones: those with labels) - # Deduplicate by name - keep first occurrence + # Add parameters - deduplicate by name (keep first occurrence) # In lite mode, exclude US state parameters (gov.states.*) seen_names = set() parameters_to_add = [] skipped_state_params = 0 for p in model_version.parameters: - if p.label is None or p.name in seen_names: + if p.name in seen_names: continue # In lite mode, skip state-level parameters for faster seeding if lite and p.name.startswith("gov.states."): @@ -222,8 +221,10 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer parameters_to_add.append(p) seen_names.add(p.name) - filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters" - filter_msg += f" (from {len(model_version.parameters)} total, deduplicated by name)" + filter_msg = f" Filtered to {len(parameters_to_add)} parameters" + filter_msg += ( + f" (from {len(model_version.parameters)} total, deduplicated by name)" + ) if lite and skipped_state_params > 0: filter_msg += f", skipped {skipped_state_params} state params (lite mode)" console.print(filter_msg) @@ -626,7 +627,9 @@ def main(): with logfire.span("database_seeding"): mode_str = " (lite mode)" if args.lite else "" - console.print(f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n") + console.print( + f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n" + ) with next(get_quiet_session()) as session: # Seed UK model diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index 57f7faf..604125a 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -235,8 +235,7 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]: prop = schema_to_json_schema(spec, param_schema) prop["description"] = ( - param.get("description", "") - + f" (in: {param_in})" + param.get("description", "") + f" (in: {param_in})" ) properties[param_name] = prop @@ -268,16 +267,18 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]: if required: input_schema["required"] = list(set(required)) - tools.append({ - "name": tool_name, - "description": full_desc[:1024], # Claude has limits - "input_schema": input_schema, - "_meta": { - "path": path, - "method": method, - "parameters": operation.get("parameters", []), - }, - }) + tools.append( + { + "name": tool_name, + "description": full_desc[:1024], # Claude has limits + "input_schema": input_schema, + "_meta": { + "path": path, + "method": method, + "parameters": operation.get("parameters", []), + }, + } + ) return tools @@ -347,7 +348,9 @@ def execute_api_tool( url, params=query_params, json=body_data, headers=headers, timeout=60 ) elif method == "delete": - resp = requests.delete(url, params=query_params, headers=headers, timeout=60) + resp = requests.delete( + url, params=query_params, headers=headers, timeout=60 + ) else: return f"Unsupported method: {method}" @@ -415,9 +418,7 @@ def log(msg: str) -> None: tool_lookup = {t["name"]: t for t in tools} # Strip _meta from tools before sending to Claude (it doesn't need it) - claude_tools = [ - {k: v for k, v in t.items() if k != "_meta"} for t in tools - ] + claude_tools = [{k: v for k, v in t.items() if k != "_meta"} for t in tools] # Add the sleep tool claude_tools.append(SLEEP_TOOL) @@ -477,11 +478,13 @@ def log(msg: str) -> None: log(f"[TOOL_RESULT] {result[:300]}") - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) messages.append({"role": "assistant", "content": assistant_content}) diff --git a/src/policyengine_api/api/agent.py b/src/policyengine_api/api/agent.py index 7b7d108..6c26e80 100644 --- a/src/policyengine_api/api/agent.py +++ b/src/policyengine_api/api/agent.py @@ -24,6 +24,7 @@ def get_traceparent() -> str | None: TraceContextTextMapPropagator().inject(carrier) return carrier.get("traceparent") + router = APIRouter(prefix="/agent", tags=["agent"]) @@ -93,7 +94,9 @@ def _run_local_agent( from policyengine_api.agent_sandbox import _run_agent_impl try: - history_dicts = [{"role": m.role, "content": m.content} for m in (history or [])] + history_dicts = [ + {"role": m.role, "content": m.content} for m in (history or []) + ] result = _run_agent_impl(question, api_base_url, call_id, history_dicts) _calls[call_id]["status"] = result.get("status", "completed") _calls[call_id]["result"] = result @@ -136,9 +139,15 @@ async def run_agent(request: RunRequest) -> RunResponse: traceparent = get_traceparent() run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent") - history_dicts = [{"role": m.role, "content": m.content} for m in request.history] + history_dicts = [ + {"role": m.role, "content": m.content} for m in request.history + ] call = run_fn.spawn( - request.question, api_base_url, call_id, history_dicts, traceparent=traceparent + request.question, + api_base_url, + call_id, + history_dicts, + traceparent=traceparent, ) _calls[call_id] = { @@ -166,7 +175,12 @@ async def run_agent(request: RunRequest) -> RunResponse: # Run in background using asyncio loop = asyncio.get_event_loop() loop.run_in_executor( - None, _run_local_agent, call_id, request.question, api_base_url, request.history + None, + _run_local_agent, + call_id, + request.question, + api_base_url, + request.history, ) return RunResponse(call_id=call_id, status="running") diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index b0e99c9..09ea701 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -28,6 +28,7 @@ def get_traceparent() -> str | None: TraceContextTextMapPropagator().inject(carrier) return carrier.get("traceparent") + router = APIRouter(prefix="/household", tags=["household"]) diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index d660b1b..3c24f3d 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -56,9 +56,9 @@ def list_variables( # Case-insensitive search using ILIKE # Note: Variables don't have a label field, only name and description search_pattern = f"%{search}%" - search_filter = Variable.name.ilike(search_pattern) | Variable.description.ilike( + search_filter = Variable.name.ilike( search_pattern - ) + ) | Variable.description.ilike(search_pattern) query = query.where(search_filter) variables = session.exec( diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index cdc3d8d..3a7b9a7 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -147,7 +147,11 @@ def download_dataset( @app.function( - image=uk_image, secrets=[db_secrets, logfire_secrets], memory=4096, cpu=4, timeout=600 + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, ) def simulate_household_uk( job_id: str, @@ -277,7 +281,11 @@ def simulate_household_uk( @app.function( - image=us_image, secrets=[db_secrets, logfire_secrets], memory=4096, cpu=4, timeout=600 + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=4096, + cpu=4, + timeout=600, ) def simulate_household_us( job_id: str, @@ -416,7 +424,11 @@ def simulate_household_us( @app.function( - image=uk_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=8192, + cpu=8, + timeout=1800, ) def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> None: """Run a single UK economy simulation and write results to database.""" @@ -445,6 +457,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N Simulation, SimulationStatus, ) + with Session(engine) as session: simulation = session.get(Simulation, UUID(simulation_id)) if not simulation: @@ -452,7 +465,9 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N # Skip if already completed if simulation.status == SimulationStatus.COMPLETED: - logfire.info("Simulation already completed", simulation_id=simulation_id) + logfire.info( + "Simulation already completed", simulation_id=simulation_id + ) return # Update status to running @@ -475,7 +490,9 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N pe_model_version = uk_latest # Get policy and dynamic - policy = _get_pe_policy_uk(simulation.policy_id, pe_model_version, session) + policy = _get_pe_policy_uk( + simulation.policy_id, pe_model_version, session + ) dynamic = _get_pe_dynamic_uk( simulation.dynamic_id, pe_model_version, session ) @@ -512,7 +529,11 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N logfire.info("UK economy simulation completed", simulation_id=simulation_id) except Exception as e: - logfire.error("UK economy simulation failed", simulation_id=simulation_id, error=str(e)) + logfire.error( + "UK economy simulation failed", + simulation_id=simulation_id, + error=str(e), + ) # Use raw SQL to mark as failed - models may not be available try: from sqlmodel import text @@ -534,7 +555,11 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N @app.function( - image=us_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=8192, + cpu=8, + timeout=1800, ) def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> None: """Run a single US economy simulation and write results to database.""" @@ -563,6 +588,7 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N Simulation, SimulationStatus, ) + with Session(engine) as session: simulation = session.get(Simulation, UUID(simulation_id)) if not simulation: @@ -570,7 +596,9 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N # Skip if already completed if simulation.status == SimulationStatus.COMPLETED: - logfire.info("Simulation already completed", simulation_id=simulation_id) + logfire.info( + "Simulation already completed", simulation_id=simulation_id + ) return # Update status to running @@ -593,7 +621,9 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N pe_model_version = us_latest # Get policy and dynamic - policy = _get_pe_policy_us(simulation.policy_id, pe_model_version, session) + policy = _get_pe_policy_us( + simulation.policy_id, pe_model_version, session + ) dynamic = _get_pe_dynamic_us( simulation.dynamic_id, pe_model_version, session ) @@ -630,7 +660,11 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N logfire.info("US economy simulation completed", simulation_id=simulation_id) except Exception as e: - logfire.error("US economy simulation failed", simulation_id=simulation_id, error=str(e)) + logfire.error( + "US economy simulation failed", + simulation_id=simulation_id, + error=str(e), + ) # Use raw SQL to mark as failed - models may not be available try: from sqlmodel import text @@ -652,7 +686,11 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N @app.function( - image=uk_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=8192, + cpu=8, + timeout=1800, ) def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: """Run UK economy comparison analysis (decile impacts, budget impact, etc).""" @@ -690,6 +728,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: SimulationStatus, TaxBenefitModelVersion, ) + with Session(engine) as session: # Load report and related data report = session.get(Report, UUID(job_id)) @@ -851,7 +890,10 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: ) session.add(program_stat) except KeyError as e: - logfire.warn(f"Skipping {prog_name}: variable not found", error=str(e)) + logfire.warn( + f"Skipping {prog_name}: variable not found", + error=str(e), + ) # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED @@ -890,7 +932,11 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: @app.function( - image=us_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=8192, + cpu=8, + timeout=1800, ) def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: """Run US economy comparison analysis (decile impacts, budget impact, etc).""" @@ -1036,7 +1082,9 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: # Calculate program statistics with logfire.span("calculate_program_statistics"): - PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + PEProgramStats.model_rebuild( + _types_namespace={"Simulation": PESimulation} + ) programs = { "income_tax": {"entity": "tax_unit", "is_tax": True}, @@ -1074,7 +1122,10 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: ) session.add(program_stat) except KeyError as e: - logfire.warn(f"Skipping {prog_name}: variable not found", error=str(e)) + logfire.warn( + f"Skipping {prog_name}: variable not found", + error=str(e), + ) # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED diff --git a/test_fixtures/fixtures_parameters.py b/test_fixtures/fixtures_parameters.py index ff69b0e..12995b3 100644 --- a/test_fixtures/fixtures_parameters.py +++ b/test_fixtures/fixtures_parameters.py @@ -41,8 +41,17 @@ def model_version(session): # ----------------------------------------------------------------------------- -def create_parameter(session, model_version, name: str, label: str) -> Parameter: - """Create and persist a Parameter.""" +def create_parameter( + session, model_version, name: str, label: str | None = None +) -> Parameter: + """Create and persist a Parameter. + + Args: + session: The database session. + model_version: The TaxBenefitModelVersion to associate with. + name: The parameter name (e.g., "gov.irs.income.bracket.rates.1"). + label: Optional human-readable label. If None, parameter has no label. + """ param = Parameter( name=name, label=label, diff --git a/tests/test_agent_policy_questions.py b/tests/test_agent_policy_questions.py index 1550f89..68d80d0 100644 --- a/tests/test_agent_policy_questions.py +++ b/tests/test_agent_policy_questions.py @@ -218,4 +218,6 @@ def test_turn_efficiency(self, question, max_expected_turns): print(f"Result: {result['result'][:300]}") if result["turns"] > max_expected_turns: - print(f"WARNING: Took {result['turns']} turns, expected <= {max_expected_turns}") + print( + f"WARNING: Took {result['turns']} turns, expected <= {max_expected_turns}" + ) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index f95016b..72cdd8f 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -42,6 +42,60 @@ def test__given_nonexistent_parameter_id__then_returns_404(client): assert response.status_code == 404 +def test__given_parameter_without_label__then_can_create_and_retrieve( + client, + session, + model_version, # noqa: F811 +): + """Parameters without labels can be created and retrieved via API. + + This tests that the API supports parameters with label=None, which is + important for bracket parameters (e.g., [0].rate) and breakdown parameters + (e.g., .SINGLE, .CA) that don't have explicit labels in their YAML files. + """ + # Given - create a parameter without a label (simulating bracket/breakdown params) + param = create_parameter( + session, model_version, "gov.test.rates[0].rate", label=None + ) + + # When + response = client.get(f"/parameters/{param.id}") + + # Then + assert response.status_code == 200 + data = response.json() + assert data["name"] == "gov.test.rates[0].rate" + assert data["label"] is None + + +def test__given_mixed_label_parameters__then_returns_all( + client, + session, + model_version, # noqa: F811 +): + """Parameters with and without labels are both returned in list queries. + + This ensures the API doesn't filter out parameters based on label presence. + """ + # Given - create parameters with and without labels + param_with_label = create_parameter( + session, model_version, "gov.test.labeled", label="Test labeled param" + ) + param_without_label = create_parameter( + session, model_version, "gov.test.unlabeled", label=None + ) + + # When + response = client.get("/parameters") + + # Then + assert response.status_code == 200 + data = response.json() + param_names = [p["name"] for p in data] + assert param_with_label.name in param_names + assert param_without_label.name in param_names + + # ----------------------------------------------------------------------------- # Parameter Value Endpoint Tests # -----------------------------------------------------------------------------