diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 719bfa61..8c3e862d 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -2,11 +2,12 @@ import time from typing import Any, List, Optional, Union +from pydantic import Field +from transformers import AutoTokenizer, GPT2TokenizerFast + from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.llm_utils import AIMessage, Discussion from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template -from pydantic import Field -from transformers import AutoTokenizer, GPT2TokenizerFast class HFBaseChatModel(AbstractChatModel): @@ -104,7 +105,7 @@ def __call__( response = AIMessage(answer) if self.log_probs: response["content"] = answer.generated_text - response["log_prob"] = answer.details + response["log_probs"] = answer.details responses.append(response) break except Exception as e: