From c9b435376944c67c2983a7b919de2ca4c4db2963 Mon Sep 17 00:00:00 2001 From: jemeza-codegen Date: Tue, 18 Mar 2025 18:36:06 -0700 Subject: [PATCH 1/2] fix!: catches circumstance where the LLM response is truncated due to it reaching its max output token limit --- src/codegen/extensions/langchain/graph.py | 16 +++++-- src/codegen/extensions/langchain/llm.py | 2 +- src/codegen/extensions/langchain/tools.py | 33 ++++++++------ .../langchain/utils/custom_tool_node.py | 43 +++++++++++++++++++ src/codegen/extensions/tools/create_file.py | 12 +++++- 5 files changed, 85 insertions(+), 21 deletions(-) create mode 100644 src/codegen/extensions/langchain/utils/custom_tool_node.py diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 2987f6863..697c18e25 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -6,17 +6,24 @@ import anthropic import openai from langchain.tools import BaseTool -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + AnyMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.prompts import ChatPromptTemplate +from langchain_core.stores import InMemoryBaseStore from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START from langgraph.graph.state import CompiledGraph, StateGraph -from langgraph.prebuilt import ToolNode from langgraph.pregel import RetryPolicy from codegen.agents.utils import AgentConfig from codegen.extensions.langchain.llm import LLM from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT +from codegen.extensions.langchain.utils.custom_tool_node import CustomToolNode from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens @@ -87,6 +94,7 @@ def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMe self.config = config self.max_messages = config.get("max_messages", 100) if config else 100 self.keep_first_messages = config.get("keep_first_messages", 1) if config else 1 + self.store = InMemoryBaseStore() # =================================== NODES ==================================== @@ -459,7 +467,7 @@ def get_field_descriptions(tool_obj): # Add nodes builder.add_node("reasoner", self.reasoner, retry=retry_policy) - builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) + builder.add_node("tools", CustomToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) builder.add_node("summarize_conversation", self.summarize_conversation, retry=retry_policy) # Add edges @@ -471,7 +479,7 @@ def get_field_descriptions(tool_obj): ) builder.add_conditional_edges("summarize_conversation", self.should_continue) - return builder.compile(checkpointer=checkpointer, debug=debug) + return builder.compile(checkpointer=checkpointer, store=self.store, debug=debug) def create_react_agent( diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 4c457e46d..dadcf6314 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -89,7 +89,7 @@ def _get_model(self) -> BaseChatModel: if not os.getenv("ANTHROPIC_API_KEY"): msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - max_tokens = 16384 if "claude-3-7" in self.model_name else 8192 + max_tokens = 8192 return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000) elif self.model_provider == "openai": diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index ee17aa691..d7dbc2f06 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -1,9 +1,11 @@ """Langchain tools for workspace operations.""" from collections.abc import Callable -from typing import ClassVar, Literal +from typing import Annotated, ClassVar, Literal +from langchain_core.stores import InMemoryBaseStore from langchain_core.tools.base import BaseTool +from langgraph.prebuilt import InjectedStore from pydantic import BaseModel, Field from codegen.extensions.linear.linear_client import LinearClient @@ -189,11 +191,13 @@ def _run(self, filepath: str, content: str) -> str: class CreateFileInput(BaseModel): """Input for creating a file.""" + model_config = {"arbitrary_types_allowed": True} filepath: str = Field(..., description="Path where to create the file") + store: Annotated[InMemoryBaseStore, InjectedStore()] content: str = Field( - ..., + default="", description=""" -Content for the new file (REQUIRED). +Content for the new file. ⚠️ IMPORTANT: This parameter MUST be a STRING, not a dictionary, JSON object, or any other data type. Example: content="print('Hello world')" @@ -207,19 +211,14 @@ class CreateFileTool(BaseTool): name: ClassVar[str] = "create_file" description: ClassVar[str] = """ -Create a new file in the codebase. Always provide content for the new file, even if minimal. - -⚠️ CRITICAL WARNING ⚠️ -Both parameters MUST be provided as STRINGS: -The content for the new file always needs to be provided. +Create a new file in the codebase. 1. filepath: The path where to create the file (as a string) 2. content: The content for the new file (as a STRING, NOT as a dictionary or JSON object) ✅ CORRECT usage: create_file(filepath="path/to/file.py", content="print('Hello world')") - -The content parameter is REQUIRED and MUST be a STRING. If you receive a validation error about +If you receive a validation error about missing content, you are likely trying to pass a dictionary instead of a string. """ args_schema: ClassVar[type[BaseModel]] = CreateFileInput @@ -228,8 +227,15 @@ class CreateFileTool(BaseTool): def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, filepath: str, content: str) -> str: - result = create_file(self.codebase, filepath, content) + def _run(self, filepath: str, store: InMemoryBaseStore, content: str = "") -> str: + create_file_tool_status = store.mget([self.name])[0] + if create_file_tool_status and create_file_tool_status.get("max_tokens_reached", False): + max_tokens = create_file_tool_status.get("max_tokens", None) + store.mset([(self.name, {"max_tokens": max_tokens, "max_tokens_reached": False})]) + result = create_file(self.codebase, filepath, content, max_tokens=max_tokens) + else: + result = create_file(self.codebase, filepath, content) + return result.render() @@ -1094,6 +1100,7 @@ class SearchFilesByNameInput(BaseModel): page: int = Field(default=1, description="Page number to return (1-based)") files_per_page: int | float = Field(default=10, description="Number of files per page to return, use math.inf to return all files") + class SearchFilesByNameTool(BaseTool): """Tool for searching files by filename across a codebase.""" @@ -1107,8 +1114,6 @@ class SearchFilesByNameTool(BaseTool): args_schema: ClassVar[type[BaseModel]] = SearchFilesByNameInput codebase: Codebase = Field(exclude=True) - - def __init__(self, codebase: Codebase): super().__init__(codebase=codebase) diff --git a/src/codegen/extensions/langchain/utils/custom_tool_node.py b/src/codegen/extensions/langchain/utils/custom_tool_node.py new file mode 100644 index 000000000..bdbe4ab0e --- /dev/null +++ b/src/codegen/extensions/langchain/utils/custom_tool_node.py @@ -0,0 +1,43 @@ +from typing import Any, Literal, Optional, Union + +from langchain_core.messages import ( + AIMessage, + AnyMessage, + ToolCall, +) +from langchain_core.stores import InMemoryBaseStore +from langgraph.prebuilt import ToolNode +from pydantic import BaseModel + + +class CustomToolNode(ToolNode): + """Extended ToolNode that detects truncated tool calls.""" + + def _parse_input( + self, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + store: Optional[InMemoryBaseStore], + ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]: + """Parse the input and check for truncated tool calls.""" + messages = input.get("messages", []) + if isinstance(messages, list): + if isinstance(messages[-1], AIMessage): + response_metadata = messages[-1].response_metadata + # Check if the stop reason is due to max tokens + if response_metadata.get("stop_reason") == "max_tokens": + # Check if the response metadata contains usage information + if "usage" not in response_metadata or "output_tokens" not in response_metadata["usage"]: + msg = "Response metadata is missing usage information." + raise ValueError(msg) + + output_tokens = response_metadata["usage"]["output_tokens"] + for tool_call in messages[-1].tool_calls: + if tool_call.get("name") == "create_file": + # Set the max tokens and max tokens reached flag in the store + store.mset([(tool_call["name"], {"max_tokens": output_tokens, "max_tokens_reached": True})]) + + return super()._parse_input(input, store) diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index 3a54303ff..887a27a02 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -1,6 +1,6 @@ """Tool for creating new files.""" -from typing import ClassVar +from typing import ClassVar, Optional from pydantic import Field @@ -23,7 +23,7 @@ class CreateFileObservation(Observation): str_template: ClassVar[str] = "Created file {filepath}" -def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileObservation: +def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: Optional[int] = None) -> CreateFileObservation: """Create a new file. Args: @@ -34,6 +34,14 @@ def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileOb Returns: CreateFileObservation containing new file state, or error if file exists """ + if max_tokens: + error = f"Your response reached the max output tokens limit of {max_tokens} tokens (~{max_tokens * 0.75 / 20} lines). Consider breaking upt the content into smaller files." + return CreateFileObservation( + status="error", + error=error, + filepath=filepath, + file_info=ViewFileObservation(status="error", error=error, filepath=filepath, content="", line_count=0), + ) if codebase.has_file(filepath): return CreateFileObservation( status="error", From 8deef96d25f6479857569c8d558fe90b63354387 Mon Sep 17 00:00:00 2001 From: jemeza-codegen Date: Wed, 19 Mar 2025 15:16:29 -0700 Subject: [PATCH 2/2] fix: missing argument to ViewFileObservation --- src/codegen/extensions/tools/create_file.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index b3ed3e3f1..cc22d3ede 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -35,12 +35,14 @@ def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: Opt CreateFileObservation containing new file state, or error if file exists """ if max_tokens: - error = f"Your response reached the max output tokens limit of {max_tokens} tokens (~{max_tokens * 0.75 / 20} lines). Consider breaking upt the content into smaller files." + error = f"""Your response reached the max output tokens limit of {max_tokens} tokens (~ {max_tokens / 10} lines). +Create the file in chunks or break up the content into smaller files. + """ return CreateFileObservation( status="error", error=error, filepath=filepath, - file_info=ViewFileObservation(status="error", error=error, filepath=filepath, content="", line_count=0), + file_info=ViewFileObservation(status="error", error=error, filepath=filepath, content="", raw_content="", line_count=0), ) if codebase.has_file(filepath): return CreateFileObservation(