diff --git a/src/BUILD b/src/BUILD index f4778b258b..24f405597f 100644 --- a/src/BUILD +++ b/src/BUILD @@ -2587,6 +2587,7 @@ cc_test( "@mediapipe//mediapipe/calculators/ovms:ovms_calculator", "@mediapipe//mediapipe/framework:calculator_runner", ":text2image_test", + ":lora_adapter_test", ], "//:disable_mediapipe" : [ @@ -2598,6 +2599,22 @@ cc_test( linkopts = LINKOPTS_ADJUSTED, ) +cc_library( + name = "lora_adapter_test", + linkstatic = 1, + alwayslink = True, + srcs = [ + "test/llm/lora_adapter_test.cpp", + ], + deps = [ + ":test_test_with_temp_dir", + "//src/llm:genai_servables", + "@com_google_googletest//:gtest", + ], + copts = COPTS_TESTS, + local_defines = COMMON_LOCAL_DEFINES, +) + cc_library( name = "test_constructor_enabled_model_manager", hdrs = ["test/constructor_enabled_model_manager.hpp",], diff --git a/src/llm/language_model/continuous_batching/servable_initializer.cpp b/src/llm/language_model/continuous_batching/servable_initializer.cpp index 27f4f51aee..70b5229cf0 100644 --- a/src/llm/language_model/continuous_batching/servable_initializer.cpp +++ b/src/llm/language_model/continuous_batching/servable_initializer.cpp @@ -198,6 +198,11 @@ Status ContinuousBatchingServableInitializer::initialize(std::shared_ptrpluginConfig); if (!status.ok()) { SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config()); diff --git a/src/llm/language_model/legacy/servable_initializer.cpp b/src/llm/language_model/legacy/servable_initializer.cpp index 4ee7d4820a..65a53edaba 100644 --- a/src/llm/language_model/legacy/servable_initializer.cpp +++ b/src/llm/language_model/legacy/servable_initializer.cpp @@ -76,6 +76,11 @@ Status LegacyServableInitializer::initialize(std::shared_ptr& ser return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED; } + status = initializeLoraAdapters(nodeOptions, graphPath, properties); + if (!status.ok()) { + return status; + } + status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), properties->pluginConfig); if (!status.ok()) { SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config()); diff --git a/src/llm/llm_calculator.proto b/src/llm/llm_calculator.proto index c8edacf88e..d347a35264 100644 --- a/src/llm/llm_calculator.proto +++ b/src/llm/llm_calculator.proto @@ -26,6 +26,11 @@ message LLMCalculatorOptions { optional LLMCalculatorOptions ext = 113473750; } + message LoraAdapter { + required string model_path = 1; + optional float alpha = 2 [default = 1]; + } + message KVCrushConfig { enum AnchorPointMode { RANDOM = 0; @@ -135,4 +140,6 @@ message LLMCalculatorOptions { optional bool enable_tool_guided_generation = 23 [default = false]; optional SparseAttentionConfig sparse_attention_config = 24; + + repeated LoraAdapter lora_adapter = 25; } diff --git a/src/llm/servable_initializer.cpp b/src/llm/servable_initializer.cpp index 68913ee66e..06022e7a86 100644 --- a/src/llm/servable_initializer.cpp +++ b/src/llm/servable_initializer.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -316,6 +317,44 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr properties) { + if (nodeOptions.lora_adapter_size() <= 0) { + return StatusCode::OK; + } + SPDLOG_INFO("LoRA adapters will be applied to the model. Number of adapters: {}", nodeOptions.lora_adapter_size()); + ov::genai::AdapterConfig adapterConfig; + for (int i = 0; i < nodeOptions.lora_adapter_size(); ++i) { + const auto& loraAdapterOption = nodeOptions.lora_adapter(i); + SPDLOG_INFO("Processing LoRA adapter number {} with model path: {} alpha: {}", i, loraAdapterOption.model_path(), loraAdapterOption.alpha()); + if (loraAdapterOption.alpha() <= 0.0f || loraAdapterOption.alpha() > 1.0f) { + SPDLOG_ERROR("LoRA adapter alpha value {} is out of valid range (0.0, 1.0]", loraAdapterOption.alpha()); + return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED; + } + auto fsLoraPath = std::filesystem::path(loraAdapterOption.model_path()); + std::string loraPath; + if (fsLoraPath.is_relative()) { + loraPath = (std::filesystem::path(graphPath) / fsLoraPath).string(); + } else { + loraPath = fsLoraPath.string(); + } + try { + ov::genai::Adapter adapter(loraPath); + adapterConfig.add(adapter, loraAdapterOption.alpha()); + SPDLOG_INFO("Registered LoRA adapter from path: {} with alpha: {}", loraPath, loraAdapterOption.alpha()); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error during LoRA adapter initialization for model_path: {} exception: {}", loraPath, e.what()); + return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED; + } catch (...) { + SPDLOG_ERROR("Error during LoRA adapter initialization for model_path: {}", loraPath); + return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED; + } + } + // since it is only applied once at initialization, static mode is sufficient and more efficient. + adapterConfig.set_mode(ov::genai::AdapterConfig::MODE_STATIC); + properties->pluginConfig.insert(ov::genai::adapters(adapterConfig)); + return StatusCode::OK; +} + Status parseModelsPath(std::string& outPath, std::string modelsPath, std::string graphPath) { auto fsModelsPath = std::filesystem::path(modelsPath); if (fsModelsPath.is_relative()) { diff --git a/src/llm/servable_initializer.hpp b/src/llm/servable_initializer.hpp index d742db9c3e..f7fd948d87 100644 --- a/src/llm/servable_initializer.hpp +++ b/src/llm/servable_initializer.hpp @@ -61,6 +61,7 @@ class GenAiServableInitializer { virtual Status initialize(std::shared_ptr& servable, const mediapipe::LLMCalculatorOptions& nodeOptions, std::string graphPath) = 0; }; Status parseModelsPath(std::string& outPath, std::string modelsPath, std::string graphPath); +Status initializeLoraAdapters(const mediapipe::LLMCalculatorOptions& nodeOptions, const std::string& graphPath, std::shared_ptr properties); std::optional parseMaxModelLength(std::string& modelsPath); Status determinePipelineType(PipelineType& pipelineType, const mediapipe::LLMCalculatorOptions& nodeOptions, const std::string& graphPath); Status initializeGenAiServable(std::shared_ptr& servable, const ::mediapipe::CalculatorGraphConfig::Node& graphNodeConfig, std::string graphPath); diff --git a/src/llm/visual_language_model/legacy/servable_initializer.cpp b/src/llm/visual_language_model/legacy/servable_initializer.cpp index b8576c5851..44cfc41a3e 100644 --- a/src/llm/visual_language_model/legacy/servable_initializer.cpp +++ b/src/llm/visual_language_model/legacy/servable_initializer.cpp @@ -75,6 +75,11 @@ Status VisualLanguageModelLegacyServableInitializer::initialize(std::shared_ptr< return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED; } + status = initializeLoraAdapters(nodeOptions, graphPath, properties); + if (!status.ok()) { + return status; + } + status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), properties->pluginConfig); if (!status.ok()) { SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config()); diff --git a/src/test/llm/lora_adapter_test.cpp b/src/test/llm/lora_adapter_test.cpp new file mode 100644 index 0000000000..0f719eb0f8 --- /dev/null +++ b/src/test/llm/lora_adapter_test.cpp @@ -0,0 +1,217 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include +#include +#include +#include +#include +#include + +#include +#include + +#pragma warning(push) +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 4005 4456 6246) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include "mediapipe/framework/calculator_graph.h" +#pragma GCC diagnostic pop +#pragma warning(pop) + +#include "src/llm/servable.hpp" +#include "src/llm/servable_initializer.hpp" +#include "src/status.hpp" +#include "src/test/test_with_temp_dir.hpp" + +using namespace ovms; + +class LoraAdapterInitTest : public TestWithTempDir { +protected: + std::shared_ptr properties; + mediapipe::LLMCalculatorOptions nodeOptions; + std::string loraDir; + std::string loraFilePath; + + // Creates a minimal valid safetensors file that ov::genai::Adapter can load + static void createMinimalSafetensorsFile(const std::string& dir) { + std::filesystem::create_directories(dir); + std::string path = dir + "/adapter_model.safetensors"; + + std::string header = + R"({"lora_A.weight":{"dtype":"F32","shape":[1,2],"data_offsets":[0,8]},)" + R"("lora_B.weight":{"dtype":"F32","shape":[2,1],"data_offsets":[8,16]}})"; + // Pad header to 8-byte alignment + while (header.size() % 8 != 0) + header += ' '; + + uint64_t headerLen = header.size(); + std::ofstream f(path, std::ios::binary); + f.write(reinterpret_cast(&headerLen), sizeof(headerLen)); + f.write(header.data(), headerLen); + // 16 bytes of zero tensor data (4 floats of zeros) + const std::vector zeros(16, 0); + f.write(zeros.data(), zeros.size()); + } + + void SetUp() override { + TestWithTempDir::SetUp(); + properties = std::make_shared(); + loraDir = directoryPath + "/lora_adapter"; + createMinimalSafetensorsFile(loraDir); + loraFilePath = loraDir + "/adapter_model.safetensors"; + } +}; + +// --- Protobuf parsing tests --- + +TEST_F(LoraAdapterInitTest, ProtobufLoraAdapterFieldsParsedCorrectly) { + std::string pbtxt = R"( + models_path: "/some/model" + lora_adapter { model_path: "/path/to/lora1" alpha: 0.5 } + lora_adapter { model_path: "/path/to/lora2" } + )"; + mediapipe::LLMCalculatorOptions opts; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(pbtxt, &opts)); + ASSERT_EQ(opts.lora_adapter_size(), 2); + EXPECT_EQ(opts.lora_adapter(0).model_path(), "/path/to/lora1"); + EXPECT_FLOAT_EQ(opts.lora_adapter(0).alpha(), 0.5f); + EXPECT_EQ(opts.lora_adapter(1).model_path(), "/path/to/lora2"); + EXPECT_FLOAT_EQ(opts.lora_adapter(1).alpha(), 1.0f); // default +} + +// --- No adapters --- + +TEST_F(LoraAdapterInitTest, NoAdaptersReturnsOk) { + ASSERT_EQ(initializeLoraAdapters(nodeOptions, "/some/path", properties), StatusCode::OK); + EXPECT_TRUE(properties->pluginConfig.empty()); +} + +// --- Invalid path --- + +TEST_F(LoraAdapterInitTest, NonExistentPathFails) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(directoryPath + "/nonexistent_lora"); + adapter->set_alpha(0.5f); + EXPECT_EQ(initializeLoraAdapters(nodeOptions, "", properties), + StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED); +} + +// --- Alpha validation --- + +TEST_F(LoraAdapterInitTest, AlphaZeroFails) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); + adapter->set_alpha(0.0f); + EXPECT_EQ(initializeLoraAdapters(nodeOptions, "", properties), + StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED); +} + +TEST_F(LoraAdapterInitTest, AlphaNegativeFails) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); + adapter->set_alpha(-0.5f); + EXPECT_EQ(initializeLoraAdapters(nodeOptions, "", properties), + StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED); +} + +TEST_F(LoraAdapterInitTest, AlphaAboveOneFails) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); + adapter->set_alpha(1.5f); + EXPECT_EQ(initializeLoraAdapters(nodeOptions, "", properties), + StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED); +} + +// --- Happy paths --- + +TEST_F(LoraAdapterInitTest, ValidAdapterWithAlpha) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); + adapter->set_alpha(0.5f); + ASSERT_EQ(initializeLoraAdapters(nodeOptions, "", properties), StatusCode::OK); + EXPECT_FALSE(properties->pluginConfig.empty()); + EXPECT_EQ(properties->pluginConfig.count("adapters"), 1); +} + +TEST_F(LoraAdapterInitTest, DefaultAlphaSucceeds) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); + // alpha defaults to 1.0 in proto + ASSERT_EQ(initializeLoraAdapters(nodeOptions, "", properties), StatusCode::OK); + EXPECT_EQ(properties->pluginConfig.count("adapters"), 1); +} + +TEST_F(LoraAdapterInitTest, AlphaExactlyOneSucceeds) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); + adapter->set_alpha(1.0f); + ASSERT_EQ(initializeLoraAdapters(nodeOptions, "", properties), StatusCode::OK); + EXPECT_EQ(properties->pluginConfig.count("adapters"), 1); +} + +TEST_F(LoraAdapterInitTest, MultipleAdaptersRegistered) { + auto* a1 = nodeOptions.add_lora_adapter(); + a1->set_model_path(loraFilePath); + a1->set_alpha(0.3f); + auto* a2 = nodeOptions.add_lora_adapter(); + a2->set_model_path(loraFilePath); + a2->set_alpha(0.7f); + ASSERT_EQ(initializeLoraAdapters(nodeOptions, "", properties), StatusCode::OK); + EXPECT_EQ(properties->pluginConfig.count("adapters"), 1); +} + +// --- Path resolution --- + +TEST_F(LoraAdapterInitTest, RelativePathResolvedAgainstGraphPath) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path("lora_adapter/adapter_model.safetensors"); + adapter->set_alpha(0.5f); + // graphPath = directoryPath, so relative "lora_adapter" resolves to directoryPath/lora_adapter + ASSERT_EQ(initializeLoraAdapters(nodeOptions, directoryPath, properties), StatusCode::OK); + EXPECT_EQ(properties->pluginConfig.count("adapters"), 1); +} + +TEST_F(LoraAdapterInitTest, AbsolutePathIgnoresGraphPath) { + auto* adapter = nodeOptions.add_lora_adapter(); + adapter->set_model_path(loraFilePath); // absolute path + adapter->set_alpha(0.5f); + ASSERT_EQ(initializeLoraAdapters(nodeOptions, "/wrong/graph/path", properties), StatusCode::OK); + EXPECT_EQ(properties->pluginConfig.count("adapters"), 1); +} + +// --- Mixed valid/invalid --- + +TEST_F(LoraAdapterInitTest, SecondAdapterInvalidAlphaFailsAll) { + auto* a1 = nodeOptions.add_lora_adapter(); + a1->set_model_path(loraFilePath); + a1->set_alpha(0.5f); + auto* a2 = nodeOptions.add_lora_adapter(); + a2->set_model_path(loraFilePath); + a2->set_alpha(0.0f); // invalid + EXPECT_EQ(initializeLoraAdapters(nodeOptions, "", properties), + StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED); +} + +TEST_F(LoraAdapterInitTest, SecondAdapterInvalidPathFailsAll) { + auto* a1 = nodeOptions.add_lora_adapter(); + a1->set_model_path(loraFilePath); + a1->set_alpha(0.5f); + auto* a2 = nodeOptions.add_lora_adapter(); + a2->set_model_path(directoryPath + "/no_such_adapter.safetensors"); + a2->set_alpha(0.5f); + EXPECT_EQ(initializeLoraAdapters(nodeOptions, "", properties), + StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED); +}