From 0d7df435b2b431c957624da9b0299569fa79e47e Mon Sep 17 00:00:00 2001 From: Leo Boisvert Date: Wed, 4 Dec 2024 22:22:19 +0000 Subject: [PATCH 1/3] adapt multiple samples for HF models --- src/agentlab/llm/huggingface_utils.py | 60 +++++++++++++++++---------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 9bb2d7ab..5a3ce2c7 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -1,6 +1,6 @@ import logging import time -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from pydantic import Field from transformers import AutoTokenizer, GPT2TokenizerFast @@ -12,7 +12,7 @@ class HFBaseChatModel(AbstractChatModel): """ - Custom LLM Chatbot that can interface with HuggingFace models. + Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples. This class allows for the creation of a custom chatbot using models hosted on HuggingFace Hub or a local checkpoint. It provides flexibility in defining @@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel): Attributes: llm (Any): The HuggingFaceHub model instance. prompt_template (Any): Template for the prompt to be used for the model's input sequence. + tokenizer (Any): The tokenizer to use for the model. + n_retry_server (int): Number of times to retry on server failure. """ llm: Any = Field(description="The HuggingFaceHub model instance") @@ -53,12 +55,20 @@ def __init__(self, model_name, n_retry_server): def __call__( self, messages: list[dict], - ) -> dict: - - # NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation. - + n_samples: int = 1, + ) -> Union[AIMessage, List[AIMessage]]: + """ + Generate one or more responses for the given messages. + + Args: + messages: List of message dictionaries containing the conversation history. + n_samples: Number of independent responses to generate. Defaults to 1. + + Returns: + If n_samples=1, returns a single AIMessage. + If n_samples>1, returns a list of AIMessages. + """ if self.tokenizer: - # messages_formated = _convert_messages_to_dict(messages) ## ? try: if isinstance(messages, Discussion): messages.merge() @@ -66,31 +76,35 @@ def __call__( except Exception as e: if "Conversation roles must alternate" in str(e): logging.warning( - f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role" + f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. " "Retrying with the 'system' role appended to the 'user' role." ) messages = _prepend_system_to_first_user(messages) prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) else: raise e - elif self.prompt_template: prompt = self.prompt_template.construct_prompt(messages) - itr = 0 - while True: - try: - response = AIMessage(self.llm(prompt)) - return response - except Exception as e: - if itr == self.n_retry_server - 1: - raise e - logging.warning( - f"Failed to get a response from the server: \n{e}\n" - f"Retrying... ({itr+1}/{self.n_retry_server})" - ) - time.sleep(5) - itr += 1 + responses = [] + for _ in range(n_samples): + itr = 0 + while True: + try: + response = AIMessage(self.llm(prompt)) + responses.append(response) + break + except Exception as e: + if itr == self.n_retry_server - 1: + raise e + logging.warning( + f"Failed to get a response from the server: \n{e}\n" + f"Retrying... ({itr+1}/{self.n_retry_server})" + ) + time.sleep(5) + itr += 1 + + return responses[0] if n_samples == 1 else responses def _llm_type(self): return "huggingface" From e5d5171dde08ce6e17be3ea0d1f1503a61bf00ad Mon Sep 17 00:00:00 2001 From: Leo Boisvert Date: Wed, 4 Dec 2024 22:57:42 +0000 Subject: [PATCH 2/3] tweaks for per-call temperature setting --- src/agentlab/llm/chat_api.py | 10 +++++----- src/agentlab/llm/huggingface_utils.py | 5 ++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index f8f02766..afc6d158 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -261,7 +261,7 @@ def __init__( **client_args, ) - def __call__(self, messages: list[dict], n_samples: int = 1) -> dict: + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: # Initialize retry tracking attributes self.retries = 0 self.success = False @@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict: e = None for itr in range(self.max_retry): self.retries += 1 + temperature = temperature if temperature is not None else self.temperature try: completion = self.client.chat.completions.create( model=self.model_name, messages=messages, n=n_samples, - temperature=self.temperature, + temperature=temperature, max_tokens=self.max_tokens, ) @@ -414,11 +415,10 @@ def __init__( super().__init__(model_name, n_retry_server) if temperature < 1e-3: logging.warning("Models might behave weirdly when temperature is too low.") + self.temperature = temperature if token is None: token = os.environ["TGI_TOKEN"] client = InferenceClient(model=model_url, token=token) - self.llm = partial( - client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens - ) + self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 5a3ce2c7..835304ff 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -56,6 +56,7 @@ def __call__( self, messages: list[dict], n_samples: int = 1, + temperature: float = None, ) -> Union[AIMessage, List[AIMessage]]: """ Generate one or more responses for the given messages. @@ -63,6 +64,7 @@ def __call__( Args: messages: List of message dictionaries containing the conversation history. n_samples: Number of independent responses to generate. Defaults to 1. + temperature: The temperature for response sampling. Defaults to None. Returns: If n_samples=1, returns a single AIMessage. @@ -91,7 +93,8 @@ def __call__( itr = 0 while True: try: - response = AIMessage(self.llm(prompt)) + temperature = temperature if temperature is not None else self.temperature + response = AIMessage(self.llm(prompt, temperature=temperature)) responses.append(response) break except Exception as e: From 6a2c783e704c0bec607131b55df52c4c82db9f8f Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 5 Dec 2024 09:28:39 -0500 Subject: [PATCH 3/3] darglint --- src/agentlab/llm/huggingface_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 835304ff..364221b5 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -69,6 +69,9 @@ def __call__( Returns: If n_samples=1, returns a single AIMessage. If n_samples>1, returns a list of AIMessages. + + Raises: + Exception: If the server fails to respond after n_retry_server attempts or if the chat template fails. """ if self.tokenizer: try: