Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/codegen/extensions/langchain/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@
import anthropic
import openai
from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.stores import InMemoryBaseStore
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START
from langgraph.graph.state import CompiledGraph, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.pregel import RetryPolicy

from codegen.agents.utils import AgentConfig
from codegen.extensions.langchain.llm import LLM
from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT
from codegen.extensions.langchain.utils.custom_tool_node import CustomToolNode
from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens


Expand Down Expand Up @@ -87,6 +94,7 @@ def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMe
self.config = config
self.max_messages = config.get("max_messages", 100) if config else 100
self.keep_first_messages = config.get("keep_first_messages", 1) if config else 1
self.store = InMemoryBaseStore()

# =================================== NODES ====================================

Expand Down Expand Up @@ -459,7 +467,7 @@ def get_field_descriptions(tool_obj):

# Add nodes
builder.add_node("reasoner", self.reasoner, retry=retry_policy)
builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy)
builder.add_node("tools", CustomToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy)
builder.add_node("summarize_conversation", self.summarize_conversation, retry=retry_policy)

# Add edges
Expand All @@ -471,7 +479,7 @@ def get_field_descriptions(tool_obj):
)
builder.add_conditional_edges("summarize_conversation", self.should_continue)

return builder.compile(checkpointer=checkpointer, debug=debug)
return builder.compile(checkpointer=checkpointer, store=self.store, debug=debug)


def create_react_agent(
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/extensions/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
if not os.getenv("ANTHROPIC_API_KEY"):
msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables."
raise ValueError(msg)
max_tokens = 16384 if "claude-3-7" in self.model_name else 8192
max_tokens = 8192
return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000)

elif self.model_provider == "openai":
Expand Down Expand Up @@ -129,7 +129,7 @@

