From 731786db0de7f7725e2d9efd671170ff69a0e6d0 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 3 Dec 2024 16:17:14 -0500 Subject: [PATCH 1/6] looking good --- src/agentlab/llm/chat_api.py | 8 ++++-- src/agentlab/llm/llm_utils.py | 54 +++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index e22d2405..f8f02766 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -261,7 +261,7 @@ def __init__( **client_args, ) - def __call__(self, messages: list[dict]) -> dict: + def __call__(self, messages: list[dict], n_samples: int = 1) -> dict: # Initialize retry tracking attributes self.retries = 0 self.success = False @@ -275,6 +275,7 @@ def __call__(self, messages: list[dict]) -> dict: completion = self.client.chat.completions.create( model=self.model_name, messages=messages, + n=n_samples, temperature=self.temperature, max_tokens=self.max_tokens, ) @@ -305,7 +306,10 @@ def __call__(self, messages: list[dict]) -> dict: ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) - return AIMessage(completion.choices[0].message.content) + if n_samples == 1: + return AIMessage(completion.choices[0].message.content) + else: + return [AIMessage(c.message.content) for c in completion.choices] def get_stats(self): return { diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index e3300f96..708d68e8 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -90,6 +90,60 @@ def retry( raise ParseError(f"Could not parse a valid value after {n_retry} retries.") +def retry_multiple( + chat: "ChatModel", + messages: "Discussion", + n_retry: int, + parser: callable, + log: bool = True, + num_samples: int = 1, +): + """Retry querying the chat models with the response from the parser until it + returns a valid value. + + If the answer is not valid, it will retry and append to the chat the retry + message. It will stop after `n_retry`. + + Note, each retry has to resend the whole prompt to the API. This can be slow + and expensive. + + Args: + chat (ChatModel): a ChatModel object taking a list of messages and + returning a list of answers, all in OpenAI format. + messages (list): the list of messages so far. This list will be modified with + the new messages and the retry messages. + n_retry (int): the maximum number of sequential retries. + parser (callable): a function taking a message and retruning a parsed value, + or raising a ParseError + log (bool): whether to log the retry messages. + + Returns: + dict: the parsed value, with a string at key "action". + + Raises: + ParseError: if the parser could not parse the response after n_retry retries. + """ + tries = 0 + while tries < n_retry: + answer_list = chat(messages, num_samples=num_samples) + # TODO: could we change this to not use inplace modifications ? + messages.append(answer) + parsed_answers = [] + errors = [] + for answer in answer_list: + try: + parsed_answers.append(parser(answer["content"])) + except ParseError as parsing_error: + errors.append(str(parsing_error)) + tries += 1 + if log: + msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(errors)}" + logging.info(msg) + messages.append(dict(role="user", content=str(errors))) + + raise ParseError(f"Could not parse a valid value after {n_retry} retries.") + + def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"): """Use tiktoken to truncate a text to a maximum number of tokens.""" enc = tiktoken.encoding_for_model(model_name) From a300d7162281caea6367c3eb90b5a6a31704a455 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 3 Dec 2024 18:27:18 -0500 Subject: [PATCH 2/6] adding bypass on benchmark version --- src/agentlab/experiments/reproducibility_util.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/agentlab/experiments/reproducibility_util.py b/src/agentlab/experiments/reproducibility_util.py index c5d971a1..01f3fdc9 100644 --- a/src/agentlab/experiments/reproducibility_util.py +++ b/src/agentlab/experiments/reproducibility_util.py @@ -19,7 +19,9 @@ def _get_repo(module): return Repo(Path(module.__file__).resolve().parent, search_parent_directories=True) -def _get_benchmark_version(benchmark: bgym.Benchmark) -> str: +def _get_benchmark_version( + benchmark: bgym.Benchmark, allow_bypass_benchmark_version: bool = False +) -> str: benchmark_name = benchmark.name if hasattr(benchmark, "get_version"): @@ -42,7 +44,10 @@ def _get_benchmark_version(benchmark: bgym.Benchmark) -> str: elif benchmark_name.startswith("assistantbench"): return metadata.distribution("browsergym.assistantbench").version else: - raise ValueError(f"Unknown benchmark {benchmark_name}") + if allow_bypass_benchmark_version: + return "bypassed" + else: + raise ValueError(f"Unknown benchmark {benchmark_name}") def _get_git_username(repo: Repo) -> str: @@ -183,6 +188,7 @@ def get_reproducibility_info( "*inspect_results.ipynb", ), ignore_changes=False, + allow_bypass_benchmark_version=False, ): """ Retrieve a dict of information that could influence the reproducibility of an experiment. @@ -205,7 +211,7 @@ def get_reproducibility_info( "benchmark": benchmark.name, "study_id": study_id, "comment": comment, - "benchmark_version": _get_benchmark_version(benchmark), + "benchmark_version": _get_benchmark_version(benchmark, allow_bypass_benchmark_version), "date": datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), "os": f"{platform.system()} ({platform.version()})", "python_version": platform.python_version(), From 9ab98d80b5c4f403476d033bfb1c2f3629d1b89d Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 3 Dec 2024 18:27:31 -0500 Subject: [PATCH 3/6] darglint --- src/agentlab/llm/llm_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 708d68e8..31e8c137 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -116,6 +116,7 @@ def retry_multiple( parser (callable): a function taking a message and retruning a parsed value, or raising a ParseError log (bool): whether to log the retry messages. + num_samples (int): the number of samples to generate from the model. Returns: dict: the parsed value, with a string at key "action". From da8208fc58848858611c72571165702a8449c714 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 3 Dec 2024 18:30:07 -0500 Subject: [PATCH 4/6] wth vscode --- src/agentlab/experiments/study.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index c091d117..8a65b3a2 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod import gzip import logging import pickle @@ -269,6 +268,7 @@ def set_reproducibility_info(self, strict_reproducibility=False, comment=None): self.uuid, ignore_changes=not strict_reproducibility, comment=comment, + allow_bypass_benchmark_version=not strict_reproducibility, ) if self.reproducibility_info is not None: repro.assert_compatible( From e124fea6139358e2ec61dd83cc91442affe393c4 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 3 Dec 2024 18:33:05 -0500 Subject: [PATCH 5/6] forgot return in retry multiple --- src/agentlab/llm/llm_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 31e8c137..cec2226f 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -136,6 +136,10 @@ def retry_multiple( parsed_answers.append(parser(answer["content"])) except ParseError as parsing_error: errors.append(str(parsing_error)) + # if we have a valid answer, return it + if parsed_answers: + return parsed_answers, tries + else: tries += 1 if log: msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(errors)}" From bea6cff41d38cb21b45078a015aed1cd7e5c005d Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 3 Dec 2024 18:35:32 -0500 Subject: [PATCH 6/6] darglint --- src/agentlab/llm/llm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index cec2226f..d7dcc5cf 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -119,7 +119,7 @@ def retry_multiple( num_samples (int): the number of samples to generate from the model. Returns: - dict: the parsed value, with a string at key "action". + list[dict]: the parsed value, with a string at key "action". Raises: ParseError: if the parser could not parse the response after n_retry retries.