Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
**client_args,
)

def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
# Initialize retry tracking attributes
self.retries = 0
self.success = False
Expand All @@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
e = None
for itr in range(self.max_retry):
self.retries += 1
temperature = temperature if temperature is not None else self.temperature
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
n=n_samples,
temperature=self.temperature,
temperature=temperature,
max_tokens=self.max_tokens,
)

Expand Down Expand Up @@ -414,11 +415,10 @@ def __init__(
super().__init__(model_name, n_retry_server)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature

if token is None:
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
66 changes: 43 additions & 23 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast
Expand All @@ -12,7 +12,7 @@

class HFBaseChatModel(AbstractChatModel):
"""
Custom LLM Chatbot that can interface with HuggingFace models.
Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples.

This class allows for the creation of a custom chatbot using models hosted
on HuggingFace Hub or a local checkpoint. It provides flexibility in defining
Expand All @@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel):
Attributes:
llm (Any): The HuggingFaceHub model instance.
prompt_template (Any): Template for the prompt to be used for the model's input sequence.
tokenizer (Any): The tokenizer to use for the model.
n_retry_server (int): Number of times to retry on server failure.
"""

llm: Any = Field(description="The HuggingFaceHub model instance")
Expand Down Expand Up @@ -53,44 +55,62 @@ def __init__(self, model_name, n_retry_server):
def __call__(
self,
messages: list[dict],
) -> dict:

# NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation.

n_samples: int = 1,
temperature: float = None,
) -> Union[AIMessage, List[AIMessage]]:
"""
Generate one or more responses for the given messages.

Args:
messages: List of message dictionaries containing the conversation history.
n_samples: Number of independent responses to generate. Defaults to 1.
temperature: The temperature for response sampling. Defaults to None.

Returns:
If n_samples=1, returns a single AIMessage.
If n_samples>1, returns a list of AIMessages.

Raises:
Exception: If the server fails to respond after n_retry_server attempts or if the chat template fails.
"""
if self.tokenizer:
# messages_formated = _convert_messages_to_dict(messages) ## ?
try:
if isinstance(messages, Discussion):
messages.merge()
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
except Exception as e:
if "Conversation roles must alternate" in str(e):
logging.warning(
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role"
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
"Retrying with the 'system' role appended to the 'user' role."
)
messages = _prepend_system_to_first_user(messages)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
else:
raise e

elif self.prompt_template:
prompt = self.prompt_template.construct_prompt(messages)

itr = 0
while True:
try:
response = AIMessage(self.llm(prompt))
return response
except Exception as e:
if itr == self.n_retry_server - 1:
raise e
logging.warning(
f"Failed to get a response from the server: \n{e}\n"
f"Retrying... ({itr+1}/{self.n_retry_server})"
)
time.sleep(5)
itr += 1
responses = []
for _ in range(n_samples):
itr = 0
while True:
try:
temperature = temperature if temperature is not None else self.temperature
response = AIMessage(self.llm(prompt, temperature=temperature))
responses.append(response)
break
except Exception as e:
if itr == self.n_retry_server - 1:
raise e
logging.warning(
f"Failed to get a response from the server: \n{e}\n"
f"Retrying... ({itr+1}/{self.n_retry_server})"
)
time.sleep(5)
itr += 1

return responses[0] if n_samples == 1 else responses

def _llm_type(self):
return "huggingface"
Expand Down
Loading