diff --git a/src/agentlab/experiments/loop.py b/src/agentlab/experiments/loop.py index de4b976a..299c3f83 100644 --- a/src/agentlab/experiments/loop.py +++ b/src/agentlab/experiments/loop.py @@ -1,3 +1,4 @@ +import base64 import gzip import importlib.metadata import json @@ -13,12 +14,15 @@ from collections import defaultdict from dataclasses import asdict, dataclass, field, is_dataclass from datetime import datetime +from io import BytesIO from pathlib import Path from typing import Optional import gymnasium as gym import numpy as np +import PIL.Image from browsergym.core.chat import Chat +from browsergym.core.hint_labeling import HintLabeling, HintLabelingInputs from browsergym.experiments.agent import Agent from browsergym.experiments.utils import count_tokens from dataclasses_json import DataClassJsonMixin @@ -49,6 +53,8 @@ class EnvArgs(DataClassJsonMixin): storage_state: Optional[str | Path | dict] = None task_kwargs: Optional[dict] = None # use default value from BrowserGym pre_observation_delay: float = None # seconds, wait for JS events to be fired + use_chat_ui: bool = False + use_hint_labeling_ui: bool = False def make_env( self, action_mapping, exp_dir, exp_task_kwargs: dict = {}, use_raw_page_output=True @@ -96,6 +102,8 @@ def make_env( wait_for_user_message=self.wait_for_user_message, action_mapping=action_mapping, # action mapping is provided by the agent use_raw_page_output=use_raw_page_output, + use_chat_ui=self.use_chat_ui, + use_hint_labeling_ui=self.use_hint_labeling_ui, **extra_kwargs, ) @@ -443,15 +451,22 @@ def run(self): # will end the episode after saving the step info. step_info.truncated = True - step_info.save_step_info( - self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som - ) - logger.debug("Step info saved.") + # step_info.save_step_info( + # self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som + # ) + # logger.debug("Step info saved.") if hasattr(env.unwrapped, "chat") and isinstance(env.unwrapped.chat, Chat): _send_chat_info(env.unwrapped.chat, action, step_info.agent_info) logger.debug("Chat info sent.") + if hasattr(env.unwrapped, "hint_labeling") and isinstance( + env.unwrapped.hint_labeling, HintLabeling + ): + action = _update_hint_labeling( + env.unwrapped.hint_labeling, action, agent, step_info + ) + if action is None: logger.debug("Agent returned None action. Ending episode.") break @@ -481,10 +496,11 @@ def run(self): finally: try: - if step_info is not None: - step_info.save_step_info( - self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som - ) + pass + # if step_info is not None: + # step_info.save_step_info( + # self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som + # ) except Exception as e: logger.error(f"Error while saving step info in the finally block: {e}") try: @@ -508,7 +524,8 @@ def run(self): except Exception as e: logger.exception(f"Error while closing the environment: {e}") try: - self._unset_logger() # stop writing logs to run logfile + # self._unset_logger() # stop writing logs to run logfile + pass except Exception as e: logger.exception(f"Error while unsetting the logger: {e}") @@ -943,6 +960,80 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict): chat.add_message(role="info", msg=msg) +def _convert_np_array_to_base64(np_array: np.ndarray): + im = PIL.Image.fromarray(np_array) + buffered = BytesIO() + im.save(buffered, format="PNG") + img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_b64 + + +def _update_hint_labeling( + hint_labeling: HintLabeling, action: str, agent: Agent, step_info: StepInfo +): + """Update the hint labeling with the action and agent info.""" + context = HintLabelingInputs( + goal=step_info.obs.get("goal", ""), # TODO: is this goal deprecated? + error_feedback=step_info.obs.get("last_action_error", ""), + screenshot=_convert_np_array_to_base64(step_info.obs["screenshot"]), + axtree=step_info.obs["axtree_txt"], + history=[], # TODO: add history + hint="", + suggestions=[ + { + "id": "1", + "action": action, + "think": step_info.agent_info.think, + } + ], + ) + while True: + # update hint labeling ui context + logger.info("Updating Hint Labeling UI context...") + hint_labeling.update_context(context) + + # wait for hint labeling response + logger.info("Waiting for Hint Labeling UI response...") + response = hint_labeling.wait_for_response() + + # if payload is for reprompt, we ask for 5 suggestions and we combine everything + if response["type"] == "reprompt": + # reprompt model 5 times + hint = response["payload"]["hint"] + agent.flags.extra_instructions = hint + seen_actions = set() + suggestions = [] + for _ in range(5): + # TODO: make this more optimal + action = step_info.from_action(agent) + think = step_info.agent_info.think + if action not in seen_actions: + seen_actions.add(action) + suggestions.append( + {"id": str(len(seen_actions)), "action": action, "think": think} + ) + + # update context + context = HintLabelingInputs( + goal=context.goal, + error_feedback=context.error_feedback, + screenshot=context.screenshot, + axtree=context.axtree, + history=context.history, + hint=hint, + suggestions=suggestions, + ) + continue + + # otherwise, if payload is for action, we return the updated action and save the hint + elif response["type"] == "step": + step_info.agent_info.think = response["payload"]["think"] + action = response["payload"]["action"] + return action + else: + raise ValueError(f"Unknown response type: {response['type']}") + + def _flatten_dict(d, parent_key="", sep="."): """Recursively flatten a nested dictionary.""" items = []