def bind_tools(
self,
tools: Sequence[BaseTool],

Check failure on line 132 in src/codegen/extensions/langchain/llm.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 of "bind_tools" is incompatible with supertype "BaseChatModel"; supertype defines the argument type as "Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool]" [override]
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the underlying model.
Expand Down
28 changes: 17 additions & 11 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from typing import Annotated, ClassVar, Literal, Optional

from langchain_core.messages import ToolMessage
from langchain_core.stores import InMemoryBaseStore
from langchain_core.tools import InjectedToolCallId
from langchain_core.tools.base import BaseTool
from langgraph.prebuilt import InjectedStore
from pydantic import BaseModel, Field

from codegen.extensions.linear.linear_client import LinearClient
Expand Down Expand Up @@ -64,11 +66,11 @@
class ViewFileTool(BaseTool):
"""Tool for viewing file contents and metadata."""

name: ClassVar[str] = "view_file"

Check failure on line 69 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = """View the contents and metadata of a file in the codebase.

Check failure on line 70 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
For large files (>500 lines), content will be paginated. Use start_line and end_line to navigate through the file.
The response will indicate if there are more lines available to view."""
args_schema: ClassVar[type[BaseModel]] = ViewFileInput

Check failure on line 73 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
Expand Down Expand Up @@ -106,9 +108,9 @@
class ListDirectoryTool(BaseTool):
"""Tool for listing directory contents."""

name: ClassVar[str] = "list_directory"

Check failure on line 111 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "List contents of a directory in the codebase"

Check failure on line 112 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
args_schema: ClassVar[type[BaseModel]] = ListDirectoryInput

Check failure on line 113 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
Expand Down Expand Up @@ -137,9 +139,9 @@
class SearchTool(BaseTool):
"""Tool for searching the codebase."""

name: ClassVar[str] = "search"

Check failure on line 142 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "Search the codebase using text search or regex pattern matching"

Check failure on line 143 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
args_schema: ClassVar[type[BaseModel]] = SearchInput

Check failure on line 144 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
Expand Down Expand Up @@ -196,11 +198,13 @@
class CreateFileInput(BaseModel):
"""Input for creating a file."""

model_config = {"arbitrary_types_allowed": True}
filepath: str = Field(..., description="Path where to create the file")
store: Annotated[InMemoryBaseStore, InjectedStore()]
content: str = Field(
...,
default="",
description="""
Content for the new file (REQUIRED).
Content for the new file.

⚠️ IMPORTANT: This parameter MUST be a STRING, not a dictionary, JSON object, or any other data type.
Example: content="print('Hello world')"
Expand All @@ -214,19 +218,14 @@

name: ClassVar[str] = "create_file"
description: ClassVar[str] = """
Create a new file in the codebase. Always provide content for the new file, even if minimal.

⚠️ CRITICAL WARNING ⚠️
Both parameters MUST be provided as STRINGS:
The content for the new file always needs to be provided.
Create a new file in the codebase.

1. filepath: The path where to create the file (as a string)
2. content: The content for the new file (as a STRING, NOT as a dictionary or JSON object)

✅ CORRECT usage:
create_file(filepath="path/to/file.py", content="print('Hello world')")

The content parameter is REQUIRED and MUST be a STRING. If you receive a validation error about
If you receive a validation error about
missing content, you are likely trying to pass a dictionary instead of a string.
"""
args_schema: ClassVar[type[BaseModel]] = CreateFileInput
Expand All @@ -235,8 +234,15 @@
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, filepath: str, content: str) -> str:
result = create_file(self.codebase, filepath, content)
def _run(self, filepath: str, store: InMemoryBaseStore, content: str = "") -> str:
create_file_tool_status = store.mget([self.name])[0]
if create_file_tool_status and create_file_tool_status.get("max_tokens_reached", False):
max_tokens = create_file_tool_status.get("max_tokens", None)
store.mset([(self.name, {"max_tokens": max_tokens, "max_tokens_reached": False})])
result = create_file(self.codebase, filepath, content, max_tokens=max_tokens)
else:
result = create_file(self.codebase, filepath, content)

return result.render()


Expand Down
43 changes: 43 additions & 0 deletions src/codegen/extensions/langchain/utils/custom_tool_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any, Literal, Optional, Union

from langchain_core.messages import (
AIMessage,
AnyMessage,
ToolCall,
)
from langchain_core.stores import InMemoryBaseStore
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel


class CustomToolNode(ToolNode):
"""Extended ToolNode that detects truncated tool calls."""

def _parse_input(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
store: Optional[InMemoryBaseStore],
) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
"""Parse the input and check for truncated tool calls."""
messages = input.get("messages", [])
if isinstance(messages, list):
if isinstance(messages[-1], AIMessage):
response_metadata = messages[-1].response_metadata
# Check if the stop reason is due to max tokens
if response_metadata.get("stop_reason") == "max_tokens":
# Check if the response metadata contains usage information
if "usage" not in response_metadata or "output_tokens" not in response_metadata["usage"]:
msg = "Response metadata is missing usage information."
raise ValueError(msg)

output_tokens = response_metadata["usage"]["output_tokens"]
for tool_call in messages[-1].tool_calls:
if tool_call.get("name") == "create_file":
# Set the max tokens and max tokens reached flag in the store
store.mset([(tool_call["name"], {"max_tokens": output_tokens, "max_tokens_reached": True})])

return super()._parse_input(input, store)
14 changes: 12 additions & 2 deletions src/codegen/extensions/tools/create_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tool for creating new files."""

from typing import ClassVar
from typing import ClassVar, Optional

from pydantic import Field

Expand All @@ -23,7 +23,7 @@ class CreateFileObservation(Observation):
str_template: ClassVar[str] = "Created file {filepath}"


def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileObservation:
def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: Optional[int] = None) -> CreateFileObservation:
"""Create a new file.

Args:
Expand All @@ -34,6 +34,16 @@ def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileOb
Returns:
CreateFileObservation containing new file state, or error if file exists
"""
if max_tokens:
error = f"""Your response reached the max output tokens limit of {max_tokens} tokens (~ {max_tokens / 10} lines).
Create the file in chunks or break up the content into smaller files.
"""
return CreateFileObservation(
status="error",
error=error,
filepath=filepath,
file_info=ViewFileObservation(status="error", error=error, filepath=filepath, content="", raw_content="", line_count=0),
)
if codebase.has_file(filepath):
return CreateFileObservation(
status="error",
Expand Down