-
Notifications
You must be signed in to change notification settings - Fork 108
Aj/tool use agent chat completion support #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,60 +5,28 @@ | |
|
|
||
| import bgym | ||
| import numpy as np | ||
| from browsergym.core.observation import extract_screenshot | ||
| from PIL import Image, ImageDraw | ||
|
|
||
| from agentlab.agents import agent_utils | ||
| from agentlab.agents.agent_args import AgentArgs | ||
| from agentlab.llm.llm_utils import image_to_png_base64_url | ||
| from agentlab.llm.response_api import ( | ||
| BaseModelArgs, | ||
| ClaudeResponseModelArgs, | ||
| MessageBuilder, | ||
| OpenAIChatModelArgs, | ||
| OpenAIResponseModelArgs, | ||
| OpenRouterModelArgs, | ||
| ResponseLLMOutput, | ||
| VLLMModelArgs, | ||
| ) | ||
| from agentlab.llm.tracking import cost_tracker_decorator | ||
| from browsergym.core.observation import extract_screenshot | ||
|
|
||
| if TYPE_CHECKING: | ||
| from openai.types.responses import Response | ||
|
|
||
|
|
||
| def tag_screenshot_with_action(screenshot: Image, action: str) -> Image: | ||
| """ | ||
| If action is a coordinate action, try to render it on the screenshot. | ||
|
|
||
| e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot | ||
|
|
||
| Args: | ||
| screenshot: The screenshot to tag. | ||
| action: The action to tag the screenshot with. | ||
|
|
||
| Returns: | ||
| The tagged screenshot. | ||
|
|
||
| Raises: | ||
| ValueError: If the action parsing fails. | ||
| """ | ||
| if action.startswith("mouse_click"): | ||
| try: | ||
| coords = action[action.index("(") + 1 : action.index(")")].split(",") | ||
| coords = [c.strip() for c in coords] | ||
| if len(coords) != 2: | ||
| raise ValueError(f"Invalid coordinate format: {coords}") | ||
| if coords[0].startswith("x="): | ||
| coords[0] = coords[0][2:] | ||
| if coords[1].startswith("y="): | ||
| coords[1] = coords[1][2:] | ||
| x, y = float(coords[0].strip()), float(coords[1].strip()) | ||
| draw = ImageDraw.Draw(screenshot) | ||
| radius = 5 | ||
| draw.ellipse( | ||
| (x - radius, y - radius, x + radius, y + radius), fill="red", outline="red" | ||
| ) | ||
| except (ValueError, IndexError) as e: | ||
| logging.warning(f"Failed to parse action '{action}': {e}") | ||
| return screenshot | ||
|
|
||
|
|
||
| @dataclass | ||
| class ToolUseAgentArgs(AgentArgs): | ||
| model_args: OpenAIResponseModelArgs = None | ||
|
|
@@ -97,19 +65,9 @@ def __init__( | |
| self.model_args = model_args | ||
| self.use_first_obs = use_first_obs | ||
| self.tag_screenshot = tag_screenshot | ||
|
|
||
| self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False) | ||
|
|
||
| self.tools = self.action_set.to_tool_description(api=model_args.api) | ||
|
|
||
| # count tools tokens | ||
| from agentlab.llm.llm_utils import count_tokens | ||
|
|
||
| tool_str = json.dumps(self.tools, indent=2) | ||
| print(f"Tool description: {tool_str}") | ||
| tool_tokens = count_tokens(tool_str, model_args.model_name) | ||
| print(f"Tool tokens: {tool_tokens}") | ||
|
|
||
| self.call_ids = [] | ||
|
|
||
| # self.tools.append( | ||
|
|
@@ -131,7 +89,7 @@ def __init__( | |
| # ) | ||
|
|
||
| self.llm = model_args.make_model(extra_kwargs={"tools": self.tools}) | ||
|
|
||
| self.msg_builder = model_args.get_message_builder() | ||
| self.messages: list[MessageBuilder] = [] | ||
|
|
||
| def obs_preprocessor(self, obs): | ||
|
|
@@ -140,7 +98,7 @@ def obs_preprocessor(self, obs): | |
| obs["screenshot"] = extract_screenshot(page) | ||
| if self.tag_screenshot: | ||
| screenshot = Image.fromarray(obs["screenshot"]) | ||
| screenshot = tag_screenshot_with_action(screenshot, obs["last_action"]) | ||
| screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"]) | ||
| obs["screenshot_tag"] = np.array(screenshot) | ||
| else: | ||
| raise ValueError("No page found in the observation.") | ||
|
|
@@ -150,56 +108,31 @@ def obs_preprocessor(self, obs): | |
| @cost_tracker_decorator | ||
| def get_action(self, obs: Any) -> float: | ||
| if len(self.messages) == 0: | ||
| system_message = MessageBuilder.system().add_text( | ||
| "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal." | ||
| ) | ||
| self.messages.append(system_message) | ||
|
|
||
| goal_message = MessageBuilder.user() | ||
| for content in obs["goal_object"]: | ||
| if content["type"] == "text": | ||
| goal_message.add_text(content["text"]) | ||
| elif content["type"] == "image_url": | ||
| goal_message.add_image(content["image_url"]) | ||
| self.messages.append(goal_message) | ||
|
|
||
| extra_info = [] | ||
|
|
||
| extra_info.append( | ||
| """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" | ||
| ) | ||
|
|
||
| self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info))) | ||
|
|
||
| if self.use_first_obs: | ||
| msg = "Here is the first observation." | ||
| screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" | ||
| if self.tag_screenshot: | ||
| msg += " A red dot on screenshots indicate the previous click action." | ||
| message = MessageBuilder.user().add_text(msg) | ||
| message.add_image(image_to_png_base64_url(obs[screenshot_key])) | ||
| self.messages.append(message) | ||
| self.initalize_messages(obs) | ||
| else: | ||
| if obs["last_action_error"] == "": | ||
| if obs["last_action_error"] == "": # Check No error in the last action | ||
| screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" | ||
| tool_message = MessageBuilder.tool().add_image( | ||
| tool_message = self.msg_builder.tool().add_image( | ||
| image_to_png_base64_url(obs[screenshot_key]) | ||
| ) | ||
| tool_message.update_last_raw_response(self.last_response) | ||
| tool_message.add_tool_id(self.previous_call_id) | ||
| self.messages.append(tool_message) | ||
| else: | ||
| tool_message = MessageBuilder.tool().add_text( | ||
| tool_message = self.msg_builder.tool().add_text( | ||
| f"Function call failed: {obs['last_action_error']}" | ||
| ) | ||
| tool_message.add_tool_id(self.previous_call_id) | ||
| tool_message.update_last_raw_response(self.last_response) | ||
| self.messages.append(tool_message) | ||
|
|
||
| response: ResponseLLMOutput = self.llm(messages=self.messages) | ||
|
|
||
| action = response.action | ||
| think = response.think | ||
| self.last_response = response | ||
| self.previous_call_id = response.last_computer_call_id | ||
| self.messages.append(response.assistant_message) | ||
| self.messages.append(response.assistant_message) # this is tool call | ||
|
|
||
| return ( | ||
| action, | ||
|
|
@@ -210,6 +143,37 @@ def get_action(self, obs: Any) -> float: | |
| ), | ||
| ) | ||
|
|
||
| def initalize_messages(self, obs: Any) -> None: | ||
| system_message = self.msg_builder.system().add_text( | ||
| "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal." | ||
| ) | ||
| self.messages.append(system_message) | ||
|
|
||
| goal_message = self.msg_builder.user() | ||
| for content in obs["goal_object"]: | ||
| if content["type"] == "text": | ||
| goal_message.add_text(content["text"]) | ||
| elif content["type"] == "image_url": | ||
| goal_message.add_image(content["image_url"]) | ||
| self.messages.append(goal_message) | ||
|
|
||
| extra_info = [] | ||
|
|
||
| extra_info.append( | ||
| """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" | ||
| ) | ||
|
|
||
| self.messages.append(self.msg_builder.user().add_text("\n".join(extra_info))) | ||
|
|
||
| if self.use_first_obs: | ||
| msg = "Here is the first observation." | ||
| screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" | ||
| if self.tag_screenshot: | ||
| msg += " A red dot on screenshots indicate the previous click action." | ||
| message = self.msg_builder.user().add_text(msg) | ||
| message.add_image(image_to_png_base64_url(obs[screenshot_key])) | ||
| self.messages.append(message) | ||
|
|
||
|
|
||
| OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs( | ||
| model_name="gpt-4.1", | ||
|
|
@@ -220,6 +184,14 @@ def get_action(self, obs: Any) -> float: | |
| vision_support=True, | ||
| ) | ||
|
|
||
| OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs( | ||
| model_name="gpt-4o-2024-08-06", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
| max_new_tokens=2_000, | ||
| temperature=0.1, | ||
| vision_support=True, | ||
| ) | ||
|
|
||
| CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( | ||
| model_name="claude-3-7-sonnet-20250219", | ||
|
|
@@ -231,6 +203,103 @@ def get_action(self, obs: Any) -> float: | |
| ) | ||
|
|
||
|
|
||
|
|
||
|
|
||
| # def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs: | ||
| # default_model_args = { | ||
| # "max_total_tokens": 200_000, | ||
| # "max_input_tokens": 180_000, | ||
| # "max_new_tokens": 2_000, | ||
| # "temperature": 0.1, | ||
| # "vision_support": True, | ||
| # } | ||
| # merged_args = {**default_model_args, **open_router_args} | ||
|
|
||
| # return OpenRouterModelArgs(model_name=model_name, **merged_args) | ||
|
|
||
|
|
||
| # def get_openrouter_tool_use_agent( | ||
| # model_name: str, | ||
| # model_args: dict = {}, | ||
| # use_first_obs=True, | ||
| # tag_screenshot=True, | ||
| # use_raw_page_output=True, | ||
| # ) -> ToolUseAgentArgs: | ||
| # # To Do : Check if OpenRouter endpoint specific args are working | ||
| # if not supports_tool_calling(model_name): | ||
| # raise ValueError(f"Model {model_name} does not support tool calling.") | ||
|
|
||
| # model_args = get_openrouter_model(model_name, **model_args) | ||
|
|
||
| # return ToolUseAgentArgs( | ||
| # model_args=model_args, | ||
| # use_first_obs=use_first_obs, | ||
| # tag_screenshot=tag_screenshot, | ||
| # use_raw_page_output=use_raw_page_output, | ||
| # ) | ||
|
|
||
|
|
||
| # OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview") | ||
|
|
||
|
|
||
| AGENT_CONFIG = ToolUseAgentArgs( | ||
| model_args=CLAUDE_MODEL_CONFIG, | ||
| ) | ||
|
|
||
| # MT_TOOL_USE_AGENT = ToolUseAgentArgs( | ||
| # model_args=OPENROUTER_MODEL, | ||
| # ) | ||
| CHATAPI_AGENT_CONFIG = ToolUseAgentArgs( | ||
| model_args=OpenAIChatModelArgs( | ||
| model_name="gpt-4o-2024-11-20", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
| max_new_tokens=2_000, | ||
| temperature=0.7, | ||
| vision_support=True, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs( | ||
| model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06") | ||
| ) | ||
|
Comment on lines
+264
to
+266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent Constant Naming
Tell me moreWhat is the issue?Variable name contains a typo ('TOOl' instead of 'TOOL') and is inconsistent with other constant naming patterns. Why this mattersInconsistent naming makes the code harder to search for and understand, especially when the inconsistency is due to a typo. Suggested change ∙ Feature PreviewOAI_CHAT_TOOL_AGENT = ToolUseAgentArgs(
model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06")
)Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
|
|
||
| PROVIDER_FACTORY_MAP = { | ||
| "openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs}, | ||
| "openrouter": OpenRouterModelArgs, | ||
| "vllm": VLLMModelArgs, | ||
| "antrophic": ClaudeResponseModelArgs, | ||
| } | ||
|
Comment on lines
+269
to
+274
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provider Name Typo
Tell me moreWhat is the issue?The 'antrophic' key in PROVIDER_FACTORY_MAP contains a typo (should be 'anthropic'). Why this mattersThis typo will cause runtime errors when trying to use Anthropic models as the provider key won't match. Suggested change ∙ Feature PreviewCorrect the spelling in the dictionary: PROVIDER_FACTORY_MAP = {
"openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs},
"openrouter": OpenRouterModelArgs,
"vllm": VLLMModelArgs,
"anthropic": ClaudeResponseModelArgs,
}Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
|
|
||
| def get_tool_use_agent( | ||
| api_provider: str, | ||
| model_args: "BaseModelArgs", | ||
| tool_use_agent_args: dict = None, | ||
| api_provider_spec=None, | ||
| ) -> ToolUseAgentArgs: | ||
|
|
||
| if api_provider == "openai": | ||
| assert ( | ||
| api_provider_spec is not None | ||
| ), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'." | ||
|
|
||
| model_args_factory = ( | ||
| PROVIDER_FACTORY_MAP[api_provider] | ||
| if api_provider_spec is None | ||
| else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec] | ||
| ) | ||
|
|
||
| # Create the agent with model arguments from the factory | ||
| agent = ToolUseAgentArgs( | ||
| model_args=model_args_factory(**model_args), **(tool_use_agent_args or {}) | ||
| ) | ||
| return agent | ||
|
|
||
|
|
||
| ## We have three providers that we want to support. | ||
| # Anthropic | ||
| # OpenAI | ||
| # vllm (uses OpenAI API) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the static method from message builder could be in the llm object instead. this way we wouldn't need to implement this extra thing.
e.g.
message = self.llm.tool().add_imageThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we can push like this and move it after