Skip to content

Commit 8cc571f

Browse files
committed
checkpoint
1 parent 066638c commit 8cc571f

File tree

3 files changed

+756
-361
lines changed

3 files changed

+756
-361
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 2 additions & 355 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from ._logger import logger
3939
from ._utils import suppress_stdout_stderr, Singleton
40+
from .llama_chat_template import chat_formatter_to_chat_completion_handler, ChatFormatterResponse, LlamaChatCompletionHandlerRegistry, ChatFormatter, LlamaChatCompletionHandler, LlamaChatCompletionHandlerRegistry, register_chat_completion_handler, Jinja2ChatFormatter
4041

4142
### Common Chat Templates and Special Tokens ###
4243

@@ -59,212 +60,6 @@
5960
### Chat Completion Handler ###
6061

6162

62-
class LlamaChatCompletionHandler(Protocol):
63-
"""Base Protocol for a llama chat completion handler.
64-
65-
Very generic protocol that can be used to implement any chat format.
66-
The only hard requirement is that it must return a ChatCompletion when
67-
stream=False and an iterator of ChatCompletionChunks when stream=True."""
68-
69-
def __call__(
70-
self,
71-
*,
72-
# llama.cpp instance
73-
llama: llama.Llama,
74-
# openai api parameters
75-
messages: List[llama_types.ChatCompletionRequestMessage],
76-
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
77-
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
78-
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
79-
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
80-
temperature: float = 0.2,
81-
top_p: float = 0.95,
82-
top_k: int = 40,
83-
stream: bool = False,
84-
stop: Optional[Union[str, List[str]]] = [],
85-
seed: Optional[int] = None,
86-
response_format: Optional[
87-
llama_types.ChatCompletionRequestResponseFormat
88-
] = None,
89-
max_tokens: Optional[int] = None,
90-
presence_penalty: float = 0.0,
91-
frequency_penalty: float = 0.0,
92-
repeat_penalty: float = 1.1,
93-
model: Optional[str] = None,
94-
logit_bias: Optional[Dict[str, float]] = None,
95-
# llama.cpp parameters
96-
min_p: float = 0.05,
97-
typical_p: float = 1.0,
98-
tfs_z: float = 1.0,
99-
mirostat_mode: int = 0,
100-
mirostat_tau: float = 5.0,
101-
mirostat_eta: float = 0.1,
102-
logits_processor: Optional[llama.LogitsProcessorList] = None,
103-
grammar: Optional[llama.LlamaGrammar] = None,
104-
logprobs: Optional[bool] = None,
105-
top_logprobs: Optional[int] = None,
106-
**kwargs, # type: ignore
107-
) -> Union[
108-
llama_types.CreateChatCompletionResponse,
109-
Iterator[llama_types.CreateChatCompletionStreamResponse],
110-
]: ...
111-
112-
113-
class LlamaChatCompletionHandlerNotFoundException(Exception):
114-
pass
115-
116-
117-
class LlamaChatCompletionHandlerRegistry(Singleton):
118-
_chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
119-
120-
def register_chat_completion_handler(
121-
self,
122-
name: str,
123-
chat_handler: LlamaChatCompletionHandler,
124-
overwrite: bool = False,
125-
):
126-
if not overwrite and name in self._chat_handlers:
127-
raise ValueError(
128-
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
129-
)
130-
self._chat_handlers[name] = chat_handler
131-
132-
def unregister_chat_handler(self, name: str):
133-
if name in self._chat_handlers:
134-
del self._chat_handlers[name]
135-
else:
136-
raise ValueError(f"No formatter registered under the name '{name}'.")
137-
138-
def get_chat_completion_handler_by_name(
139-
self, name: str
140-
) -> LlamaChatCompletionHandler:
141-
try:
142-
chat_handler = self._chat_handlers[name]
143-
return chat_handler
144-
except KeyError:
145-
raise LlamaChatCompletionHandlerNotFoundException(
146-
f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
147-
)
148-
149-
150-
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
151-
return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
152-
name
153-
)
154-
155-
156-
def register_chat_completion_handler(name: str):
157-
def decorator(f: LlamaChatCompletionHandler):
158-
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
159-
return f
160-
161-
return decorator
162-
163-
164-
### Chat Formatter ###
165-
166-
167-
@dataclasses.dataclass
168-
class ChatFormatterResponse:
169-
"""Dataclass that stores completion parameters for a given chat format and
170-
create_chat_completion request.
171-
172-
prompt contains the formatted prompt generated from the chat format and messages.
173-
stop contains the stop token or list of stop tokens to use for the chat format."""
174-
175-
prompt: str
176-
stop: Optional[Union[str, List[str]]] = None
177-
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
178-
added_special: bool = False
179-
180-
181-
class ChatFormatter(Protocol):
182-
"""Base Protocol for a chat formatter. A chat formatter is a function that
183-
takes a list of messages and returns a chat format response which can be used
184-
to generate a completion. The response can also include a stop token or list
185-
of stop tokens to use for the completion."""
186-
187-
def __call__(
188-
self,
189-
*,
190-
messages: List[llama_types.ChatCompletionRequestMessage],
191-
**kwargs: Any,
192-
) -> ChatFormatterResponse: ...
193-
194-
195-
class Jinja2ChatFormatter(ChatFormatter):
196-
def __init__(
197-
self,
198-
template: str,
199-
eos_token: str,
200-
bos_token: str,
201-
add_generation_prompt: bool = True,
202-
stop_token_ids: Optional[List[int]] = None,
203-
):
204-
"""A chat formatter that uses jinja2 templates to format the prompt."""
205-
self.template = template
206-
self.eos_token = eos_token
207-
self.bos_token = bos_token
208-
self.add_generation_prompt = add_generation_prompt
209-
self.stop_token_ids = (
210-
set(stop_token_ids) if stop_token_ids is not None else None
211-
)
212-
213-
self._environment = ImmutableSandboxedEnvironment(
214-
loader=jinja2.BaseLoader(),
215-
trim_blocks=True,
216-
lstrip_blocks=True,
217-
).from_string(self.template)
218-
219-
@staticmethod
220-
def strftime_now(f: str) -> str:
221-
return datetime.now().strftime(f)
222-
223-
def __call__(
224-
self,
225-
*,
226-
messages: List[llama_types.ChatCompletionRequestMessage],
227-
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
228-
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
229-
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
230-
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
231-
**kwargs: Any,
232-
) -> ChatFormatterResponse:
233-
def raise_exception(message: str):
234-
raise ValueError(message)
235-
236-
prompt = self._environment.render(
237-
messages=messages,
238-
eos_token=self.eos_token,
239-
bos_token=self.bos_token,
240-
raise_exception=raise_exception,
241-
add_generation_prompt=self.add_generation_prompt,
242-
functions=functions,
243-
function_call=function_call,
244-
tools=tools,
245-
tool_choice=tool_choice,
246-
strftime_now=self.strftime_now,
247-
)
248-
249-
stopping_criteria = None
250-
if self.stop_token_ids is not None:
251-
252-
def stop_on_last_token(
253-
tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
254-
) -> bool:
255-
return tokens[-1] in self.stop_token_ids
256-
257-
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
258-
259-
return ChatFormatterResponse(
260-
prompt=prompt,
261-
stop=[self.eos_token],
262-
stopping_criteria=stopping_criteria,
263-
added_special=True,
264-
)
265-
266-
def to_chat_handler(self) -> LlamaChatCompletionHandler:
267-
return chat_formatter_to_chat_completion_handler(self)
26863

