Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 82 additions & 80 deletions src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from concurrent.futures import ProcessPoolExecutor
import gzip
import logging
import os
import pickle
import random
import uuid
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Manager, Pool, Queue
from pathlib import Path

import bgym
Expand All @@ -19,8 +21,6 @@
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
from multiprocessing import Pool, Manager, Queue
import random

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,7 +238,7 @@ def __post_init__(self):

def make_exp_args_list(self):
"""Generate the exp_args_list from the agent_args and the benchmark."""
self.exp_args_list = _agents_on_benchmark(
self.exp_args_list = self.agents_on_benchmark(
self.agent_args,
self.benchmark,
logging_level=self.logging_level,
Expand Down Expand Up @@ -424,6 +424,84 @@ def load(dir: Path) -> "Study":
def load_most_recent(root_dir: Path = None, contains=None) -> "Study":
return Study.load(get_most_recent_study(root_dir, contains=contains))

def agents_on_benchmark(
self,
agents: list[AgentArgs] | AgentArgs,
benchmark: bgym.Benchmark,
demo_mode=False,
logging_level: int = logging.INFO,
logging_level_stdout: int = logging.INFO,
ignore_dependencies=False,
):
"""Run one or multiple agents on a benchmark.

Args:
agents: list[AgentArgs] | AgentArgs
The agent configuration(s) to run.
benchmark: bgym.Benchmark
The benchmark to run the agents on.
demo_mode: bool
If True, the experiments will be run in demo mode.
logging_level: int
The logging level for individual jobs.
logging_level_stdout: int
The logging level for the stdout.
ignore_dependencies: bool
If True, the dependencies will be ignored and all experiments can be run in parallel.

Returns:
list[ExpArgs]: The list of experiments to run.

Raises:
ValueError: If multiple agents are run on a benchmark that requires manual reset.
"""

if not isinstance(agents, (list, tuple)):
agents = [agents]

if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
if len(agents) > 1:
raise ValueError(
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
)

for agent in agents:
agent.set_benchmark(
benchmark, demo_mode
) # the agent can adapt (lightly?) to the benchmark

env_args_list = benchmark.env_args_list
if demo_mode:
set_demo_mode(env_args_list)

exp_args_list = []

for agent in agents:
for env_args in env_args_list:
exp_args = ExpArgs(
agent_args=agent,
env_args=env_args,
logging_level=logging_level,
logging_level_stdout=logging_level_stdout,
)
exp_args_list.append(exp_args)

for i, exp_args in enumerate(exp_args_list):
exp_args.order = i

# not required with ray, but keeping around if we would need it for visualwebareana on joblib
# _flag_sequential_exp(exp_args_list, benchmark)

if not ignore_dependencies:
# populate the depends_on field based on the task dependencies in the benchmark
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
else:
logger.warning(
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
)

return exp_args_list


def _make_study_name(agent_names, benchmark_names, suffix=None):
"""Make a study name from the agent and benchmark names."""
Expand Down Expand Up @@ -634,82 +712,6 @@ def set_demo_mode(env_args_list: list[EnvArgs]):
env_args.slow_mo = 1000


def _agents_on_benchmark(
agents: list[AgentArgs] | AgentArgs,
benchmark: bgym.Benchmark,
demo_mode=False,
logging_level: int = logging.INFO,
logging_level_stdout: int = logging.INFO,
ignore_dependencies=False,
):
"""Run one or multiple agents on a benchmark.

Args:
agents: list[AgentArgs] | AgentArgs
The agent configuration(s) to run.
benchmark: bgym.Benchmark
The benchmark to run the agents on.
demo_mode: bool
If True, the experiments will be run in demo mode.
logging_level: int
The logging level for individual jobs.
logging_level_stdout: int
The logging level for the stdout.
ignore_dependencies: bool
If True, the dependencies will be ignored and all experiments can be run in parallel.

Returns:
list[ExpArgs]: The list of experiments to run.

Raises:
ValueError: If multiple agents are run on a benchmark that requires manual reset.
"""

if not isinstance(agents, (list, tuple)):
agents = [agents]

if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
if len(agents) > 1:
raise ValueError(
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
)

for agent in agents:
agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark

env_args_list = benchmark.env_args_list
if demo_mode:
set_demo_mode(env_args_list)

exp_args_list = []

for agent in agents:
for env_args in env_args_list:
exp_args = ExpArgs(
agent_args=agent,
env_args=env_args,
logging_level=logging_level,
logging_level_stdout=logging_level_stdout,
)
exp_args_list.append(exp_args)

for i, exp_args in enumerate(exp_args_list):
exp_args.order = i

# not required with ray, but keeping around if we would need it for visualwebareana on joblib
# _flag_sequential_exp(exp_args_list, benchmark)

if not ignore_dependencies:
# populate the depends_on field based on the task dependencies in the benchmark
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
else:
logger.warning(
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
)

return exp_args_list


# def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark):
# if benchmark.name.startswith("visualwebarena"):
# sequential_subset = benchmark.subset_from_glob("requires_reset", "True")
Expand Down