diff --git a/src/agentlab/llm/base_api.py b/src/agentlab/llm/base_api.py index 9c1ebf5f..b6d1a7be 100644 --- a/src/agentlab/llm/base_api.py +++ b/src/agentlab/llm/base_api.py @@ -21,6 +21,7 @@ class BaseModelArgs(ABC): max_new_tokens: int = None temperature: float = 0.1 vision_support: bool = False + log_probs: bool = False @abstractmethod def make_model(self) -> AbstractChatModel: diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 096aae00..5ecc05ab 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -87,6 +87,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + log_probs=self.log_probs, ) @@ -100,6 +101,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + log_probs=self.log_probs, ) @@ -115,6 +117,7 @@ def make_model(self): temperature=self.temperature, max_tokens=self.max_new_tokens, deployment_name=self.deployment_name, + log_probs=self.log_probs, ) @@ -142,6 +145,7 @@ def make_model(self): temperature=self.temperature, max_new_tokens=self.max_new_tokens, n_retry_server=self.n_retry_server, + log_probs=self.log_probs, ) elif self.backend == "vllm": return VLLMChatModel( @@ -232,6 +236,7 @@ def __init__( client_class=OpenAI, client_args=None, pricing_func=None, + log_probs=False, ): assert max_retry > 0, "max_retry should be greater than 0" @@ -240,6 +245,7 @@ def __init__( self.max_tokens = max_tokens self.max_retry = max_retry self.min_retry_wait_time = min_retry_wait_time + self.log_probs = log_probs # Get the API key from the environment variable if not provided if api_key_env_var: @@ -286,6 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float n=n_samples, temperature=temperature, max_tokens=self.max_tokens, + log_probs=self.log_probs, ) if completion.usage is None: @@ -315,7 +322,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: - return AIMessage(completion.choices[0].message.content) + res = AIMessage(completion.choices[0].message.content) + if self.log_probs: + res["log_probs"] = completion.choices[0].log_probs + return res else: return [AIMessage(c.message.content) for c in completion.choices] @@ -335,6 +345,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): super().__init__( model_name=model_name, @@ -346,6 +357,7 @@ def __init__( api_key_env_var="OPENAI_API_KEY", client_class=OpenAI, pricing_func=tracking.get_pricing_openai, + log_probs=log_probs, ) @@ -358,6 +370,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): client_args = { "base_url": "https://openrouter.ai/api/v1", @@ -373,6 +386,7 @@ def __init__( client_class=OpenAI, client_args=client_args, pricing_func=tracking.get_pricing_openrouter, + log_probs=log_probs, ) @@ -386,6 +400,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") @@ -406,6 +421,7 @@ def __init__( client_class=AzureOpenAI, client_args=client_args, pricing_func=tracking.get_pricing_openai, + log_probs=log_probs, ) @@ -419,8 +435,9 @@ def __init__( temperature: Optional[int] = 1e-1, max_new_tokens: Optional[int] = 512, n_retry_server: Optional[int] = 4, + log_probs: Optional[bool] = False, ): - super().__init__(model_name, base_model_name, n_retry_server) + super().__init__(model_name, base_model_name, n_retry_server, log_probs) if temperature < 1e-3: logging.warning("Models might behave weirdly when temperature is too low.") self.temperature = temperature @@ -429,7 +446,7 @@ def __init__( token = os.environ["TGI_TOKEN"] client = InferenceClient(model=model_url, token=token) - self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens) + self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs) class VLLMChatModel(ChatModel): diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 32f12082..719bfa61 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -2,12 +2,11 @@ import time from typing import Any, List, Optional, Union -from pydantic import Field -from transformers import AutoTokenizer, GPT2TokenizerFast - from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.llm_utils import AIMessage, Discussion from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template +from pydantic import Field +from transformers import AutoTokenizer, GPT2TokenizerFast class HFBaseChatModel(AbstractChatModel): @@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel): description="The number of times to retry the server if it fails to respond", ) - def __init__(self, model_name, base_model_name, n_retry_server): + def __init__(self, model_name, base_model_name, n_retry_server, log_probs): super().__init__() self.n_retry_server = n_retry_server + self.log_probs = log_probs if base_model_name is None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -100,7 +100,11 @@ def __call__( while True: try: temperature = temperature if temperature is not None else self.temperature - response = AIMessage(self.llm(prompt, temperature=temperature)) + answer = self.llm(prompt, temperature=temperature) + response = AIMessage(answer) + if self.log_probs: + response["content"] = answer.generated_text + response["log_prob"] = answer.details responses.append(response) break except Exception as e: diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 2bbf219d..2536200e 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -382,9 +382,14 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): class BaseMessage(dict): - def __init__(self, role: str, content: Union[str, list[dict]]): + def __init__(self, role: str, content: Union[str, list[dict]], **kwargs): + allowed_attrs = {"log_probs"} + invalid_attrs = set(kwargs.keys()) - allowed_attrs + if invalid_attrs: + raise ValueError(f"Invalid attributes: {invalid_attrs}") self["role"] = role self["content"] = deepcopy(content) + self.update(kwargs) def __str__(self, warn_if_image=False) -> str: if isinstance(self["content"], str): @@ -464,8 +469,8 @@ def __init__(self, content: Union[str, list[dict]]): class AIMessage(BaseMessage): - def __init__(self, content: Union[str, list[dict]]): - super().__init__("assistant", content) + def __init__(self, content: Union[str, list[dict]], log_probs=None): + super().__init__("assistant", content, log_probs=log_probs) class Discussion: