Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions src/agentlab/experiments/loop.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import gzip
import importlib.metadata
import json
Expand All @@ -13,12 +14,15 @@
from collections import defaultdict
from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Optional

import gymnasium as gym
import numpy as np
import PIL.Image
from browsergym.core.chat import Chat
from browsergym.core.hint_labeling import HintLabeling, HintLabelingInputs
from browsergym.experiments.agent import Agent
from browsergym.experiments.utils import count_tokens
from dataclasses_json import DataClassJsonMixin
Expand Down Expand Up @@ -49,6 +53,8 @@ class EnvArgs(DataClassJsonMixin):
storage_state: Optional[str | Path | dict] = None
task_kwargs: Optional[dict] = None # use default value from BrowserGym
pre_observation_delay: float = None # seconds, wait for JS events to be fired
use_chat_ui: bool = False
use_hint_labeling_ui: bool = False

def make_env(
self, action_mapping, exp_dir, exp_task_kwargs: dict = {}, use_raw_page_output=True
Expand Down Expand Up @@ -96,6 +102,8 @@ def make_env(
wait_for_user_message=self.wait_for_user_message,
action_mapping=action_mapping, # action mapping is provided by the agent
use_raw_page_output=use_raw_page_output,
use_chat_ui=self.use_chat_ui,
use_hint_labeling_ui=self.use_hint_labeling_ui,
**extra_kwargs,
)

Expand Down Expand Up @@ -443,15 +451,22 @@ def run(self):
# will end the episode after saving the step info.
step_info.truncated = True

step_info.save_step_info(
self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
)
logger.debug("Step info saved.")
# step_info.save_step_info(
# self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
# )
# logger.debug("Step info saved.")

if hasattr(env.unwrapped, "chat") and isinstance(env.unwrapped.chat, Chat):
_send_chat_info(env.unwrapped.chat, action, step_info.agent_info)
logger.debug("Chat info sent.")

if hasattr(env.unwrapped, "hint_labeling") and isinstance(
env.unwrapped.hint_labeling, HintLabeling
):
action = _update_hint_labeling(
env.unwrapped.hint_labeling, action, agent, step_info
)

if action is None:
logger.debug("Agent returned None action. Ending episode.")
break
Expand Down Expand Up @@ -481,10 +496,11 @@ def run(self):

finally:
try:
if step_info is not None:
step_info.save_step_info(
self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
)
pass
# if step_info is not None:
# step_info.save_step_info(
# self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
# )
except Exception as e:
logger.error(f"Error while saving step info in the finally block: {e}")
try:
Expand All @@ -508,7 +524,8 @@ def run(self):
except Exception as e:
logger.exception(f"Error while closing the environment: {e}")
try:
self._unset_logger() # stop writing logs to run logfile
# self._unset_logger() # stop writing logs to run logfile
pass
except Exception as e:
logger.exception(f"Error while unsetting the logger: {e}")

Expand Down Expand Up @@ -943,6 +960,80 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict):
chat.add_message(role="info", msg=msg)


def _convert_np_array_to_base64(np_array: np.ndarray):
im = PIL.Image.fromarray(np_array)
buffered = BytesIO()
im.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64


def _update_hint_labeling(
hint_labeling: HintLabeling, action: str, agent: Agent, step_info: StepInfo
):
"""Update the hint labeling with the action and agent info."""
context = HintLabelingInputs(
goal=step_info.obs.get("goal", ""), # TODO: is this goal deprecated?
error_feedback=step_info.obs.get("last_action_error", ""),
screenshot=_convert_np_array_to_base64(step_info.obs["screenshot"]),
axtree=step_info.obs["axtree_txt"],
history=[], # TODO: add history
hint="",
suggestions=[
{
"id": "1",
"action": action,
"think": step_info.agent_info.think,
}
],
)
while True:
# update hint labeling ui context
logger.info("Updating Hint Labeling UI context...")
hint_labeling.update_context(context)

# wait for hint labeling response
logger.info("Waiting for Hint Labeling UI response...")
response = hint_labeling.wait_for_response()

# if payload is for reprompt, we ask for 5 suggestions and we combine everything
if response["type"] == "reprompt":
# reprompt model 5 times
hint = response["payload"]["hint"]
agent.flags.extra_instructions = hint
seen_actions = set()
suggestions = []
for _ in range(5):
# TODO: make this more optimal
action = step_info.from_action(agent)
think = step_info.agent_info.think
if action not in seen_actions:
seen_actions.add(action)
suggestions.append(
{"id": str(len(seen_actions)), "action": action, "think": think}
)

# update context
context = HintLabelingInputs(
goal=context.goal,
error_feedback=context.error_feedback,
screenshot=context.screenshot,
axtree=context.axtree,
history=context.history,
hint=hint,
suggestions=suggestions,
)
continue

# otherwise, if payload is for action, we return the updated action and save the hint
elif response["type"] == "step":
step_info.agent_info.think = response["payload"]["think"]
action = response["payload"]["action"]
return action
else:
raise ValueError(f"Unknown response type: {response['type']}")


def _flatten_dict(d, parent_key="", sep="."):
"""Recursively flatten a nested dictionary."""
items = []
Expand Down
Loading