Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions scripts/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,13 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer
f" [green]✓[/green] Added {len(model_version.variables)} variables"
)

# Add parameters (only user-facing ones: those with labels)
# Deduplicate by name - keep first occurrence
# Add parameters - deduplicate by name (keep first occurrence)
# In lite mode, exclude US state parameters (gov.states.*)
seen_names = set()
parameters_to_add = []
skipped_state_params = 0
for p in model_version.parameters:
if p.label is None or p.name in seen_names:
if p.name in seen_names:
continue
# In lite mode, skip state-level parameters for faster seeding
if lite and p.name.startswith("gov.states."):
Expand All @@ -222,8 +221,10 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer
parameters_to_add.append(p)
seen_names.add(p.name)

filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters"
filter_msg += f" (from {len(model_version.parameters)} total, deduplicated by name)"
filter_msg = f" Filtered to {len(parameters_to_add)} parameters"
filter_msg += (
f" (from {len(model_version.parameters)} total, deduplicated by name)"
)
if lite and skipped_state_params > 0:
filter_msg += f", skipped {skipped_state_params} state params (lite mode)"
console.print(filter_msg)
Expand Down Expand Up @@ -626,7 +627,9 @@ def main():

with logfire.span("database_seeding"):
mode_str = " (lite mode)" if args.lite else ""
console.print(f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n")
console.print(
f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n"
)

with next(get_quiet_session()) as session:
# Seed UK model
Expand Down
45 changes: 24 additions & 21 deletions src/policyengine_api/agent_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]:

prop = schema_to_json_schema(spec, param_schema)
prop["description"] = (
param.get("description", "")
+ f" (in: {param_in})"
param.get("description", "") + f" (in: {param_in})"
)
properties[param_name] = prop

Expand Down Expand Up @@ -268,16 +267,18 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]:
if required:
input_schema["required"] = list(set(required))

tools.append({
"name": tool_name,
"description": full_desc[:1024], # Claude has limits
"input_schema": input_schema,
"_meta": {
"path": path,
"method": method,
"parameters": operation.get("parameters", []),
},
})
tools.append(
{
"name": tool_name,
"description": full_desc[:1024], # Claude has limits
"input_schema": input_schema,
"_meta": {
"path": path,
"method": method,
"parameters": operation.get("parameters", []),
},
}
)

return tools

Expand Down Expand Up @@ -347,7 +348,9 @@ def execute_api_tool(
url, params=query_params, json=body_data, headers=headers, timeout=60
)
elif method == "delete":
resp = requests.delete(url, params=query_params, headers=headers, timeout=60)
resp = requests.delete(
url, params=query_params, headers=headers, timeout=60
)
else:
return f"Unsupported method: {method}"

Expand Down Expand Up @@ -415,9 +418,7 @@ def log(msg: str) -> None:
tool_lookup = {t["name"]: t for t in tools}

# Strip _meta from tools before sending to Claude (it doesn't need it)
claude_tools = [
{k: v for k, v in t.items() if k != "_meta"} for t in tools
]
claude_tools = [{k: v for k, v in t.items() if k != "_meta"} for t in tools]
# Add the sleep tool
claude_tools.append(SLEEP_TOOL)

Expand Down Expand Up @@ -477,11 +478,13 @@ def log(msg: str) -> None:

log(f"[TOOL_RESULT] {result[:300]}")

tool_results.append({
"type": "tool_result",
"tool_use_id": block.id,
"content": result,
})
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": result,
}
)

messages.append({"role": "assistant", "content": assistant_content})

Expand Down
22 changes: 18 additions & 4 deletions src/policyengine_api/api/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_traceparent() -> str | None:
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")


router = APIRouter(prefix="/agent", tags=["agent"])


