From d36eb0bbd6225b38eb91d1d764ad16d1c7dffaba Mon Sep 17 00:00:00 2001 From: Leo Boisvert Date: Fri, 29 Nov 2024 22:01:12 +0000 Subject: [PATCH 1/4] add fix for self-hosted HF models --- src/agentlab/llm/huggingface_utils.py | 2 ++ src/agentlab/llm/llm_utils.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 470324bd..f67a3aae 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -7,6 +7,7 @@ from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template +from agentlab.llm.llm_utils import Discussion class HFBaseChatModel(AbstractChatModel): @@ -59,6 +60,7 @@ def __call__( if self.tokenizer: # messages_formated = _convert_messages_to_dict(messages) ## ? try: + messages.merge() prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) except Exception as e: if "Conversation roles must alternate" in str(e): diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 43d9dd5b..e3300f96 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -386,6 +386,8 @@ def merge(self): else: new_content.append(elem) self["content"] = new_content + if len(self["content"]) == 1: + self["content"] = self["content"][0]["text"] class SystemMessage(BaseMessage): From c16a69d1b2bb44f1e94716505e303306c2338396 Mon Sep 17 00:00:00 2001 From: Thibault LSDC <78021491+ThibaultLSDC@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:02:30 -0500 Subject: [PATCH 2/4] Update src/agentlab/llm/huggingface_utils.py --- src/agentlab/llm/huggingface_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index f67a3aae..06b7adbe 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -7,7 +7,6 @@ from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template -from agentlab.llm.llm_utils import Discussion class HFBaseChatModel(AbstractChatModel): From 5922437b7f7c11964db1b4f5fa5f46134f73dded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Boisvert?= Date: Fri, 29 Nov 2024 17:05:32 -0500 Subject: [PATCH 3/4] Update huggingface_utils.py --- src/agentlab/llm/huggingface_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 06b7adbe..153978b1 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer, GPT2TokenizerFast from agentlab.llm.base_api import AbstractChatModel +from agentlab.llm.llm_utils import Discussion from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template @@ -59,7 +60,8 @@ def __call__( if self.tokenizer: # messages_formated = _convert_messages_to_dict(messages) ## ? try: - messages.merge() + if isinstance(messages, Discussion): + messages.merge() prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) except Exception as e: if "Conversation roles must alternate" in str(e): From 3d19905f335dd7b58f37d39ad1eb32157b73f339 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Fri, 29 Nov 2024 17:18:10 -0500 Subject: [PATCH 4/4] updating test --- tests/llm/test_llm_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 5597760a..d7c42853 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -251,8 +251,7 @@ def test_message_merge_only_text(): ] message = llm_utils.BaseMessage(role="system", content=content) message.merge() - assert len(message["content"]) == 1 - assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + assert message["content"] == "Hello, world!\nThis is a test." def test_message_merge_text_image():