diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 470324bd..153978b1 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer, GPT2TokenizerFast from agentlab.llm.base_api import AbstractChatModel +from agentlab.llm.llm_utils import Discussion from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template @@ -59,6 +60,8 @@ def __call__( if self.tokenizer: # messages_formated = _convert_messages_to_dict(messages) ## ? try: + if isinstance(messages, Discussion): + messages.merge() prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) except Exception as e: if "Conversation roles must alternate" in str(e): diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 43d9dd5b..e3300f96 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -386,6 +386,8 @@ def merge(self): else: new_content.append(elem) self["content"] = new_content + if len(self["content"]) == 1: + self["content"] = self["content"][0]["text"] class SystemMessage(BaseMessage): diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 5597760a..d7c42853 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -251,8 +251,7 @@ def test_message_merge_only_text(): ] message = llm_utils.BaseMessage(role="system", content=content) message.merge() - assert len(message["content"]) == 1 - assert message["content"][0]["text"] == "Hello, world!\nThis is a test." + assert message["content"] == "Hello, world!\nThis is a test." def test_message_merge_text_image():