Expand Down Expand Up @@ -93,7 +94,9 @@ def _run_local_agent(
from policyengine_api.agent_sandbox import _run_agent_impl

try:
history_dicts = [{"role": m.role, "content": m.content} for m in (history or [])]
history_dicts = [
{"role": m.role, "content": m.content} for m in (history or [])
]
result = _run_agent_impl(question, api_base_url, call_id, history_dicts)
_calls[call_id]["status"] = result.get("status", "completed")
_calls[call_id]["result"] = result
Expand Down Expand Up @@ -136,9 +139,15 @@ async def run_agent(request: RunRequest) -> RunResponse:

traceparent = get_traceparent()
run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent")
history_dicts = [{"role": m.role, "content": m.content} for m in request.history]
history_dicts = [
{"role": m.role, "content": m.content} for m in request.history
]
call = run_fn.spawn(
request.question, api_base_url, call_id, history_dicts, traceparent=traceparent
request.question,
api_base_url,
call_id,
history_dicts,
traceparent=traceparent,
)

_calls[call_id] = {
Expand Down Expand Up @@ -166,7 +175,12 @@ async def run_agent(request: RunRequest) -> RunResponse:
# Run in background using asyncio
loop = asyncio.get_event_loop()
loop.run_in_executor(
None, _run_local_agent, call_id, request.question, api_base_url, request.history
None,
_run_local_agent,
call_id,
request.question,
api_base_url,
request.history,
)

return RunResponse(call_id=call_id, status="running")
Expand Down
1 change: 1 addition & 0 deletions src/policyengine_api/api/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_traceparent() -> str | None:
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")


router = APIRouter(prefix="/household", tags=["household"])


Expand Down
4 changes: 2 additions & 2 deletions src/policyengine_api/api/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def list_variables(
# Case-insensitive search using ILIKE
# Note: Variables don't have a label field, only name and description
search_pattern = f"%{search}%"
search_filter = Variable.name.ilike(search_pattern) | Variable.description.ilike(
search_filter = Variable.name.ilike(
search_pattern
)
) | Variable.description.ilike(search_pattern)
query = query.where(search_filter)

variables = session.exec(
Expand Down
81 changes: 66 additions & 15 deletions src/policyengine_api/modal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ def download_dataset(


@app.function(
image=uk_image, secrets=[db_secrets, logfire_secrets], memory=4096, cpu=4, timeout=600
image=uk_image,
secrets=[db_secrets, logfire_secrets],
memory=4096,
cpu=4,
timeout=600,
)
def simulate_household_uk(
job_id: str,
Expand Down Expand Up @@ -277,7 +281,11 @@ def simulate_household_uk(


@app.function(
image=us_image, secrets=[db_secrets, logfire_secrets], memory=4096, cpu=4, timeout=600
image=us_image,
secrets=[db_secrets, logfire_secrets],
memory=4096,
cpu=4,
timeout=600,
)
def simulate_household_us(
job_id: str,
Expand Down Expand Up @@ -416,7 +424,11 @@ def simulate_household_us(


@app.function(
image=uk_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800
image=uk_image,
secrets=[db_secrets, logfire_secrets],
memory=8192,
cpu=8,
timeout=1800,
)
def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> None:
"""Run a single UK economy simulation and write results to database."""
Expand Down Expand Up @@ -445,14 +457,17 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N
Simulation,
SimulationStatus,
)

with Session(engine) as session:
simulation = session.get(Simulation, UUID(simulation_id))
if not simulation:
raise ValueError(f"Simulation {simulation_id} not found")

# Skip if already completed
if simulation.status == SimulationStatus.COMPLETED:
logfire.info("Simulation already completed", simulation_id=simulation_id)
logfire.info(
"Simulation already completed", simulation_id=simulation_id
)
return

# Update status to running
Expand All @@ -475,7 +490,9 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N
pe_model_version = uk_latest

# Get policy and dynamic
policy = _get_pe_policy_uk(simulation.policy_id, pe_model_version, session)
policy = _get_pe_policy_uk(
simulation.policy_id, pe_model_version, session
)
dynamic = _get_pe_dynamic_uk(
simulation.dynamic_id, pe_model_version, session
)
Expand Down Expand Up @@ -512,7 +529,11 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N
logfire.info("UK economy simulation completed", simulation_id=simulation_id)

except Exception as e:
logfire.error("UK economy simulation failed", simulation_id=simulation_id, error=str(e))
logfire.error(
"UK economy simulation failed",
simulation_id=simulation_id,
error=str(e),
)
# Use raw SQL to mark as failed - models may not be available
try:
from sqlmodel import text
Expand All @@ -534,7 +555,11 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N


@app.function(
image=us_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800
image=us_image,
secrets=[db_secrets, logfire_secrets],
memory=8192,
cpu=8,
timeout=1800,
)
def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> None:
"""Run a single US economy simulation and write results to database."""
Expand Down Expand Up @@ -563,14 +588,17 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N
Simulation,
SimulationStatus,
)

with Session(engine) as session:
simulation = session.get(Simulation, UUID(simulation_id))
if not simulation:
raise ValueError(f"Simulation {simulation_id} not found")

# Skip if already completed
if simulation.status == SimulationStatus.COMPLETED:
logfire.info("Simulation already completed", simulation_id=simulation_id)
logfire.info(
"Simulation already completed", simulation_id=simulation_id
)
return

# Update status to running
Expand All @@ -593,7 +621,9 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N
pe_model_version = us_latest

# Get policy and dynamic
policy = _get_pe_policy_us(simulation.policy_id, pe_model_version, session)
policy = _get_pe_policy_us(
simulation.policy_id, pe_model_version, session
)
dynamic = _get_pe_dynamic_us(
simulation.dynamic_id, pe_model_version, session
)
Expand Down Expand Up @@ -630,7 +660,11 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N
logfire.info("US economy simulation completed", simulation_id=simulation_id)

except Exception as e:
logfire.error("US economy simulation failed", simulation_id=simulation_id, error=str(e))
logfire.error(
"US economy simulation failed",
simulation_id=simulation_id,
error=str(e),
)
# Use raw SQL to mark as failed - models may not be available
try:
from sqlmodel import text
Expand All @@ -652,7 +686,11 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N


@app.function(
image=uk_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800
image=uk_image,
secrets=[db_secrets, logfire_secrets],
memory=8192,
cpu=8,
timeout=1800,
)
def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
"""Run UK economy comparison analysis (decile impacts, budget impact, etc)."""
Expand Down Expand Up @@ -690,6 +728,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
SimulationStatus,
TaxBenefitModelVersion,
)

with Session(engine) as session:
# Load report and related data
report = session.get(Report, UUID(job_id))
Expand Down Expand Up @@ -851,7 +890,10 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
)
session.add(program_stat)
except KeyError as e:
logfire.warn(f"Skipping {prog_name}: variable not found", error=str(e))
logfire.warn(
f"Skipping {prog_name}: variable not found",
error=str(e),
)

# Mark simulations and report as completed
baseline_sim.status = SimulationStatus.COMPLETED
Expand Down Expand Up @@ -890,7 +932,11 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:


@app.function(
image=us_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800
image=us_image,
secrets=[db_secrets, logfire_secrets],
memory=8192,
cpu=8,
timeout=1800,
)
def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
"""Run US economy comparison analysis (decile impacts, budget impact, etc)."""
Expand Down Expand Up @@ -1036,7 +1082,9 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:

# Calculate program statistics
with logfire.span("calculate_program_statistics"):
PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation})
PEProgramStats.model_rebuild(
_types_namespace={"Simulation": PESimulation}
)

programs = {
"income_tax": {"entity": "tax_unit", "is_tax": True},
Expand Down Expand Up @@ -1074,7 +1122,10 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
)
session.add(program_stat)
except KeyError as e:
logfire.warn(f"Skipping {prog_name}: variable not found", error=str(e))
logfire.warn(
f"Skipping {prog_name}: variable not found",
error=str(e),
)

# Mark simulations and report as completed
baseline_sim.status = SimulationStatus.COMPLETED
Expand Down
Loading