diff --git a/CMakeLists.txt b/CMakeLists.txt index ce0def6000b..ec552518b9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,7 @@ include(${PROJECT_SOURCE_DIR}/tools/cmake/Codegen.cmake) include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake) include(CMakeDependentOption) include(ExternalProject) +include(FetchContent) include(GNUInstallDirs) if(NOT CMAKE_CXX_STANDARD) @@ -406,6 +407,14 @@ set(_common_include_directories $ ) +if(TARGET jinja2cpp) + install( + TARGETS jinja2cpp + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} + ) +endif() + # # The `__srcs` lists are defined by executorch_load_build_variables. # @@ -803,7 +812,7 @@ endif() if(EXECUTORCH_BUILD_EXTENSION_LLM) if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) - set(SUPPORT_REGEX_LOOKAHEAD ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/chat_template) # llama/runner/CMakeLists.txt builds a shared library libllama_runner.so # that transitively depends on tokenizers. Need to build tokenizers with # -fPIC. diff --git a/extension/llm/chat_template/BUCK b/extension/llm/chat_template/BUCK new file mode 100644 index 00000000000..bd8ea6e199e --- /dev/null +++ b/extension/llm/chat_template/BUCK @@ -0,0 +1,18 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +oncall("executorch") + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +non_fbcode_target(_kind = define_common_targets,) + +# !!!! fbcode/executorch/extension/llm/chat_template/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!! + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +fbcode_target(_kind = define_common_targets,) diff --git a/extension/llm/chat_template/CMakeLists.txt b/extension/llm/chat_template/CMakeLists.txt new file mode 100644 index 00000000000..88ce80b157f --- /dev/null +++ b/extension/llm/chat_template/CMakeLists.txt @@ -0,0 +1,116 @@ +if(NOT EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) + return() +endif() + +include(FetchContent) +cmake_policy(SET CMP0077 NEW) + +FetchContent_Declare( + jinja2cpp + GIT_REPOSITORY https://github.com/jinja2cpp/Jinja2Cpp.git + GIT_TAG 1.3.2 + GIT_SUBMODULES_RECURSE TRUE +) + +set(JINJA2CPP_BUILD_TESTS + OFF + CACHE BOOL "" + FORCE +) +set(JINJA2CPP_BUILD_SHARED + OFF + CACHE BOOL "" + FORCE +) +set(JINJA2CPP_INSTALL + OFF + CACHE BOOL "" + FORCE +) +# Enable PCRE2-based regex lookahead support in Jinja2Cpp. This must be set +# BEFORE FetchContent_MakeAvailable(jinja2cpp) so it propagates to the +# Jinja2Cpp configure step. +set(SUPPORT_REGEX_LOOKAHEAD + ON + CACHE BOOL "" + FORCE +) + +FetchContent_MakeAvailable(jinja2cpp) +if(NOT TARGET jinja2cpp) + message(FATAL_ERROR "Jinja2Cpp target not found after FetchContent.") +endif() + +if(DEFINED jinja2cpp_SOURCE_DIR) + function(executorch_copy_nonstd_header dep_name target header_name dest_root) + set(_copied FALSE) + if(TARGET ${target}) + get_target_property(_aliased ${target} ALIASED_TARGET) + if(_aliased) + set(_resolved_target ${_aliased}) + else() + set(_resolved_target ${target}) + endif() + get_target_property( + _include_dirs ${_resolved_target} INTERFACE_INCLUDE_DIRECTORIES + ) + foreach(_dir IN LISTS _include_dirs) + if(EXISTS "${_dir}/nonstd/${header_name}") + file(MAKE_DIRECTORY "${dest_root}/nonstd") + file( + COPY "${_dir}/nonstd/${header_name}" + DESTINATION "${dest_root}/nonstd" + ) + set(_copied TRUE) + break() + endif() + endforeach() + endif() + if(NOT _copied) + set( + _fallback_path + "${CMAKE_BINARY_DIR}/_deps/${dep_name}-src/include/nonstd/${header_name}" + ) + if(EXISTS "${_fallback_path}") + file(MAKE_DIRECTORY "${dest_root}/nonstd") + file(COPY "${_fallback_path}" DESTINATION "${dest_root}/nonstd") + endif() + endif() + endfunction() + + set(_jinja2cpp_nonstd_root + "${jinja2cpp_SOURCE_DIR}/thirdparty/nonstd" + ) + executorch_copy_nonstd_header( + expected-lite + nonstd::expected-lite + expected.hpp + "${_jinja2cpp_nonstd_root}/expected-lite/include" + ) + executorch_copy_nonstd_header( + variant-lite + nonstd::variant-lite + variant.hpp + "${_jinja2cpp_nonstd_root}/variant-lite/include" + ) + executorch_copy_nonstd_header( + optional-lite + nonstd::optional-lite + optional.hpp + "${_jinja2cpp_nonstd_root}/optional-lite/include" + ) + executorch_copy_nonstd_header( + string-view-lite + nonstd::string-view-lite + string_view.hpp + "${_jinja2cpp_nonstd_root}/string-view-lite/include" + ) +endif() + +# Install the chat_templates.h header so that downstream consumers of the +# installed ExecuTorch SDK can include +# . +install( + FILES chat_templates.h + DESTINATION include/executorch/extension/llm/chat_template +) diff --git a/extension/llm/chat_template/chat_templates.h b/extension/llm/chat_template/chat_templates.h new file mode 100644 index 00000000000..b9e059b80e6 --- /dev/null +++ b/extension/llm/chat_template/chat_templates.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +enum class ChatTemplateType { + None, + Llama3, + Llama32, + Gemma3, + Custom, +}; + +constexpr std::string_view kLlama3Template = R"({{ bos_token }}{%- for message in messages -%}<|start_header_id|>{{ message.role }}<|end_header_id|> + +{{ message.content }}<|eot_id|>{%- endfor -%}{%- if add_generation_prompt -%}<|start_header_id|>assistant<|end_header_id|> + +{%- endif -%})"; + +constexpr std::string_view kGemma3Template = R"({{ bos_token }}{%- for message in messages -%}{%- if message.role == 'assistant' -%}model +{%- else -%}{{ message.role }} +{%- endif -%}{{ message.content }}{%- endfor -%}{%- if add_generation_prompt -%}model +{%- endif -%})"; + +inline const std::unordered_map + kEmbeddedTemplates = { + {ChatTemplateType::Llama3, kLlama3Template}, + {ChatTemplateType::Llama32, kLlama3Template}, + {ChatTemplateType::Gemma3, kGemma3Template}, + }; + +struct ModelTokens { + std::string bos_token; + std::string eos_token; + std::vector stop_tokens; +}; + +inline const std::unordered_map kModelTokens = { + {ChatTemplateType::Llama3, + {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Llama32, + {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Gemma3, + {"", "", {"", ""}}}, +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/chat_template/targets.bzl b/extension/llm/chat_template/targets.bzl new file mode 100644 index 00000000000..197e5da9fb3 --- /dev/null +++ b/extension/llm/chat_template/targets.bzl @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "chat_templates", + exported_headers = [ + "chat_templates.h", + ], + visibility = ["PUBLIC"], + ) diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index 43b89f0a908..da5b42c3f4a 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -54,6 +54,14 @@ endif() list(APPEND runner_deps kernels_util_all_deps) target_link_libraries(extension_llm_runner PUBLIC ${runner_deps}) +target_link_libraries(extension_llm_runner PRIVATE $) +target_include_directories( + extension_llm_runner PRIVATE + $ +) +target_compile_definitions( + extension_llm_runner PUBLIC EXECUTORCH_USE_JINJA2CPP +) set_target_properties( extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON ) @@ -116,23 +124,19 @@ if(EXECUTORCH_BUILD_PYBIND) portable_lib ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} ) + # Set properties for the Python extension set_target_properties( _llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_VISIBILITY_PRESET "hidden" INTERPROCEDURAL_OPTIMIZATION TRUE - CXX_STANDARD 20 ) if(APPLE) - set(RPATH - "@loader_path/../../pybindings;@loader_path/../../../../torch/lib" - ) + set(RPATH "@loader_path/../../pybindings") else() - set(RPATH "$ORIGIN/../../pybindings:$ORIGIN/../../../../torch/lib") + set(RPATH "$ORIGIN/../../pybindings") endif() - set_target_properties( - _llm_runner PROPERTIES BUILD_RPATH "${RPATH}" INSTALL_RPATH "${RPATH}" - ) + set_target_properties(_llm_runner PROPERTIES INSTALL_RPATH ${RPATH}) # Add include directories target_include_directories( _llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS} diff --git a/extension/llm/runner/chat_types.h b/extension/llm/runner/chat_types.h new file mode 100644 index 00000000000..6a7cd2b625f --- /dev/null +++ b/extension/llm/runner/chat_types.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace executorch::extension::llm { + +struct ChatMessage { + std::string role; + std::string content; +}; + +struct ChatConversation { + std::vector messages; + std::string bos_token; + std::string eos_token; + bool add_generation_prompt = true; +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/jinja_chat_formatter.cpp b/extension/llm/runner/jinja_chat_formatter.cpp new file mode 100644 index 00000000000..00054554886 --- /dev/null +++ b/extension/llm/runner/jinja_chat_formatter.cpp @@ -0,0 +1,236 @@ +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { +namespace { + +std::string readFileToString(const std::filesystem::path& path) { + std::ifstream file(path); + if (!file) { + throw std::runtime_error("Failed to open template file: " + path.string()); + } + std::ostringstream buffer; + buffer << file.rdbuf(); + return buffer.str(); +} + +bool templateIncludesBos( + const std::string& template_str, + const ModelTokens& model_tokens) { + if (!model_tokens.bos_token.empty() && + template_str.find(model_tokens.bos_token) != std::string::npos) { + return true; + } + return template_str.find("bos_token") != std::string::npos; +} + +std::string normalizeTemplate(std::string input) { + // These replacements normalize vLLM/HuggingFace Jinja templates so they + // compile/render correctly with Jinja2Cpp, which has stricter parser + // semantics than Python Jinja2. + // + // IMPORTANT: "not tools is none" in Python Jinja means "tools is not none" + // (truthy when tools is defined and non-null), so we map it to a simple + // truthy check on `tools`. Mapping to "not tools" was a bug that would + // skip tool blocks for non-empty tools lists. + constexpr std::array, 10> + replacements = {{ + {"tools = none", "tools = []"}, + {"tools = None", "tools = []"}, + {"tools is not none", "tools"}, + {"tools is not None", "tools"}, + {"not tools is none", "tools"}, + {"not tools is None", "tools"}, + {"tools is none", "not tools"}, + {"tools is None", "not tools"}, + {"messages[1:]", "messages_tail"}, + {"{ \"output\": message.content } | tojson", "message.content | tojson"}, + }}; + // Handle special case that can't be constexpr due to escape sequence + const std::pair gemmaReplacement = { + "{{'model\\n'}}", "{{ 'model\\n' }}"}; + for (const auto& replacement : replacements) { + size_t pos = 0; + while ((pos = input.find(replacement.first, pos)) != std::string::npos) { + input.replace(pos, replacement.first.size(), replacement.second); + pos += replacement.second.size(); + } + } + // Apply the gemma replacement separately + size_t pos = 0; + while ((pos = input.find(gemmaReplacement.first, pos)) != std::string::npos) { + input.replace(pos, gemmaReplacement.first.size(), gemmaReplacement.second); + pos += gemmaReplacement.second.size(); + } + return input; +} + +ChatTemplateType detectTemplateType(const std::string& template_str) { + if (template_str.find("") != std::string::npos) { + return ChatTemplateType::Gemma3; + } + if (template_str.find("<|start_header_id|>") != std::string::npos) { + return ChatTemplateType::Llama3; + } + return ChatTemplateType::Custom; +} + +std::string toLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; +} + +} // namespace + +} // namespace executorch::extension::llm + +namespace jinja2 { + +// NOLINTBEGIN(facebook-hte-MisplacedTemplateSpecialization,facebook-hte-ShadowingClass) +// This template specialization must be in the jinja2 namespace for the library +// to find it via ADL during template instantiation. +template <> +struct TypeReflection + : TypeReflected { + static auto& GetAccessors() { + static std::unordered_map accessors = { + {"role", + [](const executorch::extension::llm::ChatMessage& msg) { + return jinja2::Reflect(msg.role); + }}, + {"content", + [](const executorch::extension::llm::ChatMessage& msg) { + return jinja2::Reflect(msg.content); + }}, + }; + return accessors; + } +}; +// NOLINTEND(facebook-hte-MisplacedTemplateSpecialization,facebook-hte-ShadowingClass) + +} // namespace jinja2 + +namespace executorch::extension::llm { + +JinjaChatFormatter::JinjaChatFormatter( + const std::string& template_str, + ChatTemplateType type) + : template_str_(template_str), type_(type) { + auto tokens_it = kModelTokens.find(type_); + if (tokens_it != kModelTokens.end()) { + model_tokens_ = tokens_it->second; + } + includes_bos_ = templateIncludesBos(template_str_, model_tokens_); + const std::string normalized_template = normalizeTemplate(template_str_); + compiled_template_ = std::make_unique(); + auto load_result = compiled_template_->Load(normalized_template); + if (!load_result) { + throw std::runtime_error( + "Failed to parse chat template: " + + load_result.error().ToString()); + } +} + +JinjaChatFormatter::~JinjaChatFormatter() = default; + +std::unique_ptr JinjaChatFormatter::fromTemplate( + ChatTemplateType type) { + auto it = kEmbeddedTemplates.find(type); + if (it == kEmbeddedTemplates.end()) { + throw std::runtime_error("Unsupported embedded chat template type."); + } + return std::unique_ptr( + new JinjaChatFormatter(std::string(it->second), type)); +} + +std::unique_ptr JinjaChatFormatter::fromString( + const std::string& template_str) { + const ChatTemplateType inferred_type = detectTemplateType(template_str); + return std::unique_ptr( + new JinjaChatFormatter(template_str, inferred_type)); +} + +std::unique_ptr JinjaChatFormatter::fromFile( + const std::filesystem::path& path) { + return fromString(readFileToString(path)); +} + +std::string JinjaChatFormatter::format( + const std::string& prompt, + const std::string& system_prompt) const { + ChatConversation conversation; + if (!system_prompt.empty()) { + conversation.messages.push_back({"system", system_prompt}); + } + conversation.messages.push_back({"user", prompt}); + conversation.bos_token = model_tokens_.bos_token; + conversation.eos_token = model_tokens_.eos_token; + conversation.add_generation_prompt = true; + return formatConversation(conversation); +} + +std::string JinjaChatFormatter::formatConversation( + const ChatConversation& conversation) const { + jinja2::ValuesMap params; + params["messages"] = jinja2::ValuesList(); + params["messages_tail"] = jinja2::ValuesList(); + bool is_first = true; + for (const auto& msg : conversation.messages) { + params["messages"].asList().push_back(jinja2::Reflect(msg)); + if (!is_first) { + params["messages_tail"].asList().push_back(jinja2::Reflect(msg)); + } + is_first = false; + } + params["bos_token"] = conversation.bos_token; + params["eos_token"] = conversation.eos_token; + params["add_generation_prompt"] = conversation.add_generation_prompt; + // Provide vLLM/HuggingFace-style defaults that templates often reference. + // Templates that don't use these will simply ignore them. + params["tools"] = jinja2::ValuesList(); + params["tool_choice"] = jinja2::Value(); + params["date_string"] = std::string("26 Jul 2024"); + params["chat_template_kwargs"] = jinja2::ValuesMap(); + + auto rendered = compiled_template_->RenderAsString(params); + if (!rendered) { + throw std::runtime_error( + "Failed to render chat template: " + rendered.error().ToString()); + } + return rendered.value(); +} + +ChatTemplateType parseChatTemplateType(const std::string& type_str) { + const std::string lower = toLower(type_str); + if (lower == "none") { + return ChatTemplateType::None; + } + if (lower == "llama3") { + return ChatTemplateType::Llama3; + } + if (lower == "llama3.2" || lower == "llama32" || lower == "llama3_2") { + return ChatTemplateType::Llama32; + } + if (lower == "gemma3") { + return ChatTemplateType::Gemma3; + } + if (lower == "custom" || lower == "jinja") { + return ChatTemplateType::Custom; + } + return ChatTemplateType::None; +} + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/jinja_chat_formatter.h b/extension/llm/runner/jinja_chat_formatter.h new file mode 100644 index 00000000000..e589a9ed18e --- /dev/null +++ b/extension/llm/runner/jinja_chat_formatter.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace jinja2 { +class Template; +} + +namespace executorch::extension::llm { + +class JinjaChatFormatter { + public: + static std::unique_ptr fromTemplate(ChatTemplateType type); + static std::unique_ptr fromString( + const std::string& template_str); + static std::unique_ptr fromFile( + const std::filesystem::path& path); + + ~JinjaChatFormatter(); + + std::string format( + const std::string& prompt, + const std::string& system_prompt = "") const; + std::string formatConversation(const ChatConversation& conversation) const; + + bool includesBos() const { + return includes_bos_; + } + + const ModelTokens& getModelTokens() const { + return model_tokens_; + } + + private: + JinjaChatFormatter(const std::string& template_str, ChatTemplateType type); + + std::string template_str_; + ChatTemplateType type_; + ModelTokens model_tokens_; + bool includes_bos_ = false; + std::unique_ptr compiled_template_; +}; + +ChatTemplateType parseChatTemplateType(const std::string& type_str); + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 0744c09e641..ddaf280a9ee 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -158,7 +158,9 @@ std::unordered_set get_eos_ids( tokenizers::Tokenizer* tokenizer, Module* module) { std::unordered_set eos_ids = {tokenizer->eos_tok()}; - // Get EOS IDs if available + ET_LOG(Info, "Primary eos_tok = %" PRIu64, tokenizer->eos_tok()); + + // Get EOS IDs from model metadata if available auto method_names_result = module->method_names(); if (method_names_result.error() != Error::Ok) { ET_LOG(Error, "Failed reading method names"); @@ -167,7 +169,6 @@ std::unordered_set get_eos_ids( const auto& method_names = method_names_result.get(); if (method_names.count(llm::kEosIds)) { - eos_ids.clear(); auto execute_result = module->execute(llm::kEosIds); if (execute_result.error() != Error::Ok) { ET_LOG(Error, "Failed to execute %s", llm::kEosIds); @@ -175,8 +176,10 @@ std::unordered_set get_eos_ids( } for (const auto& eos_id : execute_result.get()) { auto value = eos_id.toScalar().to(); - eos_ids.emplace(value); - ET_LOG(Info, "eos_id = %" PRId64, value); + auto [_, inserted] = eos_ids.emplace(value); + if (inserted) { + ET_LOG(Info, "Added eos_id from model metadata: %" PRId64, value); + } } } return eos_ids; diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 0d4ed99308d..f8ceed67bf1 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -111,11 +111,14 @@ def define_common_targets(): runtime.cxx_library( name = "runner_lib" + aten_suffix, exported_headers = [ + "chat_types.h", + "jinja_chat_formatter.h", "text_llm_runner.h", "llm_runner_helper.h", "constants.h", ], srcs = [ + "jinja_chat_formatter.cpp", "text_llm_runner.cpp", "llm_runner_helper.cpp", "multimodal_runner.cpp", @@ -131,6 +134,7 @@ def define_common_targets(): ":text_decoder_runner" + aten_suffix, ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, + "//executorch/extension/llm/chat_template:chat_templates", "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, "//executorch/extension/memory_allocator:cpu_caching_allocator", "//pytorch/tokenizers:hf_tokenizer", @@ -138,5 +142,6 @@ def define_common_targets(): "//pytorch/tokenizers:sentencepiece", "//pytorch/tokenizers:tekken", "//pytorch/tokenizers:tiktoken", + "@fbsource//third-party/jinja2cpp:jinja2cpp", ], ) diff --git a/extension/llm/runner/test/BUCK b/extension/llm/runner/test/BUCK index 9ed85acadb7..cb8e8fcfb7e 100644 --- a/extension/llm/runner/test/BUCK +++ b/extension/llm/runner/test/BUCK @@ -10,32 +10,21 @@ oncall("executorch") # targets.bzl. This file can contain fbcode-only targets. load(":targets.bzl", "define_common_targets") - +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") non_fbcode_target(_kind = define_common_targets,) -# !!!! fbcode/executorch/extension/llm/runner/test/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!! - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Any targets that should be shared between fbcode and xplat must be defined in -# targets.bzl. This file can contain fbcode-only targets. - -load(":targets.bzl", "define_common_targets") -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") - fbcode_target(_kind = define_common_targets,) +# fbcode-only test that requires access to fbcode//executorch/test/models fbcode_target(_kind = runtime.cxx_test, name = "test_text_decoder_runner", srcs = ["test_text_decoder_runner.cpp"], deps = [ - "//executorch/extension/llm/runner:runner_lib", "//executorch/extension/llm/runner/io_manager:io_manager", + "//executorch/extension/llm/runner:text_decoder_runner", + "//executorch/extension/module:module", + "//executorch/extension/tensor:tensor", "//executorch/kernels/portable:generated_lib", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], @@ -43,5 +32,5 @@ fbcode_target(_kind = runtime.cxx_test, "KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])", "KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])", "NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])", - } + }, ) diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 81b69c0ab9a..64d89ff0ec4 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -19,6 +19,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp + test_jinja_chat_formatter.cpp test_text_llm_runner.cpp test_text_prefiller.cpp test_text_decoder_runner.cpp diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 08044de2d35..a8e3858af9d 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -17,6 +17,14 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "test_jinja_chat_formatter", + srcs = ["test_jinja_chat_formatter.cpp"], + deps = [ + "//executorch/extension/llm/runner:runner_lib", + ], + ) + runtime.cxx_test( name = "test_text_llm_runner", srcs = ["test_text_llm_runner.cpp"], diff --git a/extension/llm/runner/test/test_jinja_chat_formatter.cpp b/extension/llm/runner/test/test_jinja_chat_formatter.cpp new file mode 100644 index 00000000000..d1a3d6d9fff --- /dev/null +++ b/extension/llm/runner/test/test_jinja_chat_formatter.cpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using executorch::extension::llm::ChatConversation; +using executorch::extension::llm::ChatMessage; +using executorch::extension::llm::ChatTemplateType; +using executorch::extension::llm::JinjaChatFormatter; +using executorch::extension::llm::parseChatTemplateType; +using testing::HasSubstr; + +TEST(JinjaChatFormatter, Llama3SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const std::string prompt = "Test prompt"; + const std::string system_prompt = "You are a helpful assistant."; + // Note: The Jinja template uses {%- ... -%} which strips whitespace, + // so the output has \n\n after each <|end_header_id|> and content, + // but no trailing \n\n at the end due to the {%- endif -%} stripping. + const std::string expected = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + + system_prompt + + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + prompt + + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; + + EXPECT_EQ(formatter->format(prompt, system_prompt), expected); +} + +TEST(JinjaChatFormatter, Gemma3SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("")); + EXPECT_THAT(result, HasSubstr("user")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("model")); +} + +TEST(JinjaChatFormatter, Llama3WithoutSystemPrompt) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); + // Should not contain system header when no system prompt + EXPECT_THAT(result, ::testing::Not(HasSubstr("<|start_header_id|>system"))); +} + +TEST(JinjaChatFormatter, Llama3IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Gemma3IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Llama3ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, "<|begin_of_text|>"); + EXPECT_EQ(tokens.eos_token, "<|eot_id|>"); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +TEST(JinjaChatFormatter, Gemma3ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, ""); + EXPECT_EQ(tokens.eos_token, ""); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +TEST(JinjaChatFormatter, FormatConversationMultiTurn) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + + ChatConversation conversation; + conversation.bos_token = "<|begin_of_text|>"; + conversation.eos_token = "<|eot_id|>"; + conversation.add_generation_prompt = true; + conversation.messages = { + {"user", "Hello"}, + {"assistant", "Hi there!"}, + {"user", "How are you?"}, + }; + + const std::string result = formatter->formatConversation(conversation); + + EXPECT_THAT(result, HasSubstr("Hello")); + EXPECT_THAT(result, HasSubstr("Hi there!")); + EXPECT_THAT(result, HasSubstr("How are you?")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +} + +TEST(JinjaChatFormatter, FromStringLlama3Template) { + const std::string llama_template = + "{{ bos_token }}<|start_header_id|>user<|end_header_id|>\n\n" + "{{ messages[0].content }}<|eot_id|>"; + + auto formatter = JinjaChatFormatter::fromString(llama_template); + + // Should detect Llama3 type from the template content + const std::string result = formatter->format("Test"); + EXPECT_THAT(result, HasSubstr("Test")); +} + +TEST(JinjaChatFormatter, FromStringGemma3Template) { + const std::string gemma_template = + "{{ bos_token }}user\n" + "{{ messages[0].content }}"; + + auto formatter = JinjaChatFormatter::fromString(gemma_template); + + const std::string result = formatter->format("Test"); + EXPECT_THAT(result, HasSubstr("Test")); +} + +TEST(JinjaChatFormatter, UnsupportedTemplateTypeThrows) { + EXPECT_THROW( + JinjaChatFormatter::fromTemplate(ChatTemplateType::None), + std::runtime_error); +} + +// Tests for parseChatTemplateType +TEST(ParseChatTemplateType, ParseNone) { + EXPECT_EQ(parseChatTemplateType("none"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("None"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("NONE"), ChatTemplateType::None); +} + +TEST(ParseChatTemplateType, ParseLlama3) { + EXPECT_EQ(parseChatTemplateType("llama3"), ChatTemplateType::Llama3); + EXPECT_EQ(parseChatTemplateType("LLAMA3"), ChatTemplateType::Llama3); + EXPECT_EQ(parseChatTemplateType("Llama3"), ChatTemplateType::Llama3); +} + +TEST(ParseChatTemplateType, ParseLlama32Variants) { + EXPECT_EQ(parseChatTemplateType("llama3.2"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("llama32"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("llama3_2"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("LLAMA3.2"), ChatTemplateType::Llama32); +} + +TEST(ParseChatTemplateType, ParseGemma3) { + EXPECT_EQ(parseChatTemplateType("gemma3"), ChatTemplateType::Gemma3); + EXPECT_EQ(parseChatTemplateType("GEMMA3"), ChatTemplateType::Gemma3); + EXPECT_EQ(parseChatTemplateType("Gemma3"), ChatTemplateType::Gemma3); +} + +TEST(ParseChatTemplateType, ParseCustom) { + EXPECT_EQ(parseChatTemplateType("custom"), ChatTemplateType::Custom); + EXPECT_EQ(parseChatTemplateType("jinja"), ChatTemplateType::Custom); + EXPECT_EQ(parseChatTemplateType("CUSTOM"), ChatTemplateType::Custom); +} + +TEST(ParseChatTemplateType, ParseUnknownReturnsNone) { + EXPECT_EQ(parseChatTemplateType("unknown"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType(""), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("invalid"), ChatTemplateType::None); +} + +TEST(JinjaChatFormatter, Llama32SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + const std::string result = formatter->format("Hello!"); + + // Llama32 uses the same template as Llama3 + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +} + +TEST(JinjaChatFormatter, Llama32IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Llama32ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, "<|begin_of_text|>"); + EXPECT_EQ(tokens.eos_token, "<|eot_id|>"); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +// Universal Jinja support: any HuggingFace / vLLM-style template string +// (passed via fromString or fromFile) should work, not just the embedded +// Llama3/Gemma3 templates. This validates the renderer with a generic +// HuggingFace-style template that uses the standard `messages`, +// `bos_token`, and `add_generation_prompt` variables. +TEST(JinjaChatFormatter, UniversalJinjaGenericTemplate) { + // Generic chat template inspired by HuggingFace tokenizer_config.json + // examples. Uses only standard Jinja2 features and standard chat variables. + const std::string generic_template = + "{{ bos_token }}" + "{%- for message in messages -%}" + "<|{{ message.role }}|>\n{{ message.content }}<|end|>\n" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}<|assistant|>\n{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(generic_template); + ChatConversation conv; + conv.bos_token = ""; + conv.add_generation_prompt = true; + conv.messages.push_back(ChatMessage{"user", "Hi there"}); + + const std::string result = formatter->formatConversation(conv); + + EXPECT_THAT(result, HasSubstr("")); + EXPECT_THAT(result, HasSubstr("<|user|>")); + EXPECT_THAT(result, HasSubstr("Hi there")); + EXPECT_THAT(result, HasSubstr("<|assistant|>")); +} + +// Universal Jinja support: templates that reference `tools` (e.g. vLLM's +// tool_chat_template_*.jinja files) should not crash even when no tools +// are passed. The formatter injects `tools = []` (an empty list) so that +// truthy/none checks evaluate consistently. With our normalization, +// `tools is not none` is rewritten to `tools` (a truthy check), which means +// an empty list (no tools) is treated as "no tools available" — matching +// the typical template intent. +TEST(JinjaChatFormatter, UniversalJinjaToolsAwareTemplate) { + const std::string tools_template = + "{%- if tools is not none -%}" + "tools_present" + "{%- else -%}" + "no_tools" + "{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(tools_template); + ChatConversation conv; + conv.add_generation_prompt = false; + + // tools defaults to [] (empty list). After normalization this is a + // truthy check, so the empty-list (no tools) branch should be selected. + EXPECT_EQ(formatter->formatConversation(conv), "no_tools"); +} + +// Universal Jinja support: a template using "not tools is none" (semantically +// "tools is not none") should now safely evaluate without skipping the +// "no tools" branch. This guards the regression where the normalizer mapped +// "not tools is none" -> "not tools", which incorrectly evaluated to true +// for an empty list and would have rendered a tool block when none was +// intended. +TEST(JinjaChatFormatter, UniversalJinjaNormalizedNotToolsIsNone) { + const std::string template_str = + "{%- if not tools is none -%}defined{%- else -%}none{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(template_str); + ChatConversation conv; + conv.add_generation_prompt = false; + + // tools = [] is falsy after normalization (`not tools is none` -> `tools`). + // The "else" branch should be selected (no tools available). + EXPECT_EQ(formatter->formatConversation(conv), "none"); +} +} diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index cf7ab50b9c8..3010aea6583 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -20,8 +20,51 @@ #include #include +#include + namespace executorch::extension::llm { +namespace { +// Known special tokens used by LLM chat templates. +// When echo=false, these are filtered out of the streamed output so users +// see clean assistant text. When echo=true, the user explicitly asked for +// raw model output, so we emit them unchanged. +const std::unordered_set kKnownSpecialTokens = { + // Llama 3.x tokens + "<|begin_of_text|>", + "<|end_of_text|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eot_id|>", + // Gemma tokens + "", + "", + "", + "", + // Common tokens + "", + "", + "", + "", +}; + +bool is_special_token(const std::string& text) { + if (text.empty()) { + return false; + } + // Check against known special tokens + if (kKnownSpecialTokens.count(text) > 0) { + return true; + } + // Match Llama-style tokens: <|...|> + if (text.size() >= 4 && text.front() == '<' && text[1] == '|' && + text.back() == '>' && text[text.size() - 2] == '|') { + return true; + } + return false; +} +} // namespace + using ::executorch::extension::Module; using ::executorch::runtime::Error; using ::executorch::runtime::Result; @@ -96,8 +139,14 @@ Error TextLLMRunner::generate( std::function wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { - llm::safe_printf(piece.c_str()); - fflush(stdout); + // When echo=false, filter out special tokens (e.g. <|eot_id|>) so + // users get clean assistant output. When echo=true, the user asked + // for raw model output including any chat-template tokens, so emit + // everything unchanged. + if (config.echo || !is_special_token(piece)) { + llm::safe_printf(piece.c_str()); + fflush(stdout); + } } if (token_callback) { token_callback(piece); diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index b0545b8ce18..971dde20867 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -36,7 +36,6 @@ PROGRAM_NO_PRIM_OPS_SRCS = [ "method.cpp", "method_meta.cpp", "program.cpp", - "program_validation.cpp", "tensor_parser_exec_aten.cpp", ] @@ -357,6 +356,7 @@ EXTENSION_RUNNER_UTIL_SRCS = [ ] EXTENSION_LLM_RUNNER_SRCS = [ + "extension/llm/runner/jinja_chat_formatter.cpp", "extension/llm/runner/llm_runner_helper.cpp", "extension/llm/runner/multimodal_prefiller.cpp", "extension/llm/runner/multimodal_runner.cpp", @@ -476,7 +476,6 @@ XNNPACK_BACKEND_BUCK_SRCS = [ "runtime/XNNPACKBackend.cpp", "runtime/XNNWeightsCache.cpp", "runtime/XNNWorkspaceManager.cpp", - "runtime/XnnpackBackendOptions.cpp", "runtime/profiling/XNNProfiler.cpp", ]