From 6b3d0fc78e1c6680d9754c182e64323cb0054418 Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 11 Feb 2025 16:59:46 -0500 Subject: [PATCH 1/5] adding log_prob option for chat models --- src/agentlab/llm/base_api.py | 1 + src/agentlab/llm/chat_api.py | 20 ++++++++++++++++++-- src/agentlab/llm/llm_utils.py | 3 ++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/agentlab/llm/base_api.py b/src/agentlab/llm/base_api.py index 9c1ebf5f..b6d1a7be 100644 --- a/src/agentlab/llm/base_api.py +++ b/src/agentlab/llm/base_api.py @@ -21,6 +21,7 @@ class BaseModelArgs(ABC): max_new_tokens: int = None temperature: float = 0.1 vision_support: bool = False + log_probs: bool = False @abstractmethod def make_model(self) -> AbstractChatModel: diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 7392e666..bf3380b2 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -87,6 +87,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + log_probs=self.log_probs, ) @@ -100,6 +101,7 @@ def make_model(self): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, + log_probs=self.log_probs, ) @@ -115,6 +117,7 @@ def make_model(self): temperature=self.temperature, max_tokens=self.max_new_tokens, deployment_name=self.deployment_name, + log_probs=self.log_probs, ) @@ -225,6 +228,7 @@ def __init__( client_class=OpenAI, client_args=None, pricing_func=None, + log_probs=False, ): assert max_retry > 0, "max_retry should be greater than 0" @@ -233,6 +237,7 @@ def __init__( self.max_tokens = max_tokens self.max_retry = max_retry self.min_retry_wait_time = min_retry_wait_time + self.logprobs = log_probs # Get the API key from the environment variable if not provided if api_key_env_var: @@ -279,6 +284,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float n=n_samples, temperature=temperature, max_tokens=self.max_tokens, + logprobs=self.logprobs, ) if completion.usage is None: @@ -308,7 +314,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float tracking.TRACKER.instance(input_tokens, output_tokens, cost) if n_samples == 1: - return AIMessage(completion.choices[0].message.content) + res = AIMessage(completion.choices[0].message.content) + if self.logprobs: + res["logprobs"] = completion.choices[0].logprobs + return res else: return [AIMessage(c.message.content) for c in completion.choices] @@ -328,6 +337,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): super().__init__( model_name=model_name, @@ -339,6 +349,7 @@ def __init__( api_key_env_var="OPENAI_API_KEY", client_class=OpenAI, pricing_func=tracking.get_pricing_openai, + log_probs=log_probs, ) @@ -351,6 +362,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): client_args = { "base_url": "https://openrouter.ai/api/v1", @@ -366,6 +378,7 @@ def __init__( client_class=OpenAI, client_args=client_args, pricing_func=tracking.get_pricing_openrouter, + log_probs=log_probs, ) @@ -379,6 +392,7 @@ def __init__( max_tokens=100, max_retry=4, min_retry_wait_time=60, + log_probs=False, ): api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") @@ -399,6 +413,7 @@ def __init__( client_class=AzureOpenAI, client_args=client_args, pricing_func=tracking.get_pricing_openai, + log_probs=log_probs, ) @@ -412,6 +427,7 @@ def __init__( temperature: Optional[int] = 1e-1, max_new_tokens: Optional[int] = 512, n_retry_server: Optional[int] = 4, + log_probs: Optional[bool] = False, ): super().__init__(model_name, base_model_name, n_retry_server) if temperature < 1e-3: @@ -422,4 +438,4 @@ def __init__( token = os.environ["TGI_TOKEN"] client = InferenceClient(model=model_url, token=token) - self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens) + self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 2bbf219d..d6f9e822 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -382,9 +382,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): class BaseMessage(dict): - def __init__(self, role: str, content: Union[str, list[dict]]): + def __init__(self, role: str, content: Union[str, list[dict]], **kwargs): self["role"] = role self["content"] = deepcopy(content) + self.update(kwargs) def __str__(self, warn_if_image=False) -> str: if isinstance(self["content"], str): From 9f37a74eb0f19e538896caa95f49dfbef4be265c Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 13 Feb 2025 10:37:44 -0500 Subject: [PATCH 2/5] vscode not saving my stuff :( --- src/agentlab/llm/huggingface_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 32f12082..2e682343 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -100,7 +100,10 @@ def __call__( while True: try: temperature = temperature if temperature is not None else self.temperature - response = AIMessage(self.llm(prompt, temperature=temperature)) + answer = self.llm(prompt, temperature=temperature) + response = AIMessage(answer) + if hasattr(answer, "details"): + response["log_prob"] = answer.details.log_prob responses.append(response) break except Exception as e: From edcf995fb4e238d910d0bdc84248b4fa13ecc42a Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Tue, 18 Feb 2025 16:17:48 -0500 Subject: [PATCH 3/5] debugging log_probs for hf models --- src/agentlab/llm/chat_api.py | 11 ++++++----- src/agentlab/llm/huggingface_utils.py | 13 +++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index bf3380b2..35c68b26 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -145,6 +145,7 @@ def make_model(self): temperature=self.temperature, max_new_tokens=self.max_new_tokens, n_retry_server=self.n_retry_server, + log_probs=self.log_probs ) else: raise ValueError(f"Backend {self.backend} is not supported") @@ -237,7 +238,7 @@ def __init__( self.max_tokens = max_tokens self.max_retry = max_retry self.min_retry_wait_time = min_retry_wait_time - self.logprobs = log_probs + self.log_probs = log_probs # Get the API key from the environment variable if not provided if api_key_env_var: @@ -284,7 +285,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float n=n_samples, temperature=temperature, max_tokens=self.max_tokens, - logprobs=self.logprobs, + log_probs=self.log_probs, ) if completion.usage is None: @@ -315,8 +316,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float if n_samples == 1: res = AIMessage(completion.choices[0].message.content) - if self.logprobs: - res["logprobs"] = completion.choices[0].logprobs + if self.log_probs: + res["log_probs"] = completion.choices[0].log_probs return res else: return [AIMessage(c.message.content) for c in completion.choices] @@ -429,7 +430,7 @@ def __init__( n_retry_server: Optional[int] = 4, log_probs: Optional[bool] = False, ): - super().__init__(model_name, base_model_name, n_retry_server) + super().__init__(model_name, base_model_name, n_retry_server, log_probs) if temperature < 1e-3: logging.warning("Models might behave weirdly when temperature is too low.") self.temperature = temperature diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 2e682343..719bfa61 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -2,12 +2,11 @@ import time from typing import Any, List, Optional, Union -from pydantic import Field -from transformers import AutoTokenizer, GPT2TokenizerFast - from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.llm_utils import AIMessage, Discussion from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template +from pydantic import Field +from transformers import AutoTokenizer, GPT2TokenizerFast class HFBaseChatModel(AbstractChatModel): @@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel): description="The number of times to retry the server if it fails to respond", ) - def __init__(self, model_name, base_model_name, n_retry_server): + def __init__(self, model_name, base_model_name, n_retry_server, log_probs): super().__init__() self.n_retry_server = n_retry_server + self.log_probs = log_probs if base_model_name is None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -102,8 +102,9 @@ def __call__( temperature = temperature if temperature is not None else self.temperature answer = self.llm(prompt, temperature=temperature) response = AIMessage(answer) - if hasattr(answer, "details"): - response["log_prob"] = answer.details.log_prob + if self.log_probs: + response["content"] = answer.generated_text + response["log_prob"] = answer.details responses.append(response) break except Exception as e: From abb44c7dff5f48ab1ede00e85ebda0fd0fd5e05b Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 20 Feb 2025 14:19:42 -0500 Subject: [PATCH 4/5] korbit 0_o --- src/agentlab/llm/llm_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index d6f9e822..2536200e 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -383,6 +383,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image): class BaseMessage(dict): def __init__(self, role: str, content: Union[str, list[dict]], **kwargs): + allowed_attrs = {"log_probs"} + invalid_attrs = set(kwargs.keys()) - allowed_attrs + if invalid_attrs: + raise ValueError(f"Invalid attributes: {invalid_attrs}") self["role"] = role self["content"] = deepcopy(content) self.update(kwargs) @@ -465,8 +469,8 @@ def __init__(self, content: Union[str, list[dict]]): class AIMessage(BaseMessage): - def __init__(self, content: Union[str, list[dict]]): - super().__init__("assistant", content) + def __init__(self, content: Union[str, list[dict]], log_probs=None): + super().__init__("assistant", content, log_probs=log_probs) class Discussion: From faac3bfacbd0ab5508a8f163abf6e94bda5040bc Mon Sep 17 00:00:00 2001 From: ThibaultLSDC Date: Thu, 20 Feb 2025 14:20:26 -0500 Subject: [PATCH 5/5] format --- src/agentlab/llm/chat_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 35c68b26..cfad3c3a 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -145,7 +145,7 @@ def make_model(self): temperature=self.temperature, max_new_tokens=self.max_new_tokens, n_retry_server=self.n_retry_server, - log_probs=self.log_probs + log_probs=self.log_probs, ) else: raise ValueError(f"Backend {self.backend} is not supported")