@@ -99,12 +99,7 @@ def __init__(self, tokenizer: TokenizerLike):
9999 self .bot_token = "[TOOL_CALLS]"
100100 self .bot_token_id = self .vocab .get (self .bot_token )
101101 self .tool_call_regex = re .compile (r"\[{.*}\]" , re .DOTALL )
102- if not _is_pre_v11_tokeniser (self .model_tokenizer ):
103- self .fn_name_regex = re .compile (
104- r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)" , re .DOTALL
105- )
106- else :
107- self .fn_name_regex = None
102+ self ._is_pre_v11 = _is_pre_v11_tokeniser (self .model_tokenizer )
108103
109104 if self .bot_token_id is None :
110105 raise RuntimeError (
@@ -148,23 +143,24 @@ def extract_tool_calls(
148143 tool_content = model_output .replace (self .bot_token , "" ).strip ()
149144
150145 try :
151- # we first try to directly load the json as parsing very nested
152- # jsons is difficult
153146 try :
154- if self .fn_name_regex :
147+ if not self ._is_pre_v11 :
155148 function_call_arr = []
156149 for single_tool_content in model_output .split (self .bot_token ):
157- matches = self .fn_name_regex .findall (single_tool_content )
158-
159- for match in matches :
160- fn_name = match [0 ]
161- args = match [1 ]
162-
163- # fn_name is encoded outside serialized json dump
164- # only arguments are serialized
165- function_call_arr .append (
166- {"name" : fn_name , "arguments" : json .loads (args )}
167- )
150+ if "{" not in single_tool_content :
151+ continue
152+
153+ end_name = single_tool_content .find ("{" )
154+ fn_name , args = (
155+ single_tool_content [:end_name ],
156+ single_tool_content [end_name :],
157+ )
158+
159+ # fn_name is encoded outside serialized json dump
160+ # only arguments are serialized
161+ function_call_arr .append (
162+ {"name" : fn_name , "arguments" : json .loads (args )}
163+ )
168164 else :
169165 function_call_arr = json .loads (tool_content )
170166 except json .JSONDecodeError :
0 commit comments