diff --git a/src/agentlab/experiments/reproducibility_util.py b/src/agentlab/experiments/reproducibility_util.py index c5d971a1..01f3fdc9 100644 --- a/src/agentlab/experiments/reproducibility_util.py +++ b/src/agentlab/experiments/reproducibility_util.py @@ -19,7 +19,9 @@ def _get_repo(module): return Repo(Path(module.__file__).resolve().parent, search_parent_directories=True) -def _get_benchmark_version(benchmark: bgym.Benchmark) -> str: +def _get_benchmark_version( + benchmark: bgym.Benchmark, allow_bypass_benchmark_version: bool = False +) -> str: benchmark_name = benchmark.name if hasattr(benchmark, "get_version"): @@ -42,7 +44,10 @@ def _get_benchmark_version(benchmark: bgym.Benchmark) -> str: elif benchmark_name.startswith("assistantbench"): return metadata.distribution("browsergym.assistantbench").version else: - raise ValueError(f"Unknown benchmark {benchmark_name}") + if allow_bypass_benchmark_version: + return "bypassed" + else: + raise ValueError(f"Unknown benchmark {benchmark_name}") def _get_git_username(repo: Repo) -> str: @@ -183,6 +188,7 @@ def get_reproducibility_info( "*inspect_results.ipynb", ), ignore_changes=False, + allow_bypass_benchmark_version=False, ): """ Retrieve a dict of information that could influence the reproducibility of an experiment. @@ -205,7 +211,7 @@ def get_reproducibility_info( "benchmark": benchmark.name, "study_id": study_id, "comment": comment, - "benchmark_version": _get_benchmark_version(benchmark), + "benchmark_version": _get_benchmark_version(benchmark, allow_bypass_benchmark_version), "date": datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), "os": f"{platform.system()} ({platform.version()})", "python_version": platform.python_version(), diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index c091d117..8a65b3a2 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod import gzip import logging import pickle @@ -269,6 +268,7 @@ def set_reproducibility_info(self, strict_reproducibility=False, comment=None): self.uuid, ignore_changes=not strict_reproducibility, comment=comment, + allow_bypass_benchmark_version=not strict_reproducibility, ) if self.reproducibility_info is not None: repro.assert_compatible( diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index e22d2405..f8f02766 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -261,7 +261,7 @@ def __init__( **client_args, ) - def __call__(self, messages: list[dict]) -> dict: + def __call__(self, messages: list[dict], n_samples: int = 1) -> dict: # Initialize retry tracking attributes self.retries = 0 self.success = False @@ -275,6 +275,7 @@ def __call__(self, messages: list[dict]) -> dict: completion = self.client.chat.completions.create( model=self.model_name, messages=messages, + n=n_samples, temperature=self.temperature, max_tokens=self.max_tokens, ) @@ -305,7 +306,10 @@ def __call__(self, messages: list[dict]) -> dict: ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) - return AIMessage(completion.choices[0].message.content) + if n_samples == 1: + return AIMessage(completion.choices[0].message.content) + else: + return [AIMessage(c.message.content) for c in completion.choices] def get_stats(self): return { diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index e3300f96..d7dcc5cf 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -90,6 +90,65 @@ def retry( raise ParseError(f"Could not parse a valid value after {n_retry} retries.") +def retry_multiple( + chat: "ChatModel", + messages: "Discussion", + n_retry: int, + parser: callable, + log: bool = True, + num_samples: int = 1, +): + """Retry querying the chat models with the response from the parser until it + returns a valid value. + + If the answer is not valid, it will retry and append to the chat the retry + message. It will stop after `n_retry`. + + Note, each retry has to resend the whole prompt to the API. This can be slow + and expensive. + + Args: + chat (ChatModel): a ChatModel object taking a list of messages and + returning a list of answers, all in OpenAI format. + messages (list): the list of messages so far. This list will be modified with + the new messages and the retry messages. + n_retry (int): the maximum number of sequential retries. + parser (callable): a function taking a message and retruning a parsed value, + or raising a ParseError + log (bool): whether to log the retry messages. + num_samples (int): the number of samples to generate from the model. + + Returns: + list[dict]: the parsed value, with a string at key "action". + + Raises: + ParseError: if the parser could not parse the response after n_retry retries. + """ + tries = 0 + while tries < n_retry: + answer_list = chat(messages, num_samples=num_samples) + # TODO: could we change this to not use inplace modifications ? + messages.append(answer) + parsed_answers = [] + errors = [] + for answer in answer_list: + try: + parsed_answers.append(parser(answer["content"])) + except ParseError as parsing_error: + errors.append(str(parsing_error)) + # if we have a valid answer, return it + if parsed_answers: + return parsed_answers, tries + else: + tries += 1 + if log: + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(errors)}" + logging.info(msg) + messages.append(dict(role="user", content=str(errors))) + + raise ParseError(f"Could not parse a valid value after {n_retry} retries.") + + def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"): """Use tiktoken to truncate a text to a maximum number of tokens.""" enc = tiktoken.encoding_for_model(model_name)