Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/agentlab/experiments/reproducibility_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def _get_repo(module):
return Repo(Path(module.__file__).resolve().parent, search_parent_directories=True)


def _get_benchmark_version(benchmark: bgym.Benchmark) -> str:
def _get_benchmark_version(
benchmark: bgym.Benchmark, allow_bypass_benchmark_version: bool = False
) -> str:
benchmark_name = benchmark.name

if hasattr(benchmark, "get_version"):
Expand All @@ -42,7 +44,10 @@ def _get_benchmark_version(benchmark: bgym.Benchmark) -> str:
elif benchmark_name.startswith("assistantbench"):
return metadata.distribution("browsergym.assistantbench").version
else:
raise ValueError(f"Unknown benchmark {benchmark_name}")
if allow_bypass_benchmark_version:
return "bypassed"
else:
raise ValueError(f"Unknown benchmark {benchmark_name}")


def _get_git_username(repo: Repo) -> str:
Expand Down Expand Up @@ -183,6 +188,7 @@ def get_reproducibility_info(
"*inspect_results.ipynb",
),
ignore_changes=False,
allow_bypass_benchmark_version=False,
):
"""
Retrieve a dict of information that could influence the reproducibility of an experiment.
Expand All @@ -205,7 +211,7 @@ def get_reproducibility_info(
"benchmark": benchmark.name,
"study_id": study_id,
"comment": comment,
"benchmark_version": _get_benchmark_version(benchmark),
"benchmark_version": _get_benchmark_version(benchmark, allow_bypass_benchmark_version),
"date": datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
"os": f"{platform.system()} ({platform.version()})",
"python_version": platform.python_version(),
Expand Down
2 changes: 1 addition & 1 deletion src/agentlab/experiments/study.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
import gzip
import logging
import pickle
Expand Down Expand Up @@ -269,6 +268,7 @@ def set_reproducibility_info(self, strict_reproducibility=False, comment=None):
self.uuid,
ignore_changes=not strict_reproducibility,
comment=comment,
allow_bypass_benchmark_version=not strict_reproducibility,
)
if self.reproducibility_info is not None:
repro.assert_compatible(
Expand Down
8 changes: 6 additions & 2 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
**client_args,
)

def __call__(self, messages: list[dict]) -> dict:
def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
# Initialize retry tracking attributes
self.retries = 0
self.success = False
Expand All @@ -275,6 +275,7 @@ def __call__(self, messages: list[dict]) -> dict:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
n=n_samples,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
Expand Down Expand Up @@ -305,7 +306,10 @@ def __call__(self, messages: list[dict]) -> dict:
):
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

return AIMessage(completion.choices[0].message.content)
if n_samples == 1:
return AIMessage(completion.choices[0].message.content)
else:
return [AIMessage(c.message.content) for c in completion.choices]

def get_stats(self):
return {
Expand Down
59 changes: 59 additions & 0 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,65 @@ def retry(
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")


def retry_multiple(
chat: "ChatModel",
messages: "Discussion",
n_retry: int,
parser: callable,
log: bool = True,
num_samples: int = 1,
):
"""Retry querying the chat models with the response from the parser until it
returns a valid value.

If the answer is not valid, it will retry and append to the chat the retry
message. It will stop after `n_retry`.

Note, each retry has to resend the whole prompt to the API. This can be slow
and expensive.

Args:
chat (ChatModel): a ChatModel object taking a list of messages and
returning a list of answers, all in OpenAI format.
messages (list): the list of messages so far. This list will be modified with
the new messages and the retry messages.
n_retry (int): the maximum number of sequential retries.
parser (callable): a function taking a message and retruning a parsed value,
or raising a ParseError
log (bool): whether to log the retry messages.
num_samples (int): the number of samples to generate from the model.

Returns:
list[dict]: the parsed value, with a string at key "action".

Raises:
ParseError: if the parser could not parse the response after n_retry retries.
"""
tries = 0
while tries < n_retry:
answer_list = chat(messages, num_samples=num_samples)
# TODO: could we change this to not use inplace modifications ?
messages.append(answer)
parsed_answers = []
errors = []
for answer in answer_list:
try:
parsed_answers.append(parser(answer["content"]))
except ParseError as parsing_error:
errors.append(str(parsing_error))
# if we have a valid answer, return it
if parsed_answers:
return parsed_answers, tries
else:
tries += 1
if log:
msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(errors)}"
logging.info(msg)
messages.append(dict(role="user", content=str(errors)))

raise ParseError(f"Could not parse a valid value after {n_retry} retries.")


def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"):
"""Use tiktoken to truncate a text to a maximum number of tokens."""
enc = tiktoken.encoding_for_model(model_name)
Expand Down
Loading