-
Notifications
You must be signed in to change notification settings - Fork 231
devstral tool parser for tool calling #3851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e3fb518
bf74839
28cd83b
33a1062
104c980
a150c4d
ccc71d3
6d09cea
9a55924
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <openvino/genai/generation_config.hpp> | ||
|
|
||
| #include "generation_config_builder.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| void DevstralGenerationConfigBuilder::parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) { | ||
| // Call the base class method to fill in common configuration | ||
| BaseGenerationConfigBuilder::parseConfigFromRequest(request); | ||
|
|
||
| // For now the only specific part is related to tools, so if there are no tools provided in the request | ||
| // we can exit early | ||
| if (request.toolNameSchemaMap.empty()) { | ||
| return; | ||
| } | ||
|
|
||
| if (enableToolGuidedGeneration || request.toolChoice == "required") { | ||
| // Set tool guided generation config specific to Devstral model | ||
| auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>(); | ||
| triggeredTags->triggers.push_back("[TOOL_CALLS]"); | ||
|
|
||
| for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) { | ||
| const auto& toolSchema = toolSchemaWrapper.stringRepr; | ||
| ov::genai::StructuredOutputConfig::Tag tagItem; | ||
| tagItem.begin = "[TOOL_CALLS]" + toolName + "[ARGS]"; | ||
| tagItem.end = ""; | ||
| tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema); | ||
| triggeredTags->tags.push_back(tagItem); | ||
| } | ||
| if (request.toolChoice == "required") { | ||
| triggeredTags->at_least_one = true; | ||
| } | ||
| ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags; | ||
| setStructuralTagsConfig(structuralTag); | ||
| } | ||
| } | ||
|
|
||
| } // namespace ovms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
| #pragma once | ||
| #include "../base_generation_config_builder.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| /* | ||
| * Phi4GenerationConfigBuilder extends BaseGenerationConfigBuilder to provide specific configuration for Phi-4 model. | ||
| * It overrides the parseConfigFromRequest method to set tool guided generation config. | ||
| */ | ||
| class DevstralGenerationConfigBuilder : public BaseGenerationConfigBuilder { | ||
| public: | ||
| DevstralGenerationConfigBuilder() = delete; | ||
| explicit DevstralGenerationConfigBuilder(const ov::genai::GenerationConfig& baseConfig, bool enableToolGuidedGeneration, DecodingMethod decodingMethod) : | ||
| BaseGenerationConfigBuilder(baseConfig, enableToolGuidedGeneration, decodingMethod) {} | ||
|
|
||
| void parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) override; | ||
| }; | ||
| } // namespace ovms |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,225 @@ | ||
| //***************************************************************************** | ||
| // Copyright 2025 Intel Corporation | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| //***************************************************************************** | ||
|
|
||
| #include <openvino/genai/tokenizer.hpp> | ||
| #include <string> | ||
| #include <vector> | ||
| #include <regex> | ||
|
|
||
| #include "src/port/rapidjson_document.hpp" | ||
|
|
||
| #include "../../../logging.hpp" | ||
| #include "tool_parser.hpp" | ||
| #include "../utils.hpp" | ||
| #include "src/stringutils.hpp" | ||
|
|
||
| namespace ovms { | ||
|
|
||
| void DevstralToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) { | ||
| std::vector<std::string> tools; | ||
| // expected format: [TOOL_CALLS]tool_name[ARGS]{"arg1": "value1", ...} | ||
| if (parsedOutput.content.empty() || generatedTokens.size() <= 0) { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No content to parse for tool calls"); | ||
| return; | ||
| } | ||
| size_t firstToolTokenIndex; | ||
| auto it = std::find(generatedTokens.begin(), generatedTokens.end(), this->botTokenId); | ||
| if (it != generatedTokens.end()) { | ||
| firstToolTokenIndex = std::distance(generatedTokens.begin(), it); | ||
| } else { | ||
| return; | ||
| } | ||
|
|
||
| size_t firstArgsTokenIndex; | ||
| auto itArgs = std::find(generatedTokens.begin() + firstToolTokenIndex, generatedTokens.end(), this->argsTokenId); | ||
| if (itArgs != generatedTokens.end()) { | ||
| firstArgsTokenIndex = std::distance(generatedTokens.begin(), itArgs); | ||
| } else { | ||
| return; | ||
| } | ||
| if (firstToolTokenIndex > firstArgsTokenIndex) { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "First tool token index is greater than first args token index."); | ||
| return; | ||
| } | ||
| std::vector<int64_t> toolNameTokens(generatedTokens.begin() + (firstToolTokenIndex + 1), generatedTokens.begin() + (firstArgsTokenIndex)); | ||
| std::vector<int64_t> argumentsTokens(generatedTokens.begin() + (firstArgsTokenIndex + 1), generatedTokens.end()); | ||
|
|
||
| ToolCall toolCall; | ||
| std::string toolName = tokenizer.decode(toolNameTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||
| std::string arguments = tokenizer.decode(argumentsTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||
| toolCall.name = toolName; | ||
| toolCall.arguments = arguments; | ||
| toolCall.id = generateRandomId(); // Generate a random ID for the tool call | ||
| parsedOutput.toolCalls.push_back(toolCall); | ||
|
|
||
| // get subset of generatedTokens starting from begin() to firstArgsTokenIndex | ||
| std::vector<int64_t> contentTokens; | ||
| if (firstToolTokenIndex > 0) { | ||
| contentTokens = std::vector<int64_t>(generatedTokens.begin(), generatedTokens.begin() + firstToolTokenIndex); | ||
| parsedOutput.content = tokenizer.decode(contentTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); // Return only the content till tool call | ||
| } else { | ||
| parsedOutput.content = tokenizer.decode(contentTokens, ov::AnyMap{ov::genai::skip_special_tokens(true)}); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| std::optional<rapidjson::Document> DevstralToolParser::sendFullDelta(ToolCall& toolCall) { | ||
| rapidjson::Document argsDelta; | ||
| argsDelta.Parse(toolCall.arguments.c_str()); | ||
| rapidjson::Document argumentsWrapper; | ||
| argumentsWrapper.SetObject(); | ||
| rapidjson::Document::AllocatorType& allocator = argumentsWrapper.GetAllocator(); | ||
| // now we need to add string toolCall.arguments to argumentsWrapper under "arguments" key | ||
| rapidjson::Value toolCallsString(rapidjson::kStringType); | ||
| toolCallsString.SetString(toolCall.arguments.c_str(), allocator); | ||
| argumentsWrapper.AddMember("arguments", toolCallsString, allocator); | ||
| auto currentDelta = wrapDelta(argumentsWrapper, this->toolCallIndex); | ||
| return currentDelta; | ||
| } | ||
|
|
||
| rapidjson::Document DevstralToolParser::wrapCombinedDelta(ToolCall& toolCall) { | ||
| rapidjson::Document wrappedDelta; | ||
| wrappedDelta.SetObject(); | ||
| rapidjson::Value toolCalls(rapidjson::kArrayType); | ||
| rapidjson::Value toolCallObj(rapidjson::kObjectType); | ||
| rapidjson::Value idValue(generateRandomId().c_str(), wrappedDelta.GetAllocator()); | ||
| rapidjson::Value toolCallsString(rapidjson::kStringType); | ||
|
|
||
| toolCallObj.AddMember("id", idValue, wrappedDelta.GetAllocator()); | ||
| toolCallObj.AddMember("type", "function", wrappedDelta.GetAllocator()); | ||
| toolCallObj.AddMember("index", toolCallIndex, wrappedDelta.GetAllocator()); | ||
| rapidjson::Value functionObj(rapidjson::kObjectType); | ||
| rapidjson::Value nameValue(toolCall.name.c_str(), wrappedDelta.GetAllocator()); | ||
| functionObj.AddMember("name", nameValue, wrappedDelta.GetAllocator()); | ||
| // now we need to add string toolCall.arguments to argumentsWrapper under "arguments" key | ||
|
|
||
| toolCallsString.SetString(toolCall.arguments.c_str(), wrappedDelta.GetAllocator()); | ||
| functionObj.AddMember("arguments", toolCallsString, wrappedDelta.GetAllocator()); | ||
| toolCallObj.AddMember("function", functionObj, wrappedDelta.GetAllocator()); | ||
| toolCalls.PushBack(toolCallObj, wrappedDelta.GetAllocator()); | ||
| rapidjson::Value deltaWrapper(rapidjson::kObjectType); | ||
| deltaWrapper.AddMember("tool_calls", toolCalls, wrappedDelta.GetAllocator()); | ||
| wrappedDelta.AddMember("delta", deltaWrapper, wrappedDelta.GetAllocator()); | ||
| return wrappedDelta; | ||
| } | ||
|
|
||
| rapidjson::Document DevstralToolParser::parseContentChunk() { | ||
| rapidjson::StringBuffer buffer; | ||
| rapidjson::Writer<rapidjson::StringBuffer> writer(buffer); | ||
| writer.StartObject(); | ||
| writer.String("delta"); | ||
| writer.StartObject(); | ||
| writer.String("content"); | ||
| writer.String(streamContent.c_str()); | ||
| writer.EndObject(); | ||
| writer.EndObject(); | ||
| rapidjson::Document doc; | ||
| doc.Parse(buffer.GetString()); | ||
| streamContent.clear(); | ||
| return doc; | ||
| } | ||
|
|
||
| std::optional<rapidjson::Document> DevstralToolParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { | ||
| /* | ||
| Devstral [TOOL_CALL]tool_name[ARGS]arguments[</s>] | ||
| It does not support parallel tool calls, so tool calls are always in sequence. | ||
|
|
||
| We have three processing states: | ||
| AWAITING_START_TAG, | ||
| AWAITING_ARGS_TAG, | ||
| PROCESSING_ARGS | ||
|
|
||
| We store the history of chunks in streamContent string. After state changes are detected, we clear the streamContent to only keep unprocessed part. | ||
| */ | ||
|
|
||
| this->streamContent += chunk; | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Chunk content: '{}', StreamContent: '{}', State: {}", chunk, this->streamContent, std::to_string(this->internalState)); | ||
| if (this->internalState == AWAITING_START_TAG) { | ||
| // if chunk ends with </s> we need to remove it and return parsed content immediately | ||
| if (chunk.size() >= this->streamingEndTag.size() && | ||
| chunk.substr(chunk.size() - this->streamingEndTag.size()) == this->streamingEndTag) { | ||
| // remove </s> from streamContent | ||
| this->streamContent = this->streamContent.substr(0, this->streamContent.size() - this->streamingEndTag.size()); | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Found end tag in chunk while awaiting start tag. Returning content chunk."); | ||
| return parseContentChunk(); | ||
| } | ||
| size_t pos = chunk.find(this->streamingParsingToolCallsStartTag); | ||
| if (pos != std::string::npos) { | ||
| this->internalState = AWAITING_ARGS_TAG; | ||
| std::cout << "Found [TOOL_CALLS] tag in chunk." | ||
| << " Current state: " << this->internalState << std::endl; | ||
| this->toolCallIndex++; | ||
| if (pos == 0) { | ||
| this->streamContent.clear(); | ||
| return std::nullopt; | ||
| } else { | ||
| this->streamContent = this->streamContent.substr(pos + this->streamingParsingToolCallsStartTag.length()); // "[TOOLS_CALLS]" length is 13 | ||
| return parseContentChunk(); | ||
| } | ||
| } else { | ||
| return parseContentChunk(); | ||
| } | ||
| } | ||
| if (this->internalState == AWAITING_ARGS_TAG) { | ||
| // check if [ARGS] tag is present in the chunk and update state accordingly | ||
| size_t pos = this->streamContent.find(this->streamingParsingArgsStartTag); | ||
| if (pos != std::string::npos) { | ||
| this->internalState = PROCESSING_ARGS; | ||
| this->toolName = this->streamContent.substr(0, pos); | ||
| this->streamContent = this->streamContent.substr(pos + this->streamingParsingArgsStartTag.length()); | ||
| // check if chunk ends with </s>, if yes, we need return full tool call delta | ||
| if (this->streamContent.size() >= this->streamingEndTag.size() && | ||
| this->streamContent.substr(this->streamContent.size() - this->streamingEndTag.size()) == this->streamingEndTag) { | ||
| // remove </s> from streamContent | ||
| ToolCall toolCall; | ||
| toolCall.name = this->toolName; | ||
| this->streamContent = this->streamContent.substr(0, this->streamContent.size() - this->streamingEndTag.size()); | ||
| if (!this->streamContent.empty()) { | ||
| toolCall.arguments = this->streamContent; | ||
| } else { | ||
| toolCall.arguments = "{}"; | ||
| } | ||
| this->streamContent = ""; | ||
| return wrapCombinedDelta(toolCall); | ||
| } else { | ||
| return wrapFirstDelta(this->toolName, this->toolCallIndex); | ||
| } | ||
| } else { | ||
| return std::nullopt; | ||
| } | ||
| } | ||
| if (this->internalState == PROCESSING_ARGS) { | ||
| size_t endPos = this->streamContent.find(this->streamingEndTag); | ||
| std::string arguments; | ||
| if (endPos != std::string::npos) { | ||
| arguments = this->streamContent.substr(0, endPos); | ||
| } else { | ||
| arguments = this->streamContent; | ||
| } | ||
| if (!arguments.empty()) { | ||
| ToolCall toolCall; | ||
| toolCall.arguments = arguments; | ||
| toolCall.name = this->toolName; | ||
| this->streamContent = ""; | ||
| return sendFullDelta(toolCall); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldntg we stream partial function argument chunks? if i understand correctly you send full delta at the end of generation
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already accepted such approach for qwen3 coder, so I suppose we can have it in other parsers as well unless there are specific requirements for "real" streaming. |
||
| } else { | ||
| SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "No valid arguments found in streamContent."); | ||
| return std::nullopt; | ||
| } | ||
| } | ||
| return std::nullopt; | ||
| } | ||
| } // namespace ovms | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically we check streamContent but it will be the case only if [ARGS] is added in the chunk. Otherwise it would be different state