26964

27065
def _convert_text_completion_logprobs_to_chat(
@@ -356,7 +151,7 @@ def _convert_text_completion_chunks_to_chat(
356151
"finish_reason": chunk["choices"][0]["finish_reason"],
357152
}
358153
],
359-
"usage": chunk.get("usage") if "usage" in chunk else None,
154+
**({"usage": chunk["usage"]} if "usage" in chunk and chunk["usage"] is not None else {}),
360155
}
361156

362157

@@ -568,154 +363,6 @@ def _stream_response_to_function_stream(
568363
return _stream_response_to_function_stream(chunks)
569364

570365

571-
def chat_formatter_to_chat_completion_handler(
572-
chat_formatter: ChatFormatter,
573-
) -> LlamaChatCompletionHandler:
574-
def chat_completion_handler(
575-
*,
576-
llama: llama.Llama,
577-
messages: List[llama_types.ChatCompletionRequestMessage],
578-
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
579-
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
580-
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
581-
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
582-
temperature: float = 0.2,
583-
top_p: float = 0.95,
584-
top_k: int = 40,
585-
min_p: float = 0.05,
586-
typical_p: float = 1.0,
587-
stream: bool = False,
588-
stop: Optional[Union[str, List[str]]] = [],
589-
seed: Optional[int] = None,
590-
response_format: Optional[
591-
llama_types.ChatCompletionRequestResponseFormat
592-
] = None,
593-
max_tokens: Optional[int] = None,
594-
presence_penalty: float = 0.0,
595-
frequency_penalty: float = 0.0,
596-
repeat_penalty: float = 1.1,
597-
tfs_z: float = 1.0,
598-
mirostat_mode: int = 0,
599-
mirostat_tau: float = 5.0,
600-
mirostat_eta: float = 0.1,
601-
model: Optional[str] = None,
602-
logits_processor: Optional[llama.LogitsProcessorList] = None,
603-
grammar: Optional[llama.LlamaGrammar] = None,
604-
logit_bias: Optional[Dict[str, float]] = None,
605-
logprobs: Optional[bool] = None,
606-
top_logprobs: Optional[int] = None,
607-
**kwargs, # type: ignore
608-
) -> Union[
609-
llama_types.CreateChatCompletionResponse,
610-
Iterator[llama_types.CreateChatCompletionStreamResponse],
611-
]:
612-
result = chat_formatter(
613-
messages=messages,
614-
functions=functions,
615-
function_call=function_call,
616-
tools=tools,
617-
tool_choice=tool_choice,
618-
)
619-
prompt = llama.tokenize(
620-
result.prompt.encode("utf-8"),
621-
add_bos=not result.added_special,
622-
special=True,
623-
)
624-
if result.stop is not None:
625-
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
626-
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
627-
stop = stop + rstop
628-
629-
stopping_criteria = None
630-
if result.stopping_criteria is not None:
631-
stopping_criteria = result.stopping_criteria
632-
633-
if response_format is not None and response_format["type"] == "json_object":
634-
grammar = _grammar_for_response_format(
635-
response_format, verbose=llama.verbose
636-
)
637-
638-
# Convert legacy functions to tools
639-
if functions is not None:
640-
tools = [
641-
{
642-
"type": "function",
643-
"function": function,
644-
}
645-
for function in functions
646-
]
647-
648-
# Convert legacy function_call to tool_choice
649-
if function_call is not None:
650-
if isinstance(function_call, str) and (
651-
function_call == "none" or function_call == "auto"
652-
):
653-
tool_choice = function_call
654-
if isinstance(function_call, dict) and "name" in function_call:
655-
tool_choice = {
656-
"type": "function",
657-
"function": {
658-
"name": function_call["name"],
659-
},
660-
}
661-
662-
tool = None
663-
if (
664-
tool_choice is not None
665-
and isinstance(tool_choice, dict)
666-
and tools is not None
667-
):
668-
name = tool_choice["function"]["name"]
669-
tool = next((t for t in tools if t["function"]["name"] == name), None)
670-
if tool is None:
671-
raise ValueError(f"Tool choice '{name}' not found in tools.")
672-
schema = tool["function"]["parameters"]
673-
try:
674-
# create grammar from json schema
675-
grammar = llama_grammar.LlamaGrammar.from_json_schema(
676-
json.dumps(schema), verbose=llama.verbose
677-
)
678-
except Exception as e:
679-
if llama.verbose:
680-
print(str(e), file=sys.stderr)
681-
grammar = llama_grammar.LlamaGrammar.from_string(
682-
llama_grammar.JSON_GBNF, verbose=llama.verbose
683-
)
684-
685-
completion_or_chunks = llama.create_completion(
686-
prompt=prompt,
687-
temperature=temperature,
688-
top_p=top_p,
689-
top_k=top_k,
690-
min_p=min_p,
691-
typical_p=typical_p,
692-
logprobs=top_logprobs if logprobs else None,
693-
stream=stream,
694-
stop=stop,
695-
seed=seed,
696-
max_tokens=max_tokens,
697-
presence_penalty=presence_penalty,
698-
frequency_penalty=frequency_penalty,
699-
repeat_penalty=repeat_penalty,
700-
tfs_z=tfs_z,
701-
mirostat_mode=mirostat_mode,
702-
mirostat_tau=mirostat_tau,
703-
mirostat_eta=mirostat_eta,
704-
model=model,
705-
logits_processor=logits_processor,
706-
stopping_criteria=stopping_criteria,
707-
grammar=grammar,
708-
logit_bias=logit_bias,
709-
)
710-
if tool is not None:
711-
tool_name = tool["function"]["name"]
712-
return _convert_completion_to_chat_function(
713-
tool_name, completion_or_chunks, stream
714-
)
715-
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
716-
717-
return chat_completion_handler
718-
719366

720367
def hf_autotokenizer_to_chat_formatter(
721368
pretrained_model_name_or_path: Union[str, os.PathLike[str]]

0 commit comments

Comments
 (0)