diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 22a49a78d..03ea70040 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -6,17 +6,24 @@ import anthropic import openai from langchain.tools import BaseTool -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + AnyMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.prompts import ChatPromptTemplate +from langchain_core.stores import InMemoryBaseStore from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START from langgraph.graph.state import CompiledGraph, StateGraph -from langgraph.prebuilt import ToolNode from langgraph.pregel import RetryPolicy from codegen.agents.utils import AgentConfig from codegen.extensions.langchain.llm import LLM from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT +from codegen.extensions.langchain.utils.custom_tool_node import CustomToolNode from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens @@ -87,6 +94,7 @@ def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMe self.config = config self.max_messages = config.get("max_messages", 100) if config else 100 self.keep_first_messages = config.get("keep_first_messages", 1) if config else 1 + self.store = InMemoryBaseStore() # =================================== NODES ==================================== @@ -459,7 +467,7 @@ def get_field_descriptions(tool_obj): # Add nodes builder.add_node("reasoner", self.reasoner, retry=retry_policy) - builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) + builder.add_node("tools", CustomToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) builder.add_node("summarize_conversation", self.summarize_conversation, retry=retry_policy) # Add edges @@ -471,7 +479,7 @@ def get_field_descriptions(tool_obj): ) builder.add_conditional_edges("summarize_conversation", self.should_continue) - return builder.compile(checkpointer=checkpointer, debug=debug) + return builder.compile(checkpointer=checkpointer, store=self.store, debug=debug) def create_react_agent( diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 4c457e46d..dadcf6314 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -89,7 +89,7 @@ def _get_model(self) -> BaseChatModel: if not os.getenv("ANTHROPIC_API_KEY"): msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - max_tokens = 16384 if "claude-3-7" in self.model_name else 8192 + max_tokens = 8192 return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000) elif self.model_provider == "openai": diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 9f8041dc9..3a1193be1 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -4,8 +4,10 @@ from typing import Annotated, ClassVar, Literal, Optional from langchain_core.messages import ToolMessage +from langchain_core.stores import InMemoryBaseStore from langchain_core.tools import InjectedToolCallId from langchain_core.tools.base import BaseTool +from langgraph.prebuilt import InjectedStore from pydantic import BaseModel, Field from codegen.extensions.linear.linear_client import LinearClient @@ -196,11 +198,13 @@ def _run(self, filepath: str, content: str, tool_call_id: str) -> str: class CreateFileInput(BaseModel): """Input for creating a file.""" + model_config = {"arbitrary_types_allowed": True} filepath: str = Field(..., description="Path where to create the file") + store: Annotated[InMemoryBaseStore, InjectedStore()] content: str = Field( - ..., + default="", description=""" -Content for the new file (REQUIRED). +Content for the new file. ⚠️ IMPORTANT: This parameter MUST be a STRING, not a dictionary, JSON object, or any other data type. Example: content="print('Hello world')" @@ -214,19 +218,14 @@ class CreateFileTool(BaseTool): name: ClassVar[str] = "create_file" description: ClassVar[str] = """ -Create a new file in the codebase. Always provide content for the new file, even if minimal. - -⚠️ CRITICAL WARNING ⚠️ -Both parameters MUST be provided as STRINGS: -The content for the new file always needs to be provided. +Create a new file in the codebase. 1. filepath: The path where to create the file (as a string) 2. content: The content for the new file (as a STRING, NOT as a dictionary or JSON object) ✅ CORRECT usage: create_file(filepath="path/to/file.py", content="print('Hello world')") - -The content parameter is REQUIRED and MUST be a STRING. If you receive a validation error about +If you receive a validation error about missing content, you are likely trying to pass a dictionary instead of a string. """ args_schema: ClassVar[type[BaseModel]] = CreateFileInput @@ -235,8 +234,15 @@ class CreateFileTool(BaseTool): def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, filepath: str, content: str) -> str: - result = create_file(self.codebase, filepath, content) + def _run(self, filepath: str, store: InMemoryBaseStore, content: str = "") -> str: + create_file_tool_status = store.mget([self.name])[0] + if create_file_tool_status and create_file_tool_status.get("max_tokens_reached", False): + max_tokens = create_file_tool_status.get("max_tokens", None) + store.mset([(self.name, {"max_tokens": max_tokens, "max_tokens_reached": False})]) + result = create_file(self.codebase, filepath, content, max_tokens=max_tokens) + else: + result = create_file(self.codebase, filepath, content) + return result.render() diff --git a/src/codegen/extensions/langchain/utils/custom_tool_node.py b/src/codegen/extensions/langchain/utils/custom_tool_node.py new file mode 100644 index 000000000..bdbe4ab0e --- /dev/null +++ b/src/codegen/extensions/langchain/utils/custom_tool_node.py @@ -0,0 +1,43 @@ +from typing import Any, Literal, Optional, Union + +from langchain_core.messages import ( + AIMessage, + AnyMessage, + ToolCall, +) +from langchain_core.stores import InMemoryBaseStore +from langgraph.prebuilt import ToolNode +from pydantic import BaseModel + + +class CustomToolNode(ToolNode): + """Extended ToolNode that detects truncated tool calls.""" + + def _parse_input( + self, + input: Union[ + list[AnyMessage], + dict[str, Any], + BaseModel, + ], + store: Optional[InMemoryBaseStore], + ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]: + """Parse the input and check for truncated tool calls.""" + messages = input.get("messages", []) + if isinstance(messages, list): + if isinstance(messages[-1], AIMessage): + response_metadata = messages[-1].response_metadata + # Check if the stop reason is due to max tokens + if response_metadata.get("stop_reason") == "max_tokens": + # Check if the response metadata contains usage information + if "usage" not in response_metadata or "output_tokens" not in response_metadata["usage"]: + msg = "Response metadata is missing usage information." + raise ValueError(msg) + + output_tokens = response_metadata["usage"]["output_tokens"] + for tool_call in messages[-1].tool_calls: + if tool_call.get("name") == "create_file": + # Set the max tokens and max tokens reached flag in the store + store.mset([(tool_call["name"], {"max_tokens": output_tokens, "max_tokens_reached": True})]) + + return super()._parse_input(input, store) diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index 77d8dd20d..cc22d3ede 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -1,6 +1,6 @@ """Tool for creating new files.""" -from typing import ClassVar +from typing import ClassVar, Optional from pydantic import Field @@ -23,7 +23,7 @@ class CreateFileObservation(Observation): str_template: ClassVar[str] = "Created file {filepath}" -def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileObservation: +def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: Optional[int] = None) -> CreateFileObservation: """Create a new file. Args: @@ -34,6 +34,16 @@ def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileOb Returns: CreateFileObservation containing new file state, or error if file exists """ + if max_tokens: + error = f"""Your response reached the max output tokens limit of {max_tokens} tokens (~ {max_tokens / 10} lines). +Create the file in chunks or break up the content into smaller files. + """ + return CreateFileObservation( + status="error", + error=error, + filepath=filepath, + file_info=ViewFileObservation(status="error", error=error, filepath=filepath, content="", raw_content="", line_count=0), + ) if codebase.has_file(filepath): return CreateFileObservation( status="error",