From 845515830ccfd27e568f1fa09c6296d6dbb56c7b Mon Sep 17 00:00:00 2001 From: Roberto Montalti <37136851+rhighs@users.noreply.github.com> Date: Mon, 16 Dec 2024 12:13:42 +0100 Subject: [PATCH 1/3] Add variable name extraction methods to prompt clients (#1013) --- langfuse/model.py | 147 +++++++++++++++++++++---------- tests/test_prompt.py | 103 ++++++++++++++++++++++ tests/test_prompt_compilation.py | 79 +++++------------ 3 files changed, 224 insertions(+), 105 deletions(-) diff --git a/langfuse/model.py b/langfuse/model.py index 6d2b25b1e..6ffaa3ee2 100644 --- a/langfuse/model.py +++ b/langfuse/model.py @@ -1,7 +1,7 @@ """@private""" from abc import ABC, abstractmethod -from typing import Optional, TypedDict, Any, Dict, Union, List +from typing import Optional, TypedDict, Any, Dict, Union, List, Tuple import re from langfuse.api.resources.commons.types.dataset import ( @@ -54,6 +54,74 @@ class ChatMessageDict(TypedDict): content: str +class ChatMessageVariables(TypedDict): + role: str + variables: List[str] + + +class TemplateParser: + OPENING = "{{" + CLOSING = "}}" + + @staticmethod + def _parse_next_variable( + content: str, start_idx: int + ) -> Optional[Tuple[str, int, int]]: + """Returns (variable_name, start_pos, end_pos) or None if no variable found""" + var_start = content.find(TemplateParser.OPENING, start_idx) + if var_start == -1: + return None + + var_end = content.find(TemplateParser.CLOSING, var_start) + if var_end == -1: + return None + + variable_name = content[ + var_start + len(TemplateParser.OPENING) : var_end + ].strip() + return (variable_name, var_start, var_end + len(TemplateParser.CLOSING)) + + @staticmethod + def find_variable_names(content: str) -> List[str]: + names = [] + curr_idx = 0 + + while curr_idx < len(content): + result = TemplateParser._parse_next_variable(content, curr_idx) + if not result: + break + names.append(result[0]) + curr_idx = result[2] + + return names + + @staticmethod + def compile_template(content: str, data: Dict[str, Any] = {}) -> str: + result_list = [] + curr_idx = 0 + + while curr_idx < len(content): + result = TemplateParser._parse_next_variable(content, curr_idx) + + if not result: + result_list.append(content[curr_idx:]) + break + + variable_name, var_start, var_end = result + result_list.append(content[curr_idx:var_start]) + + if variable_name in data: + result_list.append( + str(data[variable_name]) if data[variable_name] is not None else "" + ) + else: + result_list.append(content[var_start:var_end]) + + curr_idx = var_end + + return "".join(result_list) + + class BasePromptClient(ABC): name: str version: int @@ -73,6 +141,10 @@ def __init__(self, prompt: Prompt, is_fallback: bool = False): def compile(self, **kwargs) -> Union[str, List[ChatMessage]]: pass + @abstractmethod + def variable_names(self, **kwargs) -> Union[List[str], List[ChatMessageVariables]]: + pass + @abstractmethod def __eq__(self, other): pass @@ -85,47 +157,6 @@ def get_langchain_prompt(self): def _get_langchain_prompt_string(content: str): return re.sub(r"{{\s*(\w+)\s*}}", r"{\g<1>}", content) - @staticmethod - def _compile_template_string(content: str, data: Dict[str, Any] = {}) -> str: - opening = "{{" - closing = "}}" - - result_list = [] - curr_idx = 0 - - while curr_idx < len(content): - # Find the next opening tag - var_start = content.find(opening, curr_idx) - - if var_start == -1: - result_list.append(content[curr_idx:]) - break - - # Find the next closing tag - var_end = content.find(closing, var_start) - - if var_end == -1: - result_list.append(content[curr_idx:]) - break - - # Append the content before the variable - result_list.append(content[curr_idx:var_start]) - - # Extract the variable name - variable_name = content[var_start + len(opening) : var_end].strip() - - # Append the variable value - if variable_name in data: - result_list.append( - str(data[variable_name]) if data[variable_name] is not None else "" - ) - else: - result_list.append(content[var_start : var_end + len(closing)]) - - curr_idx = var_end + len(closing) - - return "".join(result_list) - class TextPromptClient(BasePromptClient): def __init__(self, prompt: Prompt_Text, is_fallback: bool = False): @@ -133,7 +164,15 @@ def __init__(self, prompt: Prompt_Text, is_fallback: bool = False): self.prompt = prompt.prompt def compile(self, **kwargs) -> str: - return self._compile_template_string(self.prompt, kwargs) + return TemplateParser.compile_template(self.prompt, kwargs) + + def variable_names(self) -> List[str]: + """Find all the variable names in the prompt template + + Returns: + List[str]: The list of variable names found in the prompt template + """ + return TemplateParser.find_variable_names(self.prompt) def __eq__(self, other): if isinstance(self, other.__class__): @@ -160,7 +199,7 @@ def get_langchain_prompt(self, **kwargs) -> str: str: The string that can be plugged into Langchain's PromptTemplate. """ prompt = ( - self._compile_template_string(self.prompt, kwargs) + TemplateParser.compile_template(self.prompt, kwargs) if kwargs else self.prompt ) @@ -178,7 +217,23 @@ def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False): def compile(self, **kwargs) -> List[ChatMessageDict]: return [ ChatMessageDict( - content=self._compile_template_string(chat_message["content"], kwargs), + content=TemplateParser.compile_template( + chat_message["content"], kwargs + ), + role=chat_message["role"], + ) + for chat_message in self.prompt + ] + + def variable_names(self) -> List[ChatMessageVariables]: + """Find all the variable names in the chat prompt template per each chat message item + Returns: + List[ChatMessageVariables]: The list of variable names found in the prompt + template coupled with the message role + """ + return [ + ChatMessageVariables( + variables=TemplateParser.find_variable_names(chat_message["content"]), role=chat_message["role"], ) for chat_message in self.prompt @@ -215,7 +270,7 @@ def get_langchain_prompt(self, **kwargs): ( msg["role"], self._get_langchain_prompt_string( - self._compile_template_string(msg["content"], kwargs) + TemplateParser.compile_template(msg["content"], kwargs) if kwargs else msg["content"] ), diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 959dbe8b4..8f1e76029 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -994,3 +994,106 @@ def test_do_not_link_observation_if_fallback(): assert len(trace.observations) == 1 assert trace.observations[0].prompt_id is None + + +def test_variable_names_on_content_with_variable_names(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_1", + prompt="test prompt with var names {{ var1 }} {{ var2 }}", + is_active=True, + type="text", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_1") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variable_names() + + assert var_names == ["var1", "var2"] + + +def test_variable_names_on_content_with_no_variable_names(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_2", + prompt="test prompt with no var names", + is_active=True, + type="text", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_2") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variable_names() + + assert var_names == [] + + +def test_variable_names_on_content_with_variable_names_chat_messages(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_3", + prompt=[ + { + "role": "system", + "content": "test prompt with template vars {{ var1 }} {{ var2 }}", + }, + {"role": "user", "content": "test prompt 2 with template vars {{ var3 }}"}, + ], + is_active=True, + type="chat", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_3") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variable_names() + + assert var_names == [ + {"role": "system", "variables": ["var1", "var2"]}, + {"role": "user", "variables": ["var3"]}, + ] + + +def test_variable_names_on_content_with_no_variable_names_chat_messages(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_4", + prompt=[ + {"role": "system", "content": "test prompt with no template vars"}, + {"role": "user", "content": "test prompt 2 with no template vars"}, + ], + is_active=True, + type="chat", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_4") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variable_names() + + assert var_names == [ + {"role": "system", "variables": []}, + {"role": "user", "variables": []}, + ] diff --git a/tests/test_prompt_compilation.py b/tests/test_prompt_compilation.py index 1c838a50d..856025717 100644 --- a/tests/test_prompt_compilation.py +++ b/tests/test_prompt_compilation.py @@ -1,16 +1,13 @@ import pytest -from langfuse.model import BasePromptClient +from langfuse.model import TemplateParser def test_basic_replacement(): template = "Hello, {{ name }}!" expected = "Hello, John!" - assert ( - BasePromptClient._compile_template_string(template, {"name": "John"}) - == expected - ) + assert TemplateParser.compile_template(template, {"name": "John"}) == expected def test_multiple_replacements(): @@ -18,7 +15,7 @@ def test_multiple_replacements(): expected = "Hello, John! Your balance is $100." assert ( - BasePromptClient._compile_template_string( + TemplateParser.compile_template( template, {"greeting": "Hello", "name": "John", "balance": "$100"} ) == expected @@ -29,103 +26,77 @@ def test_no_replacements(): template = "This is a test." expected = "This is a test." - assert BasePromptClient._compile_template_string(template) == expected + assert TemplateParser.compile_template(template) == expected def test_content_as_variable_name(): template = "This is a {{content}}." expected = "This is a dog." - assert ( - BasePromptClient._compile_template_string(template, {"content": "dog"}) - == expected - ) + assert TemplateParser.compile_template(template, {"content": "dog"}) == expected def test_unmatched_opening_tag(): template = "Hello, {{name! Your balance is $100." expected = "Hello, {{name! Your balance is $100." - assert ( - BasePromptClient._compile_template_string(template, {"name": "John"}) - == expected - ) + assert TemplateParser.compile_template(template, {"name": "John"}) == expected def test_unmatched_closing_tag(): template = "Hello, {{name}}! Your balance is $100}}" expected = "Hello, John! Your balance is $100}}" - assert ( - BasePromptClient._compile_template_string(template, {"name": "John"}) - == expected - ) + assert TemplateParser.compile_template(template, {"name": "John"}) == expected def test_missing_variable(): template = "Hello, {{name}}!" expected = "Hello, {{name}}!" - assert BasePromptClient._compile_template_string(template) == expected + assert TemplateParser.compile_template(template) == expected def test_none_variable(): template = "Hello, {{name}}!" expected = "Hello, !" - assert ( - BasePromptClient._compile_template_string(template, {"name": None}) == expected - ) + assert TemplateParser.compile_template(template, {"name": None}) == expected def test_strip_whitespace(): template = "Hello, {{ name }}!" expected = "Hello, John!" - assert ( - BasePromptClient._compile_template_string(template, {"name": "John"}) - == expected - ) + assert TemplateParser.compile_template(template, {"name": "John"}) == expected def test_special_characters(): template = "Symbols: {{symbol}}." expected = "Symbols: @$%^&*." - assert ( - BasePromptClient._compile_template_string(template, {"symbol": "@$%^&*"}) - == expected - ) + assert TemplateParser.compile_template(template, {"symbol": "@$%^&*"}) == expected def test_multiple_templates_one_var(): template = "{{a}} + {{a}} = {{b}}" expected = "1 + 1 = 2" - assert ( - BasePromptClient._compile_template_string(template, {"a": 1, "b": 2}) - == expected - ) + assert TemplateParser.compile_template(template, {"a": 1, "b": 2}) == expected def test_unused_variable(): template = "{{a}} + {{a}}" expected = "1 + 1" - assert ( - BasePromptClient._compile_template_string(template, {"a": 1, "b": 2}) - == expected - ) + assert TemplateParser.compile_template(template, {"a": 1, "b": 2}) == expected def test_single_curly_braces(): template = "{{a}} + {a} = {{b}" expected = "1 + {a} = {{b}" - assert ( - BasePromptClient._compile_template_string(template, {"a": 1, "b": 2}) - == expected - ) + assert TemplateParser.compile_template(template, {"a": 1, "b": 2}) == expected def test_complex_json(): @@ -138,17 +109,14 @@ def test_complex_json(): "key2": "val2", }}""" - assert ( - BasePromptClient._compile_template_string(template, {"a": 1, "b": 2}) - == expected - ) + assert TemplateParser.compile_template(template, {"a": 1, "b": 2}) == expected def test_replacement_with_empty_string(): template = "Hello, {{name}}!" expected = "Hello, !" - assert BasePromptClient._compile_template_string(template, {"name": ""}) == expected + assert TemplateParser.compile_template(template, {"name": ""}) == expected def test_variable_case_sensitivity(): @@ -156,9 +124,7 @@ def test_variable_case_sensitivity(): expected = "John != john" assert ( - BasePromptClient._compile_template_string( - template, {"Name": "John", "name": "john"} - ) + TemplateParser.compile_template(template, {"Name": "John", "name": "john"}) == expected ) @@ -167,10 +133,7 @@ def test_start_with_closing_braces(): template = "}}" expected = "}}" - assert ( - BasePromptClient._compile_template_string(template, {"name": "john"}) - == expected - ) + assert TemplateParser.compile_template(template, {"name": "john"}) == expected def test_unescaped_JSON_variable_value(): @@ -204,9 +167,7 @@ def test_unescaped_JSON_variable_value(): } }""" - compiled = BasePromptClient._compile_template_string( - template, {"some_json": some_json} - ) + compiled = TemplateParser.compile_template(template, {"some_json": some_json}) assert compiled == some_json @@ -219,4 +180,4 @@ def test_unescaped_JSON_variable_value(): ], ) def test_various_templates(template, data, expected): - assert BasePromptClient._compile_template_string(template, data) == expected + assert TemplateParser.compile_template(template, data) == expected From c8f2658870b29a6e39259c446c0cbc3863349c09 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 16 Dec 2024 14:06:52 +0100 Subject: [PATCH 2/3] return variables as list of strings --- langfuse/model.py | 37 ++++++++++++++----------------------- tests/test_prompt.py | 8 ++++---- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/langfuse/model.py b/langfuse/model.py index 6ffaa3ee2..8a009006f 100644 --- a/langfuse/model.py +++ b/langfuse/model.py @@ -54,11 +54,6 @@ class ChatMessageDict(TypedDict): content: str -class ChatMessageVariables(TypedDict): - role: str - variables: List[str] - - class TemplateParser: OPENING = "{{" CLOSING = "}}" @@ -96,7 +91,10 @@ def find_variable_names(content: str) -> List[str]: return names @staticmethod - def compile_template(content: str, data: Dict[str, Any] = {}) -> str: + def compile_template(content: str, data: Optional[Dict[str, Any]]) -> str: + if data is None: + return content + result_list = [] curr_idx = 0 @@ -141,8 +139,9 @@ def __init__(self, prompt: Prompt, is_fallback: bool = False): def compile(self, **kwargs) -> Union[str, List[ChatMessage]]: pass + @property @abstractmethod - def variable_names(self, **kwargs) -> Union[List[str], List[ChatMessageVariables]]: + def variables(self) -> List[str]: pass @abstractmethod @@ -166,12 +165,9 @@ def __init__(self, prompt: Prompt_Text, is_fallback: bool = False): def compile(self, **kwargs) -> str: return TemplateParser.compile_template(self.prompt, kwargs) - def variable_names(self) -> List[str]: - """Find all the variable names in the prompt template - - Returns: - List[str]: The list of variable names found in the prompt template - """ + @property + def variables(self) -> List[str]: + """Return all the variable names in the prompt template.""" return TemplateParser.find_variable_names(self.prompt) def __eq__(self, other): @@ -225,18 +221,13 @@ def compile(self, **kwargs) -> List[ChatMessageDict]: for chat_message in self.prompt ] - def variable_names(self) -> List[ChatMessageVariables]: - """Find all the variable names in the chat prompt template per each chat message item - Returns: - List[ChatMessageVariables]: The list of variable names found in the prompt - template coupled with the message role - """ + @property + def variables(self) -> List[str]: + """Return all the variable names in the chat prompt template.""" return [ - ChatMessageVariables( - variables=TemplateParser.find_variable_names(chat_message["content"]), - role=chat_message["role"], - ) + variable for chat_message in self.prompt + for variable in TemplateParser.find_variable_names(chat_message["content"]) ] def __eq__(self, other): diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 8f1e76029..da8ac9fbd 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1013,7 +1013,7 @@ def test_variable_names_on_content_with_variable_names(): assert prompt_client.prompt == second_prompt_client.prompt assert prompt_client.labels == ["production", "latest"] - var_names = second_prompt_client.variable_names() + var_names = second_prompt_client.variables assert var_names == ["var1", "var2"] @@ -1035,7 +1035,7 @@ def test_variable_names_on_content_with_no_variable_names(): assert prompt_client.prompt == second_prompt_client.prompt assert prompt_client.labels == ["production", "latest"] - var_names = second_prompt_client.variable_names() + var_names = second_prompt_client.variables assert var_names == [] @@ -1063,7 +1063,7 @@ def test_variable_names_on_content_with_variable_names_chat_messages(): assert prompt_client.prompt == second_prompt_client.prompt assert prompt_client.labels == ["production", "latest"] - var_names = second_prompt_client.variable_names() + var_names = second_prompt_client.variables assert var_names == [ {"role": "system", "variables": ["var1", "var2"]}, @@ -1091,7 +1091,7 @@ def test_variable_names_on_content_with_no_variable_names_chat_messages(): assert prompt_client.prompt == second_prompt_client.prompt assert prompt_client.labels == ["production", "latest"] - var_names = second_prompt_client.variable_names() + var_names = second_prompt_client.variables assert var_names == [ {"role": "system", "variables": []}, From 81e9006be4e4fa0153d9314502c477aa85d88483 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Mon, 16 Dec 2024 14:46:51 +0100 Subject: [PATCH 3/3] fix tests --- langfuse/model.py | 2 +- tests/test_prompt.py | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/langfuse/model.py b/langfuse/model.py index 8a009006f..712a4ea2a 100644 --- a/langfuse/model.py +++ b/langfuse/model.py @@ -91,7 +91,7 @@ def find_variable_names(content: str) -> List[str]: return names @staticmethod - def compile_template(content: str, data: Optional[Dict[str, Any]]) -> str: + def compile_template(content: str, data: Optional[Dict[str, Any]] = None) -> str: if data is None: return content diff --git a/tests/test_prompt.py b/tests/test_prompt.py index da8ac9fbd..ad7cf9a92 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1065,10 +1065,7 @@ def test_variable_names_on_content_with_variable_names_chat_messages(): var_names = second_prompt_client.variables - assert var_names == [ - {"role": "system", "variables": ["var1", "var2"]}, - {"role": "user", "variables": ["var3"]}, - ] + assert var_names == ["var1", "var2", "var3"] def test_variable_names_on_content_with_no_variable_names_chat_messages(): @@ -1093,7 +1090,4 @@ def test_variable_names_on_content_with_no_variable_names_chat_messages(): var_names = second_prompt_client.variables - assert var_names == [ - {"role": "system", "variables": []}, - {"role": "user", "variables": []}, - ] + assert var_names == []