diff --git a/src/policyengine_api/api/agent.py b/src/policyengine_api/api/agent.py index c53e17c..03c060a 100644 --- a/src/policyengine_api/api/agent.py +++ b/src/policyengine_api/api/agent.py @@ -154,15 +154,13 @@ def _parse_claude_stream_event(line: str) -> dict | None: async def _stream_modal_sandbox(question: str, api_base_url: str): """Stream output from Claude Code running in Modal Sandbox.""" - import queue - import threading from concurrent.futures import ThreadPoolExecutor with logfire.span( "agent_stream", question=question[:100], api_base_url=api_base_url ): sb = None - executor = ThreadPoolExecutor(max_workers=1) + executor = ThreadPoolExecutor(max_workers=2) try: from policyengine_api.agent_sandbox import run_claude_code_in_sandbox @@ -174,91 +172,47 @@ async def _stream_modal_sandbox(question: str, api_base_url: str): ) logfire.info("sandbox_created") - line_queue = queue.Queue() - lines_received = 0 - - def stream_reader(): - nonlocal lines_received + # Read stdout synchronously in executor, yield lines as we get them + def read_next_line(stdout_iter): try: - logfire.info("reader_started") - for line in process.stdout: - lines_received += 1 - # Log line length and first 500 chars (avoid scrubbing) - line_preview = ( - line[:500].replace("session", "sess1on") if line else None - ) - # Check if multiple JSON objects concatenated (embedded newlines) - newline_count = line.count("\n") if line else 0 - logfire.info( - "raw_line", - line_num=lines_received, - line_len=len(line) if line else 0, - newline_count=newline_count, - line_preview=line_preview, - ) - line_queue.put(("line", line)) - logfire.info("stdout_exhausted", total_lines=lines_received) - process.wait() - logfire.info("process_exited", returncode=process.returncode) - if process.returncode != 0: - stderr = process.stderr.read() - logfire.error( - "process_failed", - returncode=process.returncode, - stderr=stderr[:500] if stderr else None, - ) - line_queue.put(("error", (process.returncode, stderr))) - else: - line_queue.put(("done", process.returncode)) - except Exception as e: - logfire.exception("reader_error", error=str(e)) - line_queue.put(("exception", str(e))) - - reader_thread = threading.Thread(target=stream_reader, daemon=True) - reader_thread.start() + return next(stdout_iter) + except StopIteration: + return None + stdout_iter = iter(process.stdout) + lines_received = 0 events_sent = 0 + while True: - try: - item = await loop.run_in_executor( - executor, lambda: line_queue.get(timeout=0.1) - ) - event_type, data = item - - if event_type == "line": - parsed = _parse_claude_stream_event(data) - if parsed: - events_sent += 1 - logfire.info( - "event", - num=events_sent, - type=parsed["type"], - content=parsed["content"][:200] - if parsed["content"] - else None, - ) - yield f"data: {json.dumps(parsed)}\n\n" - elif event_type == "error": - returncode, stderr = data - yield f"data: {json.dumps({'type': 'error', 'content': stderr})}\n\n" - yield f"data: {json.dumps({'type': 'done', 'returncode': returncode})}\n\n" - break - elif event_type == "done": - logfire.info( - "complete", - returncode=data, - events_sent=events_sent, - lines_received=lines_received, - ) - yield f"data: {json.dumps({'type': 'done', 'returncode': data})}\n\n" - break - elif event_type == "exception": - raise Exception(data) - except Exception as e: - if "Empty" in type(e).__name__: - await asyncio.sleep(0) - continue - raise + line = await loop.run_in_executor(executor, read_next_line, stdout_iter) + + if line is None: + # stdout exhausted + logfire.info("stdout_exhausted", total_lines=lines_received) + break + + lines_received += 1 + logfire.info( + "raw_line", + line_num=lines_received, + line_len=len(line), + line_preview=line[:300].replace("session", "sess1on"), + ) + + parsed = _parse_claude_stream_event(line) + if parsed: + events_sent += 1 + yield f"data: {json.dumps(parsed)}\n\n" + + # Wait for process to finish + returncode = await loop.run_in_executor(executor, process.wait) + logfire.info( + "complete", + returncode=returncode, + events_sent=events_sent, + lines_received=lines_received, + ) + yield f"data: {json.dumps({'type': 'done', 'returncode': returncode})}\n\n" except Exception as e: logfire.exception("failed", error=str(e))