diff --git a/docs/parameters.md b/docs/parameters.md index 56ef12414c..507d991d11 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -138,8 +138,8 @@ Task specific parameters for different tasks (text generation/image generation/e | `--max_prompt_len` | `integer` | Sets NPU specific property for maximum number of tokens in the prompt. | | `--kv_cache_precision` | `string` | Reduced kv cache precision to `u8` lowers the cache size consumption. Accepted values: `u8` or empty (default). | | `--model_distribution_policy` | `string` | TENSOR_PARALLEL distributes tensor to multiple sockets/devices and processes it in parallel. PIPELINE_PARALLEL distributes different tensors to process by each device. Accepted values: `TENSOR_PARALLEL`, `PIPELINE_PARALLEL` or empty (default). | -| `--reasoning_parser` | `string` | Type of parser to use for reasoning content extraction from model output. Currently supported: [qwen3, gptoss] | -| `--tool_parser` | `string` | Type of parser to use for tool calls extraction from model output. Currently supported: [llama3, phi4, hermes3, mistral, qwen3coder, gptoss, devstral, lfm2] | +| `--reasoning_parser` | `string` | Type of parser to use for reasoning content extraction from model output. Currently supported: [qwen3, gptoss, gemma4] | +| `--tool_parser` | `string` | Type of parser to use for tool calls extraction from model output. Currently supported: [llama3, phi4, hermes3, mistral, qwen3coder, gptoss, devstral, lfm2, gemma4] | | `--enable_tool_guided_generation` | `bool` | Enables enforcing tool schema during generation. Requires setting response parser. Default: false. | ### Image generation diff --git a/prepare_llm_models.sh b/prepare_llm_models.sh index 6bc2a861ed..7c51e50eb1 100755 --- a/prepare_llm_models.sh +++ b/prepare_llm_models.sh @@ -39,6 +39,7 @@ MISTRAL_MODEL="mistralai/Mistral-7B-Instruct-v0.3" GPT_OSS_MODEL="openai/gpt-oss-20b" DEVSTRAL_MODEL="unsloth/Devstral-Small-2507" LFM2_MODEL="LiquidAI/LFM2-2.6B" +GEMMA4_MODEL="OpenVINO/gemma-4-E4B-it-int4-ov" if [ "$(python3 -c 'import sys; print(sys.version_info[1])')" -le "8" ]; then echo "Prepare models with python > 3.8."; exit 1 ; fi @@ -228,4 +229,13 @@ fi if [ ! -f "$1/$LFM2_MODEL/$TOKENIZER_FILE" ]; then echo "[ERROR] Models file $1/$LFM2_MODEL/$TOKENIZER_FILE does not exist." exit 1 -fi \ No newline at end of file +fi +if [ -f "$1/$GEMMA4_MODEL/$TOKENIZER_FILE" ]; then + echo "Models file $1/$GEMMA4_MODEL/$TOKENIZER_FILE exists. Skipping downloading models." +else + hf download "$GEMMA4_MODEL" --local-dir $1/$GEMMA4_MODEL --include *tokenizer* +fi +if [ ! -f "$1/$GEMMA4_MODEL/$TOKENIZER_FILE" ]; then + echo "[ERROR] Models file $1/$GEMMA4_MODEL/$TOKENIZER_FILE does not exist." + exit 1 +fi diff --git a/spelling-whitelist.txt b/spelling-whitelist.txt index bd12dae11c..85763a1742 100644 --- a/spelling-whitelist.txt +++ b/spelling-whitelist.txt @@ -29,3 +29,4 @@ demos/vlm_npu/README.md:157: mane ==> main, many, maine demos/vlm_npu/README.md:218: mane ==> main, many, maine demos/integration_with_OpenWebUI/README.md:423: Buildin ==> Building, Build in src/test/llm/output_parsers/lfm2_output_parser_test.cpp +src/test/llm/output_parsers/gemma4_output_parser_test.cpp diff --git a/src/llm/BUILD b/src/llm/BUILD index 6be387c954..d3d8fbab95 100644 --- a/src/llm/BUILD +++ b/src/llm/BUILD @@ -197,6 +197,38 @@ ovms_cc_library( ], visibility = ["//visibility:public"], ) +ovms_cc_library( + name = "io_processing_gemma4_tool_parser", + hdrs = ["io_processing/gemma4/tool_parser.hpp", "io_processing/gemma4/reasoning_parser.hpp"], + srcs = ["io_processing/gemma4/tool_parser.cpp", "io_processing/gemma4/reasoning_parser.cpp"], + deps = [ + "@com_github_tencent_rapidjson//:rapidjson", + "//src/port:rapidjson_document", + "//src:libovmslogging", + "//src:libovmsstring_utils", + ":io_processing_utils", + ":io_processing_base_output_parser", + ":io_processing_qwen3_reasoning_parser", + "//third_party:genai", + ], + visibility = ["//visibility:public"], +) + +ovms_cc_library( + name = "io_processing_qwen3_reasoning_parser", + hdrs = ["io_processing/qwen3/reasoning_parser.hpp"], + srcs = ["io_processing/qwen3/reasoning_parser.cpp"], + deps = [ + "@com_github_tencent_rapidjson//:rapidjson", + "//src/port:rapidjson_document", + "//src:libovmslogging", + "//src:libovmsstring_utils", + ":io_processing_utils", + ":io_processing_base_output_parser", + "//third_party:genai", + ], + visibility = ["//visibility:public"], +) ovms_cc_library( # TODO split further so we don't have to recompile everything when changing one parser ... name = "output_parsers", @@ -206,7 +238,6 @@ ovms_cc_library( # TODO split further so we don't have to recompile everything w "io_processing/phi4/tool_parser.hpp", "io_processing/devstral/tool_parser.hpp", "io_processing/mistral/tool_parser.hpp", - "io_processing/qwen3/reasoning_parser.hpp", "io_processing/gptoss/reasoning_parser.hpp", "io_processing/gptoss/tool_parser.hpp", "io_processing/gptoss/harmony.hpp", @@ -218,7 +249,6 @@ ovms_cc_library( # TODO split further so we don't have to recompile everything w "io_processing/phi4/tool_parser.cpp", "io_processing/devstral/tool_parser.cpp", "io_processing/mistral/tool_parser.cpp", - "io_processing/qwen3/reasoning_parser.cpp", "io_processing/gptoss/reasoning_parser.cpp", "io_processing/gptoss/tool_parser.cpp", "io_processing/gptoss/harmony.cpp", @@ -234,6 +264,8 @@ ovms_cc_library( # TODO split further so we don't have to recompile everything w ":io_processing_base_output_parser", ":io_processing_qwen3coder_tool_parser", ":io_processing_lfm2_tool_parser", + ":io_processing_gemma4_tool_parser", + ":io_processing_qwen3_reasoning_parser", ":io_processing_utils", ":apis_tool_schema_wrapper", ], diff --git a/src/llm/io_processing/gemma4/gemma4_reasoning_parser.cpp b/src/llm/io_processing/gemma4/gemma4_reasoning_parser.cpp new file mode 100644 index 0000000000..0c933d8657 --- /dev/null +++ b/src/llm/io_processing/gemma4/gemma4_reasoning_parser.cpp @@ -0,0 +1,67 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include +#include + +#include "src/port/rapidjson_document.hpp" + +#include "../../../logging.hpp" +#include "gemma4_reasoning_parser.hpp" +#include "../utils.hpp" + +namespace ovms { +void Gemma4ReasoningParser::parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) { + std::string startReasoningTag = getParsingStartTags()[0]; + std::string endReasoningTag = getParsingEndTag(); + size_t startPos = parsedOutput.content.find(startReasoningTag); + size_t endPos = parsedOutput.content.find(endReasoningTag); + + if (startPos != std::string::npos && endPos != std::string::npos && startPos < endPos) { + size_t reasoningStart = startPos + startReasoningTag.length(); + std::string reasoningText = parsedOutput.content.substr(reasoningStart, endPos - reasoningStart); + parsedOutput.reasoning = reasoningText; + // Remove reasoning from content + parsedOutput.content.erase(startPos, endPos - startPos + endReasoningTag.length()); + } +} + +std::optional Gemma4ReasoningParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { + if (chunk.empty()) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Received empty chunk for Gemma4ReasoningParser"); + return std::nullopt; + } + + if (chunk.find(getParsingStartTags()[0]) != std::string::npos || chunk.find(getParsingEndTag()) != std::string::npos) { + return std::nullopt; + } else { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + writer.StartObject(); + writer.String("delta"); + writer.StartObject(); + writer.String("reasoning_content"); + writer.String(chunk.c_str()); + writer.EndObject(); + writer.EndObject(); + rapidjson::Document doc; + doc.Parse(buffer.GetString()); + return doc; + } + return std::nullopt; +} +} // namespace ovms diff --git a/src/llm/io_processing/gemma4/gemma4_reasoning_parser.hpp b/src/llm/io_processing/gemma4/gemma4_reasoning_parser.hpp new file mode 100644 index 0000000000..f4a6f48a41 --- /dev/null +++ b/src/llm/io_processing/gemma4/gemma4_reasoning_parser.hpp @@ -0,0 +1,48 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once +#include +#include + +#include "../base_output_parser.hpp" + +namespace ovms { +class Gemma4ReasoningParser : public BaseOutputParser { +protected: + // Tags used to identify the reasoning segment in the content + std::string parsingStartTag = "<|channel>thought\n"; + std::string parsingEndTag = ""; + +public: + Gemma4ReasoningParser() = delete; + explicit Gemma4ReasoningParser(ov::genai::Tokenizer& tokenizer) : + BaseOutputParser(tokenizer) {} + + void parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) override; + std::optional parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) override; + const std::vector& getParsingStartTags() const override { + static const std::vector parsingStartTags{this->parsingStartTag}; + return parsingStartTags; + } + const std::vector& getSpecialParsingStartTags() const override { + static const std::vector specialParsingStartTags{}; + return specialParsingStartTags; + } + const std::string& getParsingEndTag() const override { + return parsingEndTag; + } +}; +} // namespace ovms diff --git a/src/llm/io_processing/gemma4/reasoning_parser.cpp b/src/llm/io_processing/gemma4/reasoning_parser.cpp new file mode 100644 index 0000000000..3e1c4cb4f2 --- /dev/null +++ b/src/llm/io_processing/gemma4/reasoning_parser.cpp @@ -0,0 +1,73 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include +#include + +#include "src/port/rapidjson_document.hpp" + +#include "../../../logging.hpp" +#include "reasoning_parser.hpp" +#include "../utils.hpp" + +namespace ovms { +void Gemma4ReasoningParser::parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) { + auto startPos = std::string::npos; + auto endPos = std::string::npos; + + auto startIt = std::find(generatedTokens.begin(), generatedTokens.end(), reasoningTokenId); + auto endIt = std::find(generatedTokens.begin(), generatedTokens.end(), reasoningEndTokenId); + + if (startIt != generatedTokens.end() && endIt != generatedTokens.end() && startIt < endIt) { + startPos = std::distance(generatedTokens.begin(), startIt); + endPos = std::distance(generatedTokens.begin(), endIt); + } + + if (startPos != std::string::npos && endPos != std::string::npos && startPos < endPos) { + size_t reasoningStart = startPos + 3; // deleting "<|channel>thought\n" + std::string reasoningText = tokenizer.decode(std::vector(generatedTokens.begin() + reasoningStart, generatedTokens.begin() + endPos), ov::genai::skip_special_tokens(true)); + parsedOutput.reasoning = reasoningText; + // Remove reasoning from content + std::string contentWithoutReasoning = tokenizer.decode(std::vector(generatedTokens.begin() + endPos + 1, generatedTokens.end()), ov::genai::skip_special_tokens(true)); // content MUST never appear before reasoning + parsedOutput.content = contentWithoutReasoning; + } +} +std::optional Gemma4ReasoningParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { + if (chunk.empty()) { + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Received empty chunk for Gemma4ReasoningParser"); + return std::nullopt; + } + + if (chunk.find(getParsingStartTags()[0]) != std::string::npos || chunk.find(getParsingEndTag()) != std::string::npos) { + return std::nullopt; + } else { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + writer.StartObject(); + writer.String("delta"); + writer.StartObject(); + writer.String("reasoning_content"); + writer.String(chunk.c_str()); + writer.EndObject(); + writer.EndObject(); + rapidjson::Document doc; + doc.Parse(buffer.GetString()); + return doc; + } + return std::nullopt; +} +} // namespace ovms diff --git a/src/llm/io_processing/gemma4/reasoning_parser.hpp b/src/llm/io_processing/gemma4/reasoning_parser.hpp new file mode 100644 index 0000000000..f4e10bda96 --- /dev/null +++ b/src/llm/io_processing/gemma4/reasoning_parser.hpp @@ -0,0 +1,56 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include + +#include "../qwen3/reasoning_parser.hpp" + +namespace ovms { +class Gemma4ReasoningParser : public Qwen3ReasoningParser { +protected: + const int64_t reasoningTokenId = 100; // <|channel> + const int64_t reasoningEndTokenId = 101; // + + const std::string parsingStartTag = "<|channel>thought\n"; + const std::string parsingEndTag = ""; + +public: + Gemma4ReasoningParser() = delete; + explicit Gemma4ReasoningParser(ov::genai::Tokenizer& tokenizer) : + Qwen3ReasoningParser(tokenizer) {} + void parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) override; + std::optional parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) override; + + bool requiresStreamingWithSpecialTokens() const override { + return true; + } + + const std::vector& getParsingStartTags() const override { + static const std::vector parsingStartTags{this->parsingStartTag}; + return parsingStartTags; + } + const std::vector& getSpecialParsingStartTags() const override { + static const std::vector specialParsingStartTags{}; + return specialParsingStartTags; + } + const std::string& getParsingEndTag() const override { + return parsingEndTag; + } +}; +} // namespace ovms diff --git a/src/llm/io_processing/gemma4/tool_parser.cpp b/src/llm/io_processing/gemma4/tool_parser.cpp new file mode 100644 index 0000000000..33332cee90 --- /dev/null +++ b/src/llm/io_processing/gemma4/tool_parser.cpp @@ -0,0 +1,521 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "tool_parser.hpp" +#include "../utils.hpp" +#include "../../../logging.hpp" +#include "../../../stringutils.hpp" +#include "rapidjson/error/en.h" +#include +#include +#include + +namespace ovms { + +const std::string Gemma4ToolParser::TOOL_CALL_START_TAG = "<|tool_call>"; +const std::string Gemma4ToolParser::TOOL_CALL_END_TAG = ""; +const std::string Gemma4ToolParser::TOOL_CALL_NAME_PREFIX = "call:"; + +const std::string Gemma4ToolParser::TOOL_ARGS_START_INDICATOR = "{"; +const std::string Gemma4ToolParser::TOOL_ARGS_END_INDICATOR = "}"; +const std::string Gemma4ToolParser::TOOL_ARGS_STRING_INDICATOR = "<|\"|>"; +const std::string Gemma4ToolParser::TOOL_ARGS_SEPARATOR_STR = ","; + +const std::string Gemma4ToolParser::TURN_END_TAG = ""; + +const int64_t Gemma4ToolParser::botTokenId = 48; // <|tool_call> +const int64_t Gemma4ToolParser::eotTokenId = 49; // + +const int64_t Gemma4ToolParser::reasoningTokenId = 100; // <|channel> +const int64_t Gemma4ToolParser::reasoningEndTokenId = 101; // + +std::string Gemma4ToolParser::parseArrayParameter(const std::string& argumentStr) { + size_t pos = 1; + std::string parsedArguments = "["; + + while (pos != std::string::npos) { + size_t stringStartPos = argumentStr.find(TOOL_ARGS_STRING_INDICATOR, pos); + if (stringStartPos == std::string::npos) { + break; + } + stringStartPos += TOOL_ARGS_STRING_INDICATOR.size(); + size_t stringEndPos = argumentStr.find(TOOL_ARGS_STRING_INDICATOR, stringStartPos); + if (stringEndPos == std::string::npos) { + break; + } + + std::string originalStr = argumentStr.substr(stringStartPos, stringEndPos - stringStartPos); + size_t quotePos = 0; + while ((quotePos = originalStr.find('\"', quotePos)) != std::string::npos) { + originalStr.insert(quotePos, "\\"); + quotePos += 2; + } + parsedArguments += "\"" + originalStr + "\","; + + pos = stringEndPos + TOOL_ARGS_STRING_INDICATOR.size() + 1; + } + + parsedArguments.back() = ']'; + + return parsedArguments; +} + +std::string Gemma4ToolParser::parseObjectParameter(const std::string& argumentStr) { + size_t pos = 1; + std::vector> keyValuePairs; + + while (pos != std::string::npos) { + std::string key, value; + bool isStringValue = false; + size_t keyEndPos = argumentStr.find(':', pos); + if (keyEndPos == std::string::npos) { + break; + } + key = argumentStr.substr(pos, keyEndPos - pos); + size_t valueStartPos = keyEndPos + 1; + size_t valueEndPos = std::string::npos; + if (argumentStr.substr(valueStartPos, TOOL_ARGS_STRING_INDICATOR.size()) == TOOL_ARGS_STRING_INDICATOR) { + valueStartPos = valueStartPos + TOOL_ARGS_STRING_INDICATOR.size(); + valueEndPos = argumentStr.find(TOOL_ARGS_STRING_INDICATOR, valueStartPos); + isStringValue = true; + } else { + valueEndPos = argumentStr.find(',', valueStartPos); + } + + if (valueEndPos == std::string::npos) { + valueEndPos = argumentStr.size() - 1; + } + value = argumentStr.substr(valueStartPos, valueEndPos - valueStartPos); + if (isStringValue) { + value = "\"" + value + "\""; + } + keyValuePairs.emplace_back(key, value); + if (valueEndPos == argumentStr.size() - 1) { + break; + } else if (isStringValue) { + pos = valueEndPos + TOOL_ARGS_STRING_INDICATOR.size() + 1; + } else { + pos = valueEndPos + 1; + } + } + + if (keyValuePairs.empty()) { + return argumentStr; + } + + std::string parsedObject = "{"; + for (const auto& [key, value] : keyValuePairs) { + parsedObject += "\"" + key + "\":" + value + ","; + } + parsedObject.back() = '}'; + return parsedObject; +} + +std::string Gemma4ToolParser::normalizeArgStr(const std::string& arg) { + if (arg.empty()) { + return arg; + } + + std::string normalized = arg; + trim(normalized); + std::string lower = normalized; + std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + + if (lower == "true" || lower == "false" || lower == "null") { + return lower; + } + + const char first = normalized.front(); + const char last = normalized.back(); + if (first == '{' && last == '}') { + normalized = parseObjectParameter(normalized); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument contains is an object, changed it to correct JSON format. Modified string: {}", normalized); + } + + if (first == '[' && last == ']' && normalized.find(TOOL_ARGS_STRING_INDICATOR) != std::string::npos) { + normalized = parseArrayParameter(normalized); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument is an array, normalized quotes for JSON parsing. Modified string: {}", normalized); + } + + if (normalized.substr(0, TOOL_ARGS_STRING_INDICATOR.size()) == TOOL_ARGS_STRING_INDICATOR && + normalized.substr(normalized.size() - TOOL_ARGS_STRING_INDICATOR.size(), TOOL_ARGS_STRING_INDICATOR.size()) == TOOL_ARGS_STRING_INDICATOR) { + normalized = "\"" + normalized.substr(TOOL_ARGS_STRING_INDICATOR.size(), normalized.size() - 2 * TOOL_ARGS_STRING_INDICATOR.size()) + "\""; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument is enclosed in string indicators, removed them for JSON parsing. Modified string: {}", normalized); + } + + rapidjson::Document tempDoc; + rapidjson::Value finalValue; + tempDoc.Parse(normalized.c_str()); + if (tempDoc.HasParseError()) { + auto errorCode = tempDoc.GetParseError(); + auto errorMessage = rapidjson::GetParseError_En(errorCode); + size_t errorOffset = tempDoc.GetErrorOffset(); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Failed to parse argument string as JSON. Argument string: {}, Error: {} Offset: {}", normalized, errorMessage, errorOffset); + + if (normalized.front() == '\"' && normalized.back() == '\"') { + normalized = normalized.substr(1, normalized.size() - 2); + } + finalValue.SetString(normalized.c_str(), static_cast(normalized.size()), tempDoc.GetAllocator()); + } else { + finalValue.CopyFrom(tempDoc, tempDoc.GetAllocator()); + } + + { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + finalValue.Accept(writer); + normalized = buffer.GetString(); + } + + return normalized; +} + +void Gemma4ToolParser::writeArgumentToWriter(const std::string& arg, rapidjson::Writer& writer) { + std::string normalized = normalizeArgStr(arg); + + rapidjson::Document doc; + doc.Parse(normalized.c_str()); + + rapidjson::Value& argumentDoc = doc; + writeArgumentOfAnyType(argumentDoc, writer); +} + +std::pair Gemma4ToolParser::parseSingleArgument(const std::string& argumentStr) { + std::pair argument; + + size_t colonPos = argumentStr.find(':'); + if (colonPos != std::string::npos) { + argument.first = argumentStr.substr(0, colonPos); + std::string value = argumentStr.substr(colonPos + 1); + argument.second = value; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed argument - name: {}, value: {}", argument.first, argument.second); + } else { + argument.first = argumentStr; + argument.second = ""; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Argument string: {} does not contain ':', setting name as entire string and value as empty", argumentStr); + } + return argument; +} + +std::vector> Gemma4ToolParser::parseArguments(const std::string& argumentsStr) { + std::vector args; + std::vector> parsedArgs; + + size_t argPos = 0; + while (argPos < argumentsStr.length()) { + size_t commaPos = findInStringRespectingSpecialChars(argumentsStr, TOOL_ARGS_SEPARATOR_STR, argPos); + if (commaPos == std::string::npos) { + auto remainingStr = argumentsStr.substr(argPos); + args.push_back(remainingStr); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "No more commas found, adding remaining argument string: {}", remainingStr); + break; + } + std::string argStr = argumentsStr.substr(argPos, commaPos - argPos); + args.push_back(argStr); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed argument string: {}", argStr); + argPos = commaPos + TOOL_ARGS_SEPARATOR_STR.length(); + } + + for (const std::string& arg : args) { + parsedArgs.push_back(parseSingleArgument(arg)); + } + return parsedArgs; +} + +bool Gemma4ToolParser::parseInContentState() { + size_t toolCallStartTagPos = this->streamingContent.find(TOOL_CALL_START_TAG, this->streamingPosition); + if (toolCallStartTagPos != std::string::npos) { + if (toolCallStartTagPos > this->streamingPosition) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Content found before tool call start tag at position: {}", toolCallStartTagPos); + return true; + } + this->streamingPosition = toolCallStartTagPos + TOOL_CALL_START_TAG.length(); + this->currentState = State::ToolCallStarted; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected start of tool call at position: {}", toolCallStartTagPos); + return false; + } + + return true; +} + +bool Gemma4ToolParser::parseInToolCallState() { + size_t argsPos = this->streamingContent.find(TOOL_ARGS_START_INDICATOR, this->streamingPosition); + if (argsPos == std::string::npos) { + return false; + } + + size_t toolNameStart = this->streamingContent.find(TOOL_CALL_NAME_PREFIX, this->streamingPosition); + if (toolNameStart != std::string::npos && toolNameStart < argsPos) { + toolNameStart += TOOL_CALL_NAME_PREFIX.length(); + } else { + toolNameStart = this->streamingPosition; + } + + std::string toolName = this->streamingContent.substr(toolNameStart, argsPos - toolNameStart); + trim(toolName); + this->toolCall = ToolCall{generateRandomId(), toolName, ""}; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool name: {}", toolName); + this->streamingPosition = argsPos + TOOL_ARGS_START_INDICATOR.length(); + this->currentState = State::ToolCallParameters; + this->toolCallIndex++; + return true; +} + +bool Gemma4ToolParser::parseToolCallParametersState() { + if (this->streamingContent.back() == TOOL_ARGS_END_INDICATOR.back()) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Tool arguments end indicator found at the end of streaming content, attempting to parse arguments: {}", this->streamingContent.substr(this->streamingPosition)); + } + size_t pos = findInStringRespectingSpecialChars(this->streamingContent, TOOL_ARGS_END_INDICATOR, this->streamingPosition); + if (pos == std::string::npos) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Tool arguments end indicator not found in streaming content starting from position: {}", this->streamingPosition); + return false; + } + std::string argumentsStr = this->streamingContent.substr(this->streamingPosition, pos - this->streamingPosition); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed arguments string: {}", argumentsStr); + std::vector> arguments = parseArguments(argumentsStr); + + rapidjson::Document argsDoc(rapidjson::kObjectType); + rapidjson::StringBuffer sb; + rapidjson::Writer argsWriter(sb); + argsWriter.StartObject(); + + for (const std::pair& argument : arguments) { + argsWriter.Key(argument.first.c_str()); + writeArgumentToWriter(argument.second, argsWriter); + } + + argsWriter.EndObject(); + this->toolCall.arguments = sb.GetString(); + this->currentState = State::ToolCallEnded; + this->streamingPosition = pos + TOOL_ARGS_END_INDICATOR.length(); + + return true; +} + +bool Gemma4ToolParser::parseInToolCallEndedState() { + size_t nextToolCallPos = this->streamingContent.find(TOOL_CALL_NAME_PREFIX, this->streamingPosition); + size_t toolCallEndTagPos = this->streamingContent.find(TOOL_CALL_END_TAG, this->streamingPosition); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Current state: ToolCallEnded. Streaming content from current position: {}", this->streamingContent.substr(this->streamingPosition)); + if (nextToolCallPos != std::string::npos && nextToolCallPos < toolCallEndTagPos) { + this->streamingPosition = nextToolCallPos; + this->currentState = State::ToolCallStarted; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected next tool call at position: {}", nextToolCallPos); + } else if (toolCallEndTagPos != std::string::npos) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected end of tool call at position: {}", toolCallEndTagPos); + this->streamingPosition = toolCallEndTagPos + TOOL_CALL_END_TAG.length(); + this->currentState = State::AfterToolCall; + } else { + this->streamingPosition = toolCallEndTagPos + TOOL_CALL_END_TAG.length(); + this->currentState = State::AfterToolCall; + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Detected end of tool call at position: {}, returning to content state", toolCallEndTagPos); + } + return true; +} + +bool Gemma4ToolParser::parseNewContent() { + switch (this->currentState) { + case State::Content: { + return parseInContentState(); + } + case State::ToolCallStarted: { + return parseInToolCallState(); + } + case State::ToolCallParameters: { + return parseToolCallParametersState(); + } + case State::ToolCallEnded: { + return parseInToolCallEndedState(); + } + case State::AfterToolCall: + break; + } + return false; +} + +rapidjson::Document Gemma4ToolParser::wrapDeltaContent(const std::string& content) { + rapidjson::Document doc(rapidjson::kObjectType); + rapidjson::Value deltaObj(rapidjson::kObjectType); + deltaObj.AddMember("content", rapidjson::Value(content.c_str(), doc.GetAllocator()), doc.GetAllocator()); + doc.AddMember("delta", deltaObj, doc.GetAllocator()); + return doc; +} + +rapidjson::Document Gemma4ToolParser::wrapDeltaArgs(const std::string& argsStr, int toolCallIndex) { + rapidjson::Document doc(rapidjson::kObjectType); + doc.AddMember("arguments", rapidjson::Value(argsStr.c_str(), doc.GetAllocator()), doc.GetAllocator()); + + return BaseOutputParser::wrapDelta(doc, toolCallIndex); +} + +std::optional Gemma4ToolParser::parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) { + if (chunk.empty()) { + return std::nullopt; + } + + this->streamingContent += chunk; + + if (parseNewContent()) { + if (this->currentState == State::ToolCallParameters) { + return BaseOutputParser::wrapFirstDelta(this->toolCall.name, toolCallIndex); + } + if (this->currentState == State::ToolCallEnded) { + return wrapDeltaArgs(this->toolCall.arguments, toolCallIndex); + } + if (this->currentState == State::Content) { + size_t contentEnd = this->streamingContent.find(TOOL_CALL_START_TAG, this->streamingPosition); + std::string content; + if (contentEnd != std::string::npos) { + content = this->streamingContent.substr(this->streamingPosition, contentEnd - this->streamingPosition); + } else { + content = this->streamingContent.substr(this->streamingPosition); + } + this->streamingPosition += content.size(); + if (!content.empty()) { + return wrapDeltaContent(content); + } + } + if (this->currentState == State::AfterToolCall) { + this->currentState = State::Content; + } + } + + if (finishReason != ov::genai::GenerationFinishReason::NONE) { + if ((this->currentState == State::ToolCallParameters || this->currentState == State::ToolCallEnded) && !this->toolCall.arguments.empty()) { + return wrapDeltaArgs(this->toolCall.arguments, toolCallIndex); + } + + if (this->currentState == State::Content && this->streamingPosition < this->streamingContent.size()) { + auto content = this->streamingContent.substr(this->streamingPosition); + this->streamingPosition += content.size(); + + return wrapDeltaContent(content); + } + } + + return std::nullopt; +} + +bool Gemma4ToolParser::parseSingleToolCall(const std::string& toolStr, ToolCall& toolCall) { + size_t argsPos = toolStr.find(TOOL_ARGS_START_INDICATOR); + if (argsPos != std::string::npos) { + std::string toolNameWithPrefix = toolStr.substr(0, argsPos); + if (toolNameWithPrefix.find(TOOL_CALL_NAME_PREFIX) != 0) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Tool name does not start with expected prefix '{}'. Tool string: {}", TOOL_CALL_NAME_PREFIX, toolStr); + return false; + } + std::string toolName = toolNameWithPrefix.substr(TOOL_CALL_NAME_PREFIX.length()); + trim(toolName); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool name: {}", toolName); + + int argsStrLen = toolStr.length() - argsPos - TOOL_ARGS_START_INDICATOR.length() - TOOL_ARGS_END_INDICATOR.length(); + std::string argsStr = toolStr.substr(argsPos + TOOL_ARGS_START_INDICATOR.length(), argsStrLen); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed args string: {}", argsStr); + std::vector> arguments = parseArguments(argsStr); + + toolCall.name = toolName; + rapidjson::Document argsDoc(rapidjson::kObjectType); + rapidjson::StringBuffer sb; + rapidjson::Writer argsWriter(sb); + argsWriter.StartObject(); + for (const std::pair& argument : arguments) { + argsWriter.Key(argument.first.c_str()); + writeArgumentToWriter(argument.second, argsWriter); + } + argsWriter.EndObject(); + toolCall.arguments = sb.GetString(); + toolCall.id = generateRandomId(); + return true; + } + return false; +} + +void Gemma4ToolParser::parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) { + std::vector tools; + std::vector> toolCallPositions; + size_t pos = 0; + + while (pos != std::string::npos) { + size_t start = std::string::npos; + size_t end = std::string::npos; + + auto it = std::find(generatedTokens.begin() + pos, generatedTokens.end(), botTokenId); + if (it != generatedTokens.end()) { + start = std::distance(generatedTokens.begin(), it); + } else { + break; + } + auto itArgs = std::find(generatedTokens.begin() + start, generatedTokens.end(), eotTokenId); + if (itArgs != generatedTokens.end()) { + end = std::distance(generatedTokens.begin(), itArgs); + } else { + break; + } + + std::string toolCallStr = tokenizer.decode(std::vector(generatedTokens.begin() + start + 1, generatedTokens.begin() + end + 1), ov::AnyMap{ov::genai::skip_special_tokens(false)}); + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool list string: {}", toolCallStr); + + while (!toolCallStr.empty()) { + size_t nextToolPos = toolCallStr.find(TOOL_CALL_NAME_PREFIX, TOOL_CALL_NAME_PREFIX.length()); + size_t toolEndPos; + if (nextToolPos == std::string::npos) { + toolEndPos = toolCallStr.rfind(TOOL_ARGS_END_INDICATOR); + } else { + toolEndPos = nextToolPos - 1; + } + std::string singleTool; + if (toolEndPos != std::string::npos) { + singleTool = toolCallStr.substr(0, toolEndPos + TOOL_ARGS_END_INDICATOR.length()); + if (toolEndPos + TOOL_ARGS_END_INDICATOR.length() < toolCallStr.length()) { + toolCallStr = toolCallStr.substr(toolEndPos + TOOL_ARGS_END_INDICATOR.length()); + } else { + toolCallStr.clear(); + } + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed single tool string {}", singleTool); + } else { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "No more tool strings found in the decoded string: {}", toolCallStr); + break; + } + + if (!singleTool.empty()) { + tools.push_back(singleTool); + } + } + + pos = end; + toolCallPositions.emplace_back(start, end); + } + + for (const std::string& tool : tools) { + ToolCall toolCall; + auto wasToolCallParsed = parseSingleToolCall(tool, toolCall); + if (wasToolCallParsed) { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Parsed tool call - name: {}, args: {}", toolCall.name, toolCall.arguments); + parsedOutput.toolCalls.push_back(toolCall); + } else { + SPDLOG_LOGGER_TRACE(llm_calculator_logger, "Failed to parse tool call from string: {}", tool); + } + } + std::vector contentWithoutToolCalls = generatedTokens; + for (auto it = toolCallPositions.rbegin(); it != toolCallPositions.rend(); ++it) { + contentWithoutToolCalls.erase(contentWithoutToolCalls.begin() + it->first, contentWithoutToolCalls.begin() + it->second + 1); + } + + auto reasoningEnd = std::find(contentWithoutToolCalls.begin(), contentWithoutToolCalls.end(), reasoningEndTokenId); + if (reasoningEnd != contentWithoutToolCalls.end()) { + contentWithoutToolCalls.erase(contentWithoutToolCalls.begin(), reasoningEnd + 1); + } + parsedOutput.content = tokenizer.decode(contentWithoutToolCalls, ov::AnyMap{ov::genai::skip_special_tokens(true)}); +} +} // namespace ovms diff --git a/src/llm/io_processing/gemma4/tool_parser.hpp b/src/llm/io_processing/gemma4/tool_parser.hpp new file mode 100644 index 0000000000..c33d32b6c9 --- /dev/null +++ b/src/llm/io_processing/gemma4/tool_parser.hpp @@ -0,0 +1,104 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once +#include +#include +#include +#include "src/llm/io_processing/base_output_parser.hpp" + +namespace ovms { +class Gemma4ToolParser : public BaseOutputParser { +protected: + static const std::string TOOL_CALL_START_TAG; + static const std::string TOOL_CALL_END_TAG; + static const std::string TOOL_CALL_NAME_PREFIX; + + static const std::string TOOL_ARGS_START_INDICATOR; + static const std::string TOOL_ARGS_END_INDICATOR; + static const std::string TOOL_ARGS_STRING_INDICATOR; + static const std::string TOOL_ARGS_SEPARATOR_STR; + static const std::string TURN_END_TAG; + + static const int64_t botTokenId; + static const int64_t eotTokenId; + static const int64_t reasoningTokenId; + static const int64_t reasoningEndTokenId; + + enum class State { + Content, // Content -> ToolCallStarted (on TOOL_CALL_START_TAG) + ToolCallStarted, // ToolCallStarted -> ToolCallParameters (on TOOL_ARGS_START_INDICATOR, emits name) + ToolCallParameters, // ToolCallParameters -> ToolCallEnded (on TOOL_ARGS_END_INDICATOR, emits args) + ToolCallEnded, // ToolCallEnded -> ToolCallStarted (on TOOL_CALL_NAME_PREFIX) | AfterToolCall (on end tag) + AfterToolCall // AfterToolCall -> Content + }; + +public: + Gemma4ToolParser() = delete; + explicit Gemma4ToolParser(ov::genai::Tokenizer& tokenizer) : + BaseOutputParser(tokenizer) {} + + void parse(ParsedOutput& parsedOutput, const std::vector& generatedTokens) override; + std::optional parseChunk(const std::string& chunk, ov::genai::GenerationFinishReason finishReason) override; + const std::vector& getParsingStartTags() const override { + static const std::vector parsingStartTags = {TOOL_CALL_START_TAG}; + return parsingStartTags; + } + + const std::vector& getSpecialTagsToErase() const override { + static const std::vector tagsToErase = {TURN_END_TAG}; + return tagsToErase; + } + + const std::vector& getSpecialParsingStartTags() const override { + static const std::vector beginningOnlyTags = {}; + return beginningOnlyTags; + } + + const std::string& getParsingEndTag() const override { + return TOOL_CALL_END_TAG; + } + + bool requiresStreamingWithSpecialTokens() const override { + return true; + } + + static std::string normalizeArgStr(const std::string& arg); + static std::string parseArrayParameter(const std::string& argumentStr); + static std::string parseObjectParameter(const std::string& argumentStr); + +private: + void writeArgumentToWriter(const std::string& arg, rapidjson::Writer& writer); + + std::pair parseSingleArgument(const std::string& argumentStr); + std::vector> parseArguments(const std::string& argumentsStr); + + bool parseSingleToolCall(const std::string& toolStr, ToolCall& toolCall); + bool parseNewContent(); + bool parseInContentState(); + bool parseInToolCallState(); + bool parseToolCallParametersState(); + bool parseInToolCallEndedState(); + + rapidjson::Document wrapDeltaContent(const std::string& content); + rapidjson::Document wrapDeltaArgs(const std::string& argsStr, int toolCallIndex); + + std::string streamingContent; + size_t streamingPosition{0}; + State currentState{State::Content}; + ToolCall toolCall; + int toolCallIndex{-1}; +}; +} // namespace ovms diff --git a/src/llm/io_processing/output_parser.cpp b/src/llm/io_processing/output_parser.cpp index c6d9fe8a67..d0a99002cb 100644 --- a/src/llm/io_processing/output_parser.cpp +++ b/src/llm/io_processing/output_parser.cpp @@ -28,8 +28,10 @@ #include "qwen3/reasoning_parser.hpp" #include "qwen3coder/qwen3coder_tool_parser.hpp" #include "devstral/tool_parser.hpp" +#include "gemma4/reasoning_parser.hpp" #include "gptoss/reasoning_parser.hpp" #include "lfm2/lfm2_tool_parser.hpp" +#include "gemma4/tool_parser.hpp" namespace ovms { OutputParser::TagLookupStatus OutputParser::StreamOutputCache::lookupTag(const std::string& tag) const { @@ -50,8 +52,8 @@ OutputParser::TagLookupStatus OutputParser::StreamOutputCache::lookupTag(const s } else if (tag.size() < buffer.size()) { /* If the tag is shorter than the buffer, we check: - a) if the tag is a substring of the buffer (tag is fully matched) - b) if the buffer and tag overlap (part of the tag is matched) + a) if the tag is a substring of the buffer (tag is fully matched) + b) if the buffer and tag overlap (part of the tag is matched) in the first case we return FOUND_COMPLETE, in the second FOUND_INCOMPLETE otherwise we return NOT_FOUND */ @@ -65,8 +67,8 @@ OutputParser::TagLookupStatus OutputParser::StreamOutputCache::lookupTag(const s } else { /* If the tag and buffer are of the same length, we check: - a) if they are equal (tag is fully matched) - b) if they overlap (part of the tag is matched) + a) if they are equal (tag is fully matched) + b) if they overlap (part of the tag is matched) in the first case we return FOUND_COMPLETE, in the second FOUND_INCOMPLETE otherwise we return NOT_FOUND */ @@ -184,12 +186,16 @@ OutputParser::OutputParser(ov::genai::Tokenizer& tokenizer, const std::string to toolParser = std::make_unique(tokenizer, toolNameSchemaMap); } else if (toolParserName == "lfm2") { toolParser = std::make_unique(tokenizer); + } else if (toolParserName == "gemma4") { + toolParser = std::make_unique(tokenizer); } else if (!toolParserName.empty()) { throw std::runtime_error("Unsupported tool parser: " + toolParserName); } if (reasoningParserName == "qwen3") { reasoningParser = std::make_unique(tokenizer); + } else if (reasoningParserName == "gemma4") { + reasoningParser = std::make_unique(tokenizer); } else if (reasoningParserName == "gptoss") { reasoningParser = std::make_unique(tokenizer); } else if (!reasoningParserName.empty()) { diff --git a/src/llm/io_processing/utils.cpp b/src/llm/io_processing/utils.cpp index e26ca376b4..c58ed1ccf0 100644 --- a/src/llm/io_processing/utils.cpp +++ b/src/llm/io_processing/utils.cpp @@ -70,6 +70,11 @@ size_t findInStringRespectingSpecialChars(const std::string& str, const std::str int singleQuoteDepth = 0; for (size_t i = startPos; i < str.length(); ++i) { + if (bracketDepth == 0 && braceDepth == 0 && quoteDepth == 0 && singleQuoteDepth == 0 && + str.compare(i, target.length(), target) == 0) { + return i; + } + if (str[i] == '{') { braceDepth++; } else if (str[i] == '}') { @@ -80,14 +85,10 @@ size_t findInStringRespectingSpecialChars(const std::string& str, const std::str bracketDepth--; } else if (str[i] == '"' && (i == 0 || str[i - 1] != '\\')) { quoteDepth = 1 - quoteDepth; - } else if (str[i] == '\'' && (i == 0 || str[i - 1] != '\\')) { + } else if (quoteDepth == 0 && str[i] == '\'' && (i == 0 || str[i - 1] != '\\')) { singleQuoteDepth = 1 - singleQuoteDepth; - } else if (bracketDepth == 0 && braceDepth == 0 && quoteDepth == 0 && singleQuoteDepth == 0 && - str.compare(i, target.length(), target) == 0) { - return i; } } return std::string::npos; } - } // namespace ovms diff --git a/src/llm/io_processing/utils.hpp b/src/llm/io_processing/utils.hpp index 1a956f6fee..79c7358af9 100644 --- a/src/llm/io_processing/utils.hpp +++ b/src/llm/io_processing/utils.hpp @@ -28,4 +28,7 @@ size_t findInStringRespectingSpecialChars(const std::string& str, const std::str void writeArgumentOfAnyType(const rapidjson::Value& arg, rapidjson::Writer& writer); // Generates random alphanumeric string of length 9 for tool call ID std::string generateRandomId(); + +size_t findInStringRespectingSpecialChars(const std::string& str, const std::string& target, size_t startPos); +void writeArgumentOfAnyType(const rapidjson::Value& arg, rapidjson::Writer& writer); } // namespace ovms diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 6297745360..6ca45b0b03 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -194,8 +194,10 @@ absl::Status VisualLanguageModelLegacyServable::prepareCompleteResponse(std::sha completeText = std::move(executionContext->lastStreamerCallbackOutput); executionContext->lastStreamerCallbackOutput.clear(); } - executionContext->response = executionContext->apiHandler->serializeUnaryResponse(legacyExecutionContext->results, completeText); + executionContext->apiHandler->setPromptTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_input_tokens()); + executionContext->apiHandler->setCompletionTokensUsage(legacyExecutionContext->results.perf_metrics.get_num_generated_tokens()); + executionContext->response = executionContext->apiHandler->serializeUnaryResponse(legacyExecutionContext->results, completeText); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Complete unary response: {}", executionContext->response); return absl::OkStatus(); } diff --git a/src/test/llm/output_parsers/gemma4_output_parser_test.cpp b/src/test/llm/output_parsers/gemma4_output_parser_test.cpp new file mode 100644 index 0000000000..2cd39e3cc2 --- /dev/null +++ b/src/test/llm/output_parsers/gemma4_output_parser_test.cpp @@ -0,0 +1,902 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../../llm/io_processing/base_output_parser.hpp" +#include "../../../llm/io_processing/output_parser.hpp" +#include "../../platform_utils.hpp" + +using namespace ovms; + +#ifdef _WIN32 +const std::string tokenizerPath = getWindowsRepoRootPath() + "\\src\\test\\llm_testing\\OpenVINO\\gemma-4-E4B-it-int4-ov"; +#else +// Hardcoded for usage in docker container +const std::string tokenizerPath = "/ovms/src/test/llm_testing/OpenVINO/gemma-4-E4B-it-int4-ov"; +#endif + +static std::unique_ptr gemma4Tokenizer; +static const ToolsSchemas_t& EMPTY_TOOLS_SCHEMA = {}; // not used in gemma4 + +class Gemma4OutputParserTest : public ::testing::Test { +protected: + std::unique_ptr outputParserWithRegularToolParsing; + + static void SetUpTestSuite() { + try { + gemma4Tokenizer = std::make_unique(tokenizerPath); + } catch (const std::exception& e) { + FAIL() << "Failed to initialize gemma4 tokenizer: " << e.what(); + } catch (...) { + FAIL() << "Failed to initialize gemma4 tokenizer due to unknown error."; + } + } + + static void TearDownTestSuite() { + gemma4Tokenizer.reset(); + } + + void SetUp() override { + // For Gemma4 model there is only tool parser available + outputParserWithRegularToolParsing = std::make_unique(*gemma4Tokenizer, "gemma4", "gemma4", EMPTY_TOOLS_SCHEMA); + } + + void assertChunkEqual(const std::optional& doc, const std::optional& expectedDelta, const std::string& chunk) { + if (!expectedDelta.has_value() && !doc.has_value()) { + return; + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } else { + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk; + } + } +}; + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithSingleToolCall) { + std::string inputWithProperClosure = "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithSingleToolCallAndReasoning) { + std::string inputWithProperClosure = "<|channel>thought\nSome reasoning content<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, "Some reasoning content"); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseReasoningWithoutToolCall) { + std::string inputWithProperClosure = "<|channel>thought\nSome reasoning contentSOME CONTENT WITHOUT TOOL CALL"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "SOME CONTENT WITHOUT TOOL CALL"); + EXPECT_EQ(parsedOutput.reasoning, "Some reasoning content"); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithNoToolsInTheRequest) { + std::string inputWithProperClosure = "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + std::string inputWithoutSpecialTokens = "call:example_tool{arg1:value1,arg2:42}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, false); + EXPECT_EQ(parsedOutput.content, inputWithoutSpecialTokens); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithObjectArguments) { + std::string inputWithProperClosure = "<|tool_call>call:dummy{config:{name:<|\"|>astro_config<|\"|>,value:99}}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "dummy"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"config\":{\"name\":\"astro_config\",\"value\":99}}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArguments) { + std::string inputWithProperClosure = "<|tool_call>call:test1{arg1:<|\"|>data1,data2<|\"|>}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "test1"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"data1,data2\"}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithListOfStringsAsArgument) { + std::string inputWithProperClosure = "<|tool_call>call:generate_DNA_sequence{length:100,preferences:[<|\"|>G<|\"|>,<|\"|>C<|\"|>]}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "generate_DNA_sequence"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"length\":100,\"preferences\":[\"G\",\"C\"]}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParserToolCallWithBooleanArgument) { + std::string inputWithProperClosure = "<|tool_call>call:check_status{flag:true}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "check_status"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"flag\":true}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseTwoToolCallsAtOnce) { + std::string inputWithProperClosure = "<|tool_call>call:dummy1{config:{name:<|\"|>astro_config<|\"|>,value:99}}call:dummy2{config:{value:199,name:<|\"|>second_config<|\"|>}}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 2); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "dummy1"); + EXPECT_EQ(parsedOutput.toolCalls[1].name, "dummy2"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"config\":{\"name\":\"astro_config\",\"value\":99}}"); + EXPECT_EQ(parsedOutput.toolCalls[1].arguments, "{\"config\":{\"value\":199,\"name\":\"second_config\"}}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + EXPECT_EQ(parsedOutput.toolCalls[1].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithArrayArguments) { + std::string inputWithProperClosure = "<|tool_call>call:sort{array:[42,17,89,5,33],order:<|\"|>descending<|\"|>}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "sort"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithThreeToolCalls) { + std::string inputWithProperClosure = "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}call:another_tool{param1:<|\"|>data<|\"|>,param2:true}call:third_tool{key:<|\"|>value<|\"|>}"; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 3); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + auto firstToolCallId = parsedOutput.toolCalls[0].id; + + EXPECT_EQ(parsedOutput.toolCalls[1].name, "another_tool"); + EXPECT_EQ(parsedOutput.toolCalls[1].arguments, "{\"param1\":\"data\",\"param2\":true}"); + EXPECT_EQ(parsedOutput.toolCalls[1].id.empty(), false); + auto secondToolCallId = parsedOutput.toolCalls[1].id; + EXPECT_NE(firstToolCallId, secondToolCallId); + + EXPECT_EQ(parsedOutput.toolCalls[2].name, "third_tool"); + EXPECT_EQ(parsedOutput.toolCalls[2].arguments, "{\"key\":\"value\"}"); + EXPECT_EQ(parsedOutput.toolCalls[2].id.empty(), false); + auto thirdToolCallId = parsedOutput.toolCalls[2].id; + EXPECT_NE(firstToolCallId, thirdToolCallId); + EXPECT_NE(secondToolCallId, thirdToolCallId); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithThreeToolCallsWithContentInBetween) { + std::string inputWithProperClosure = "Before tool calls content. " + "<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}" + "This is some content between tool calls." + "<|tool_call>call:another_tool{param1:<|\"|>data<|\"|>,param2:true}" + " This is some content between second and third tool call. " + "<|tool_call>call:third_tool{key:<|\"|>value<|\"|>}" + "After tool calls content."; + + std::vector inputs = {inputWithProperClosure}; + for (auto& input : inputs) { + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "Before tool calls content. This is some content between tool calls. This is some content between second and third tool call. After tool calls content."); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 3); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); + auto firstToolCallId = parsedOutput.toolCalls[0].id; + + EXPECT_EQ(parsedOutput.toolCalls[1].name, "another_tool"); + EXPECT_EQ(parsedOutput.toolCalls[1].arguments, "{\"param1\":\"data\",\"param2\":true}"); + EXPECT_EQ(parsedOutput.toolCalls[1].id.empty(), false); + auto secondToolCallId = parsedOutput.toolCalls[1].id; + EXPECT_NE(firstToolCallId, secondToolCallId); + + EXPECT_EQ(parsedOutput.toolCalls[2].name, "third_tool"); + EXPECT_EQ(parsedOutput.toolCalls[2].arguments, "{\"key\":\"value\"}"); + EXPECT_EQ(parsedOutput.toolCalls[2].id.empty(), false); + auto thirdToolCallId = parsedOutput.toolCalls[2].id; + EXPECT_NE(firstToolCallId, thirdToolCallId); + EXPECT_NE(secondToolCallId, thirdToolCallId); + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithEmptyArguments) { + // Tool call with empty braces (no arguments) + std::string input = "<|tool_call>call:no_args_tool{}"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "no_args_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{}"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithMultipleUtfChars) { + // Tool call with empty braces (no arguments) and content around + std::string input = R"(<|tool_call>call:post_tweet{content:<|"|>Check out the sorted report! 🚀 We've made improvements to the content. Tagging @currenttech and mentioning Julia for our insightful team. #currenttech #trend<|"|>,mentions:[<|"|>@currenttech<|"|>,<|"|>Julia<|"|>],tags:[<|"|>#currenttrend<|"|>]})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "post_tweet"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"content":"Check out the sorted report! 🚀 We've made improvements to the content. Tagging @currenttech and mentioning Julia for our insightful team. #currenttech #trend","mentions":["@currenttech","Julia"],"tags":["#currenttrend"]})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithMultipleUtfCharsStreaming) { + std::vector>> chunkToDeltaVec{ + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"post", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"_tweet", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{content", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"post_tweet"}}]}})"}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"Check out the sorted report!", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 🚀", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" We've made improvements", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" to the content.", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" Tagging @currenttech", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" and mentioning Julia", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" for our insightful team.", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" #currenttech #trend", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"mentions", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":[", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"@currenttech", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"Julia", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"],", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"tags", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":[", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"#currenttrend", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"]}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"content\":\"Check out the sorted report! 🚀 We've made improvements to the content. Tagging @currenttech and mentioning Julia for our insightful team. #currenttech #trend\",\"mentions\":[\"@currenttech\",\"Julia\"],\"tags\":[\"#currenttrend\"]}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithContentAndNoToolCalls) { + std::string input = "This is a regular model response without tool calls."; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "This is a regular model response without tool calls."); + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); + EXPECT_EQ(parsedOutput.reasoning, ""); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallOutputWithContentAndSingleToolCall) { + std::string input = "This is a content part and next will be a tool call.\n\n<|tool_call>call:example_tool{arg1:<|\"|>value1<|\"|>,arg2:42}"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, "This is a content part and next will be a tool call.\n\n"); + EXPECT_EQ(parsedOutput.reasoning, ""); + + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "example_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, "{\"arg1\":\"value1\",\"arg2\":42}"); + EXPECT_EQ(parsedOutput.toolCalls[0].id.empty(), false); +} + +TEST_F(Gemma4OutputParserTest, HolisticStreaming) { + std::vector>> chunkToDeltaVec{ + {"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG"}})"}, + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"sort", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{array", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"sort"}}]}})"}, + {":[", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"42", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 17", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 89", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 5", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 33", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"],", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"order", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"desc", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ending", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"}}]}})"}, + {"call:d", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ummy", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{config", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":1,"function":{"name":"dummy"}}]}})"}, + {":{", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"name", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"astro_config", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"value", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"99", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\"config\":{\"name\":\"astro_config\",\"value\":99}}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ANOTHER_CONTENT_AFTER_TOOL_CALL", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"ANOTHER_CONTENT_AFTER_TOOL_CALL"}})"}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; // Both are nullopt, OK + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + // If both strings contain "id":"...", compare id values by length and alphanumeric, else compare whole strings + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + // Compare everything except the id value + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, StreamingWithBiggerChunks) { + std::vector>> chunkToDeltaVec{ + {"SOME_CONTENT", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"SOME_CONTENT"}})"}, + {"MORE_CONTENT<|tool_call>", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"MORE_CONTENT"}})"}, + {"call:sort", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{array:", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"sort"}}]}})"}, + {"[42, 17, 89, 5, 33],order:<|\"|>descending<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ANOTHER_CONTENT_AFTER_TOOL_CALL", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"ANOTHER_CONTENT_AFTER_TOOL_CALL"}})"}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; // Both are nullopt, OK + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, StreamingWithWhitespacesBetweenToolCalls) { + std::vector>> chunkToDeltaVec{ + {"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"JUST_SOME_STRING_BEFORE_SPECIAL_STARTING_TAG"}})"}, + {"<|tool_call>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"\n", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"call:sort", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{array", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":0,"function":{"name":"sort"}}]}})"}, + {":[", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"42", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 17", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 89", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 5", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 33", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"],", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"order", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"desc", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ending", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"array\":[42,17,89,5,33],\"order\":\"descending\"}"}}]}})"}, + {" call:d", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"ummy", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{config", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":1,"function":{"name":"dummy"}}]}})"}, + {":{", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"name", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"astro_config", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {",", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"value", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"99", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"}}", ov ::genai ::GenerationFinishReason ::NONE, R"({"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\"config\":{\"name\":\"astro_config\",\"value\":99}}"}}]}})"}, + {"call: solve", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"{e", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"id":"XXXXXXXXX","type":"function","index":2,"function":{"name":"solve"}}]}})"}, + {"quation", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {":<|\"|>", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"2", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"*", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"(", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"x", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"+", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"5)", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" =", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {" 13", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"<|\"|>}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":2,"function":{"arguments":"{\"equation\":\"2*(x+5) = 13\"}"}}]}})"}, + {"", ov::genai::GenerationFinishReason::NONE, std::nullopt}, + {"And some content after second tool call", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"content":"And some content after second tool call"}})"}, + }; + + for (const auto& [chunk, finishReason, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, true, finishReason); + if (!expectedDelta.has_value() && !doc.has_value()) { + continue; // Both are nullopt, OK + } + if (expectedDelta.has_value() && doc.has_value()) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + std::string docStr = buffer.GetString(); + std::string expected = expectedDelta.value(); + std::string idKey = "\"id\":\""; + auto docIdPos = docStr.find(idKey); + auto expectedIdPos = expected.find(idKey); + if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) { + auto docIdStart = docIdPos + idKey.size(); + auto docIdEnd = docStr.find("\"", docIdStart); + auto expectedIdStart = expectedIdPos + idKey.size(); + auto expectedIdEnd = expected.find("\"", expectedIdStart); + ASSERT_NE(docIdEnd, std::string::npos); + ASSERT_NE(expectedIdEnd, std::string::npos); + std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart); + std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart); + EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk; + EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk; + // Compare everything except the id value + std::string docStrNoId = docStr; + std::string expectedNoId = expected; + docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*')); + expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*')); + EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk; + } else { + EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk; + } + } else { + std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt"; + std::string docStr = doc.has_value() ? [&]() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc->Accept(writer); + return std::string(buffer.GetString()); + }() + : "std::nullopt"; + FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk + << "\nexpectedDelta: " << expectedStr + << "\ndoc: " << docStr; + } + } +} + +TEST_F(Gemma4OutputParserTest, ToolCallsWithoutToolsInTheRequestStreaming) { + std::vector>> chunkToDeltaVec{ + {"<|tool_call>", "{\"delta\":{\"content\":\"<|tool_call>\"}}"}, + {"call:super", "{\"delta\":{\"content\":\"call:super\"}}"}, + {"_tool_number_two", "{\"delta\":{\"content\":\"_tool_number_two\"}}"}, + {"{arg1", "{\"delta\":{\"content\":\"{arg1\"}}"}, + {":<|\"|>", "{\"delta\":{\"content\":\":<|\\\"|>\"}}"}, + {"val{{{ue1", "{\"delta\":{\"content\":\"val{{{ue1\"}}"}, + {"<|\"|>}", "{\"delta\":{\"content\":\"<|\\\"|>}\"}}"}, + {"", "{\"delta\":{\"content\":\"\"}}"}, + }; + + for (const auto& [chunk, expectedDelta] : chunkToDeltaVec) { + std::optional doc = outputParserWithRegularToolParsing->parseChunk(chunk, false, ov::genai::GenerationFinishReason::NONE); + assertChunkEqual(doc, expectedDelta, chunk); + } +} + +// Malformed tool calls + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithMissingParentheses) { + std::string input = "<|tool_call>call:broken_tool"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithMissingClosingParenthesis) { + std::string input = "<|tool_call>call:broken_tool{arg1:<|\"|>value1<|\"|>"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 0); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithArgumentMissingEquals) { + std::string input = "<|tool_call>call:broken{malformed_arg}"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "broken"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingComparison) { + std::string input = R"x(<|tool_call>call:search{query:<|"|>price >= 100, (sale)<|"|>,limit:5})x"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "search"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"x({"query":"price >= 100, (sale)","limit":5})x"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingBracesAndBrackets) { + std::string input = R"(<|tool_call>call:format{template:<|"|>Hello {name}, items: [a, b, c]<|"|>,count:3})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "format"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"template":"Hello {name}, items: [a, b, c]","count":3})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingSpecialCharacters) { + std::string impl = "import package\nimport package2\n\ndef func(a, b):\n\td={\"python\": \"dict\"}\n\tl = [\"list \\\"with escaped text\\\"\", 123, []]\n\treturn f\"formatted {a} and {b}\""; + std::string input = R"(<|tool_call>call:execute{code:<|"|>)" + impl + R"(<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "execute"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"code":"import package\nimport package2\n\ndef func(a, b):\n\td={\"python\": \"dict\"}\n\tl = [\"list \\\"with escaped text\\\"\", 123, []]\n\treturn f\"formatted {a} and {b}\""})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingEscapedQuotes) { + std::string input = R"x(<|tool_call>call:execute{code:<|"|>print(\"hello world\")<|"|>,verbose:true})x"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "execute"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"x({"code":"print(\"hello world\")","verbose":true})x"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingApostrophes) { + std::string input = R"(<|tool_call>call:log{message:<|"|>it's a test, isn't it?<|"|>,level:<|"|>warn<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "log"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"message":"it's a test, isn't it?","level":"warn"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingBackslashes) { + std::string input = R"(<|tool_call>call:read_file{path:<|"|>C:\Users\test\file.txt<|"|>,encoding:<|"|>utf-8<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "read_file"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"path":"C:\\Users\\test\\file.txt","encoding":"utf-8"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsArrayWithStringsContainingQuotes) { + std::string input = R"(<|tool_call>call:save{lines:[<|"|>it's the wonderful day<|"|>,<|"|>He said: "My name's John"<|"|>,<|"|>That's Johns' car.<|"|>]})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "save"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"lines":["it's the wonderful day","He said: \"My name's John\"","That's Johns' car."]})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsObjectWithStringsContainingQuotes) { + std::string input = R"(<|tool_call>call:save{obj:{name:<|"|>it's the wonderful day<|"|>,greeting:<|"|>Hello, my name's Jan<|"|>,note:<|"|>That's Johns' car.<|"|>}})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "save"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"obj":{"name":"it's the wonderful day","greeting":"Hello, my name's Jan","note":"That's Johns' car."}})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithStringArgumentsContainingNestedJSON) { + std::string input = R"(<|tool_call>call:send{payload:<|"|>{'key': 'value', 'count': 42}<|"|>,endpoint:<|"|>api<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "send"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"payload":"{'key': 'value', 'count': 42}","endpoint":"api"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithEmptyStringArgument) { + std::string input = R"(<|tool_call>call:create{name:<|"|><|"|>,value:0})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "create"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"name":"","value":0})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithUnicodeCharactersInArguments) { + std::string input = R"(<|tool_call>call:translate{text:<|"|>zażółć gęślą jaźń<|"|>,lang:<|"|>pl<|"|>})"; + auto generatedTensor = gemma4Tokenizer->encode(input).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "translate"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"({"text":"zażółć gęślą jaźń","lang":"pl"})"); +} + +TEST_F(Gemma4OutputParserTest, ParseToolCallWithPythonCodeAsArgument) { + std::string input = R"x(<|tool_call>call:string_tool{param:<|"|> + if __name__ == "__main__": + addresses = {} + addresses["Hodor"] = """The door""" + addresses["Arya"] = "Winterfell" + for name, address in addresses.items(): + print(f'\n\t{name} lives at {address}\n\r')<|"|>})x"; + auto generatedTensor = gemma4Tokenizer->encode(input, ov::genai::add_special_tokens(false)).input_ids; + std::vector generatedTokens(generatedTensor.data(), generatedTensor.data() + generatedTensor.get_size()); + ParsedOutput parsedOutput = outputParserWithRegularToolParsing->parse(generatedTokens, true); + EXPECT_EQ(parsedOutput.content, ""); + ASSERT_EQ(parsedOutput.toolCalls.size(), 1); + EXPECT_EQ(parsedOutput.toolCalls[0].name, "string_tool"); + EXPECT_EQ(parsedOutput.toolCalls[0].arguments, R"x({"param":"\n if __name__ == \"__main__\":\n addresses = {}\n addresses[\"Hodor\"] = \"\"\"The door\"\"\"\n addresses[\"Arya\"] = \"Winterfell\"\n for name, address in addresses.items():\n print(f'\\n\\t{name} lives at {address}\\n\\r')"})x"); +} diff --git a/windows_prepare_llm_models.bat b/windows_prepare_llm_models.bat index 86e7594fef..c49fdc1f0f 100644 --- a/windows_prepare_llm_models.bat +++ b/windows_prepare_llm_models.bat @@ -45,6 +45,7 @@ set "MISTRAL_MODEL=mistralai/Mistral-7B-Instruct-v0.3" set "GPTOSS_MODEL=openai/gpt-oss-20b" set "DEVSTRAL_MODEL=unsloth/Devstral-Small-2507" set "LFM2_MODEL=LiquidAI/LFM2-2.6B" +set "GEMMA4_MODEL=OpenVINO/gemma-4-E4B-it-int4-ov" echo Downloading LLM testing models to directory %~1 set "PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly" @@ -85,6 +86,7 @@ call :download_tokenizer "%MISTRAL_MODEL%" "%~1\%MISTRAL_MODEL%" call :download_tokenizer "%GPTOSS_MODEL%" "%~1\%GPTOSS_MODEL%" call :download_tokenizer "%DEVSTRAL_MODEL%" "%~1\%DEVSTRAL_MODEL%" call :download_tokenizer "%LFM2_MODEL%" "%~1\%LFM2_MODEL%" +call :download_openvino_tokenizer "%GEMMA4_MODEL%" "%~1" exit /b 0 @@ -131,6 +133,18 @@ if not exist "%repository%\%model%\openvino_tokenizer.bin" ( ) exit /b 0 +:download_openvino_tokenizer +set "model=%~1" +set "repository=%~2" + +if not exist "%repository%\%model%\openvino_tokenizer.bin" ( + echo Downloading tokenizer and detokenizer for %model% model to %repository%\%model% directory. + mkdir "%repository%\%model%" + hf download "%model%" --local-dir "%repository%\%model%" --include *tokenizer* +) else ( + echo Models file %repository%\%model%\openvino_tokenizer.bin exists. Skipping downloading models. +) + :: Helper subroutine to download tokenizers :download_tokenizer set "model=%~1"