|
37 | 37 |
|
38 | 38 | from ._logger import logger |
39 | 39 | from ._utils import suppress_stdout_stderr, Singleton |
| 40 | +from .llama_chat_template import chat_formatter_to_chat_completion_handler, ChatFormatterResponse, LlamaChatCompletionHandlerRegistry, ChatFormatter, LlamaChatCompletionHandler, LlamaChatCompletionHandlerRegistry, register_chat_completion_handler, Jinja2ChatFormatter |
40 | 41 |
|
41 | 42 | ### Common Chat Templates and Special Tokens ### |
42 | 43 |
|
|
59 | 60 | ### Chat Completion Handler ### |
60 | 61 |
|
61 | 62 |
|
62 | | -class LlamaChatCompletionHandler(Protocol): |
63 | | - """Base Protocol for a llama chat completion handler. |
64 | | -
|
65 | | - Very generic protocol that can be used to implement any chat format. |
66 | | - The only hard requirement is that it must return a ChatCompletion when |
67 | | - stream=False and an iterator of ChatCompletionChunks when stream=True.""" |
68 | | - |
69 | | - def __call__( |
70 | | - self, |
71 | | - *, |
72 | | - # llama.cpp instance |
73 | | - llama: llama.Llama, |
74 | | - # openai api parameters |
75 | | - messages: List[llama_types.ChatCompletionRequestMessage], |
76 | | - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, |
77 | | - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, |
78 | | - tools: Optional[List[llama_types.ChatCompletionTool]] = None, |
79 | | - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, |
80 | | - temperature: float = 0.2, |
81 | | - top_p: float = 0.95, |
82 | | - top_k: int = 40, |
83 | | - stream: bool = False, |
84 | | - stop: Optional[Union[str, List[str]]] = [], |
85 | | - seed: Optional[int] = None, |
86 | | - response_format: Optional[ |
87 | | - llama_types.ChatCompletionRequestResponseFormat |
88 | | - ] = None, |
89 | | - max_tokens: Optional[int] = None, |
90 | | - presence_penalty: float = 0.0, |
91 | | - frequency_penalty: float = 0.0, |
92 | | - repeat_penalty: float = 1.1, |
93 | | - model: Optional[str] = None, |
94 | | - logit_bias: Optional[Dict[str, float]] = None, |
95 | | - # llama.cpp parameters |
96 | | - min_p: float = 0.05, |
97 | | - typical_p: float = 1.0, |
98 | | - tfs_z: float = 1.0, |
99 | | - mirostat_mode: int = 0, |
100 | | - mirostat_tau: float = 5.0, |
101 | | - mirostat_eta: float = 0.1, |
102 | | - logits_processor: Optional[llama.LogitsProcessorList] = None, |
103 | | - grammar: Optional[llama.LlamaGrammar] = None, |
104 | | - logprobs: Optional[bool] = None, |
105 | | - top_logprobs: Optional[int] = None, |
106 | | - **kwargs, # type: ignore |
107 | | - ) -> Union[ |
108 | | - llama_types.CreateChatCompletionResponse, |
109 | | - Iterator[llama_types.CreateChatCompletionStreamResponse], |
110 | | - ]: ... |
111 | | - |
112 | | - |
113 | | -class LlamaChatCompletionHandlerNotFoundException(Exception): |
114 | | - pass |
115 | | - |
116 | | - |
117 | | -class LlamaChatCompletionHandlerRegistry(Singleton): |
118 | | - _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {} |
119 | | - |
120 | | - def register_chat_completion_handler( |
121 | | - self, |
122 | | - name: str, |
123 | | - chat_handler: LlamaChatCompletionHandler, |
124 | | - overwrite: bool = False, |
125 | | - ): |
126 | | - if not overwrite and name in self._chat_handlers: |
127 | | - raise ValueError( |
128 | | - f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." |
129 | | - ) |
130 | | - self._chat_handlers[name] = chat_handler |
131 | | - |
132 | | - def unregister_chat_handler(self, name: str): |
133 | | - if name in self._chat_handlers: |
134 | | - del self._chat_handlers[name] |
135 | | - else: |
136 | | - raise ValueError(f"No formatter registered under the name '{name}'.") |
137 | | - |
138 | | - def get_chat_completion_handler_by_name( |
139 | | - self, name: str |
140 | | - ) -> LlamaChatCompletionHandler: |
141 | | - try: |
142 | | - chat_handler = self._chat_handlers[name] |
143 | | - return chat_handler |
144 | | - except KeyError: |
145 | | - raise LlamaChatCompletionHandlerNotFoundException( |
146 | | - f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})" |
147 | | - ) |
148 | | - |
149 | | - |
150 | | -def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler: |
151 | | - return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name( |
152 | | - name |
153 | | - ) |
154 | | - |
155 | | - |
156 | | -def register_chat_completion_handler(name: str): |
157 | | - def decorator(f: LlamaChatCompletionHandler): |
158 | | - LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f) |
159 | | - return f |
160 | | - |
161 | | - return decorator |
162 | | - |
163 | | - |
164 | | -### Chat Formatter ### |
165 | | - |
166 | | - |
167 | | -@dataclasses.dataclass |
168 | | -class ChatFormatterResponse: |
169 | | - """Dataclass that stores completion parameters for a given chat format and |
170 | | - create_chat_completion request. |
171 | | -
|
172 | | - prompt contains the formatted prompt generated from the chat format and messages. |
173 | | - stop contains the stop token or list of stop tokens to use for the chat format.""" |
174 | | - |
175 | | - prompt: str |
176 | | - stop: Optional[Union[str, List[str]]] = None |
177 | | - stopping_criteria: Optional[llama.StoppingCriteriaList] = None |
178 | | - added_special: bool = False |
179 | | - |
180 | | - |
181 | | -class ChatFormatter(Protocol): |
182 | | - """Base Protocol for a chat formatter. A chat formatter is a function that |
183 | | - takes a list of messages and returns a chat format response which can be used |
184 | | - to generate a completion. The response can also include a stop token or list |
185 | | - of stop tokens to use for the completion.""" |
186 | | - |
187 | | - def __call__( |
188 | | - self, |
189 | | - *, |
190 | | - messages: List[llama_types.ChatCompletionRequestMessage], |
191 | | - **kwargs: Any, |
192 | | - ) -> ChatFormatterResponse: ... |
193 | | - |
194 | | - |
195 | | -class Jinja2ChatFormatter(ChatFormatter): |
196 | | - def __init__( |
197 | | - self, |
198 | | - template: str, |
199 | | - eos_token: str, |
200 | | - bos_token: str, |
201 | | - add_generation_prompt: bool = True, |
202 | | - stop_token_ids: Optional[List[int]] = None, |
203 | | - ): |
204 | | - """A chat formatter that uses jinja2 templates to format the prompt.""" |
205 | | - self.template = template |
206 | | - self.eos_token = eos_token |
207 | | - self.bos_token = bos_token |
208 | | - self.add_generation_prompt = add_generation_prompt |
209 | | - self.stop_token_ids = ( |
210 | | - set(stop_token_ids) if stop_token_ids is not None else None |
211 | | - ) |
212 | | - |
213 | | - self._environment = ImmutableSandboxedEnvironment( |
214 | | - loader=jinja2.BaseLoader(), |
215 | | - trim_blocks=True, |
216 | | - lstrip_blocks=True, |
217 | | - ).from_string(self.template) |
218 | | - |
219 | | - @staticmethod |
220 | | - def strftime_now(f: str) -> str: |
221 | | - return datetime.now().strftime(f) |
222 | | - |
223 | | - def __call__( |
224 | | - self, |
225 | | - *, |
226 | | - messages: List[llama_types.ChatCompletionRequestMessage], |
227 | | - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, |
228 | | - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, |
229 | | - tools: Optional[List[llama_types.ChatCompletionTool]] = None, |
230 | | - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, |
231 | | - **kwargs: Any, |
232 | | - ) -> ChatFormatterResponse: |
233 | | - def raise_exception(message: str): |
234 | | - raise ValueError(message) |
235 | | - |
236 | | - prompt = self._environment.render( |
237 | | - messages=messages, |
238 | | - eos_token=self.eos_token, |
239 | | - bos_token=self.bos_token, |
240 | | - raise_exception=raise_exception, |
241 | | - add_generation_prompt=self.add_generation_prompt, |
242 | | - functions=functions, |
243 | | - function_call=function_call, |
244 | | - tools=tools, |
245 | | - tool_choice=tool_choice, |
246 | | - strftime_now=self.strftime_now, |
247 | | - ) |
248 | | - |
249 | | - stopping_criteria = None |
250 | | - if self.stop_token_ids is not None: |
251 | | - |
252 | | - def stop_on_last_token( |
253 | | - tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single] |
254 | | - ) -> bool: |
255 | | - return tokens[-1] in self.stop_token_ids |
256 | | - |
257 | | - stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token]) |
258 | | - |
259 | | - return ChatFormatterResponse( |
260 | | - prompt=prompt, |
261 | | - stop=[self.eos_token], |
262 | | - stopping_criteria=stopping_criteria, |
263 | | - added_special=True, |
264 | | - ) |
265 | | - |
266 | | - def to_chat_handler(self) -> LlamaChatCompletionHandler: |
267 | | - return chat_formatter_to_chat_completion_handler(self) |
268 | 63 |
|
269 | 64 |
|
270 | 65 | def _convert_text_completion_logprobs_to_chat( |
@@ -356,7 +151,7 @@ def _convert_text_completion_chunks_to_chat( |
356 | 151 | "finish_reason": chunk["choices"][0]["finish_reason"], |
357 | 152 | } |
358 | 153 | ], |
359 | | - "usage": chunk.get("usage") if "usage" in chunk else None, |
| 154 | + **({"usage": chunk["usage"]} if "usage" in chunk and chunk["usage"] is not None else {}), |
360 | 155 | } |
361 | 156 |
|
362 | 157 |
|
@@ -568,154 +363,6 @@ def _stream_response_to_function_stream( |
568 | 363 | return _stream_response_to_function_stream(chunks) |
569 | 364 |
|
570 | 365 |
|
571 | | -def chat_formatter_to_chat_completion_handler( |
572 | | - chat_formatter: ChatFormatter, |
573 | | -) -> LlamaChatCompletionHandler: |
574 | | - def chat_completion_handler( |
575 | | - *, |
576 | | - llama: llama.Llama, |
577 | | - messages: List[llama_types.ChatCompletionRequestMessage], |
578 | | - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, |
579 | | - function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, |
580 | | - tools: Optional[List[llama_types.ChatCompletionTool]] = None, |
581 | | - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, |
582 | | - temperature: float = 0.2, |
583 | | - top_p: float = 0.95, |
584 | | - top_k: int = 40, |
585 | | - min_p: float = 0.05, |
586 | | - typical_p: float = 1.0, |
587 | | - stream: bool = False, |
588 | | - stop: Optional[Union[str, List[str]]] = [], |
589 | | - seed: Optional[int] = None, |
590 | | - response_format: Optional[ |
591 | | - llama_types.ChatCompletionRequestResponseFormat |
592 | | - ] = None, |
593 | | - max_tokens: Optional[int] = None, |
594 | | - presence_penalty: float = 0.0, |
595 | | - frequency_penalty: float = 0.0, |
596 | | - repeat_penalty: float = 1.1, |
597 | | - tfs_z: float = 1.0, |
598 | | - mirostat_mode: int = 0, |
599 | | - mirostat_tau: float = 5.0, |
600 | | - mirostat_eta: float = 0.1, |
601 | | - model: Optional[str] = None, |
602 | | - logits_processor: Optional[llama.LogitsProcessorList] = None, |
603 | | - grammar: Optional[llama.LlamaGrammar] = None, |
604 | | - logit_bias: Optional[Dict[str, float]] = None, |
605 | | - logprobs: Optional[bool] = None, |
606 | | - top_logprobs: Optional[int] = None, |
607 | | - **kwargs, # type: ignore |
608 | | - ) -> Union[ |
609 | | - llama_types.CreateChatCompletionResponse, |
610 | | - Iterator[llama_types.CreateChatCompletionStreamResponse], |
611 | | - ]: |
612 | | - result = chat_formatter( |
613 | | - messages=messages, |
614 | | - functions=functions, |
615 | | - function_call=function_call, |
616 | | - tools=tools, |
617 | | - tool_choice=tool_choice, |
618 | | - ) |
619 | | - prompt = llama.tokenize( |
620 | | - result.prompt.encode("utf-8"), |
621 | | - add_bos=not result.added_special, |
622 | | - special=True, |
623 | | - ) |
624 | | - if result.stop is not None: |
625 | | - stop = [] if stop is None else [stop] if isinstance(stop, str) else stop |
626 | | - rstop = result.stop if isinstance(result.stop, list) else [result.stop] |
627 | | - stop = stop + rstop |
628 | | - |
629 | | - stopping_criteria = None |
630 | | - if result.stopping_criteria is not None: |
631 | | - stopping_criteria = result.stopping_criteria |
632 | | - |
633 | | - if response_format is not None and response_format["type"] == "json_object": |
634 | | - grammar = _grammar_for_response_format( |
635 | | - response_format, verbose=llama.verbose |
636 | | - ) |
637 | | - |
638 | | - # Convert legacy functions to tools |
639 | | - if functions is not None: |
640 | | - tools = [ |
641 | | - { |
642 | | - "type": "function", |
643 | | - "function": function, |
644 | | - } |
645 | | - for function in functions |
646 | | - ] |
647 | | - |
648 | | - # Convert legacy function_call to tool_choice |
649 | | - if function_call is not None: |
650 | | - if isinstance(function_call, str) and ( |
651 | | - function_call == "none" or function_call == "auto" |
652 | | - ): |
653 | | - tool_choice = function_call |
654 | | - if isinstance(function_call, dict) and "name" in function_call: |
655 | | - tool_choice = { |
656 | | - "type": "function", |
657 | | - "function": { |
658 | | - "name": function_call["name"], |
659 | | - }, |
660 | | - } |
661 | | - |
662 | | - tool = None |
663 | | - if ( |
664 | | - tool_choice is not None |
665 | | - and isinstance(tool_choice, dict) |
666 | | - and tools is not None |
667 | | - ): |
668 | | - name = tool_choice["function"]["name"] |
669 | | - tool = next((t for t in tools if t["function"]["name"] == name), None) |
670 | | - if tool is None: |
671 | | - raise ValueError(f"Tool choice '{name}' not found in tools.") |
672 | | - schema = tool["function"]["parameters"] |
673 | | - try: |
674 | | - # create grammar from json schema |
675 | | - grammar = llama_grammar.LlamaGrammar.from_json_schema( |
676 | | - json.dumps(schema), verbose=llama.verbose |
677 | | - ) |
678 | | - except Exception as e: |
679 | | - if llama.verbose: |
680 | | - print(str(e), file=sys.stderr) |
681 | | - grammar = llama_grammar.LlamaGrammar.from_string( |
682 | | - llama_grammar.JSON_GBNF, verbose=llama.verbose |
683 | | - ) |
684 | | - |
685 | | - completion_or_chunks = llama.create_completion( |
686 | | - prompt=prompt, |
687 | | - temperature=temperature, |
688 | | - top_p=top_p, |
689 | | - top_k=top_k, |
690 | | - min_p=min_p, |
691 | | - typical_p=typical_p, |
692 | | - logprobs=top_logprobs if logprobs else None, |
693 | | - stream=stream, |
694 | | - stop=stop, |
695 | | - seed=seed, |
696 | | - max_tokens=max_tokens, |
697 | | - presence_penalty=presence_penalty, |
698 | | - frequency_penalty=frequency_penalty, |
699 | | - repeat_penalty=repeat_penalty, |
700 | | - tfs_z=tfs_z, |
701 | | - mirostat_mode=mirostat_mode, |
702 | | - mirostat_tau=mirostat_tau, |
703 | | - mirostat_eta=mirostat_eta, |
704 | | - model=model, |
705 | | - logits_processor=logits_processor, |
706 | | - stopping_criteria=stopping_criteria, |
707 | | - grammar=grammar, |
708 | | - logit_bias=logit_bias, |
709 | | - ) |
710 | | - if tool is not None: |
711 | | - tool_name = tool["function"]["name"] |
712 | | - return _convert_completion_to_chat_function( |
713 | | - tool_name, completion_or_chunks, stream |
714 | | - ) |
715 | | - return _convert_completion_to_chat(completion_or_chunks, stream=stream) |
716 | | - |
717 | | - return chat_completion_handler |
718 | | - |
719 | 366 |
|
720 | 367 | def hf_autotokenizer_to_chat_formatter( |
721 | 368 | pretrained_model_name_or_path: Union[str, os.PathLike[str]] |
|
0 commit comments