From cea43ac0da8b1c6cc967eba1510390183fd2abf4 Mon Sep 17 00:00:00 2001 From: Young Han Date: Tue, 12 May 2026 17:27:40 -0700 Subject: [PATCH 1/3] [llm][1/4] Add Jinja2Cpp-based chat template formatter library MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundation PR for the chat-template support stack. Adds the Jinja2Cpp-based JinjaChatFormatter, supporting chat-types, embedded Llama3/Llama3.2/Gemma3 templates, build glue (CMake/Buck), and a focused C++ unit-test suite. This PR is reviewable in isolation — it has no behavior change for any existing runner; downstream PRs (B/C/D) plug it in. This is part 1 of a 4-PR stack split out of #16987 per reviewer request: 1/4 (this PR) Library + tests 2/4 TextLLMRunner echo-gated special-token filter + EOS merge 3/4 Python bindings + Python LlamaRunner integration 4/4 llama_main CLI flags + chat_formatter wrapper + docs What this PR adds ----------------- * extension/llm/chat_template/{chat_templates.h, BUCK, CMakeLists.txt, targets.bzl} — embedded Llama3/Llama3.2/Gemma3 templates and the ChatTemplateType enum + ModelTokens. The CMake file FetchContent's Jinja2Cpp 1.3.2, with SUPPORT_REGEX_LOOKAHEAD set BEFORE FetchContent_MakeAvailable so it propagates correctly, plus header staging for nonstd headers that some Jinja2Cpp installations omit. Installs chat_templates.h so SDK consumers can include it. * extension/llm/runner/{chat_types.h, jinja_chat_formatter.{h,cpp}} — the Universal Jinja chat formatter that supports any HuggingFace / vLLM chat template, not just the embedded ones. Loadable via fromTemplate (built-in), fromString (any string), or fromFile (any .jinja file). formatConversation injects vLLM/HuggingFace-standard params (tools=[], tool_choice=None, date_string, chat_template_kwargs) so any template that references those variables renders correctly. * normalizeTemplate handles vLLM/HF template quirks for Jinja2Cpp: notably, 'not tools is none' maps to 'tools' (truthy check), preserving the intent of 'tools is not none' for empty-list defaults. * extension/llm/runner/{CMakeLists.txt, targets.bzl} — link extension_llm_runner against jinja2cpp (PRIVATE) and define EXECUTORCH_USE_JINJA2CPP. * extension/llm/runner/test/{test_jinja_chat_formatter.cpp, CMakeLists.txt, targets.bzl, BUCK} — unit tests covering Llama3 / Llama3.2 / Gemma3 embedded templates, parseChatTemplateType (case-insensitive), and three universal-Jinja regression tests: - generic HuggingFace-style template (proves it's not Llama-specific) - tools-aware template (validates the tools=[] default) - 'not tools is none' normalization regression test * CMakeLists.txt — adds add_subdirectory(extension/llm/chat_template) guarded by EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER. * shim_et/xplat/executorch/build/build_variables.bzl — adds jinja_chat_formatter.cpp to the runner sources. Notes ----- * No behavior change for existing TextLLMRunner / MultimodalRunner users: the formatter is opt-in, only invoked when downstream code calls it. * Sample vLLM templates are NOT checked in (per reviewer feedback); documentation in the follow-up CLI PR points users to vLLM's examples directory and HuggingFace tokenizer_config.json files. Original PR (full stack): https://github.com/pytorch/executorch/pull/16987 --- CMakeLists.txt | 11 +- extension/llm/chat_template/BUCK | 18 ++ extension/llm/chat_template/CMakeLists.txt | 116 ++++++++ extension/llm/chat_template/chat_templates.h | 51 ++++ extension/llm/chat_template/targets.bzl | 16 ++ extension/llm/runner/CMakeLists.txt | 20 +- extension/llm/runner/chat_types.h | 20 ++ extension/llm/runner/jinja_chat_formatter.cpp | 236 +++++++++++++++ extension/llm/runner/jinja_chat_formatter.h | 51 ++++ extension/llm/runner/targets.bzl | 5 + extension/llm/runner/test/BUCK | 23 +- extension/llm/runner/test/CMakeLists.txt | 1 + extension/llm/runner/test/targets.bzl | 8 + .../runner/test/test_jinja_chat_formatter.cpp | 270 ++++++++++++++++++ .../executorch/build/build_variables.bzl | 3 +- 15 files changed, 821 insertions(+), 28 deletions(-) create mode 100644 extension/llm/chat_template/BUCK create mode 100644 extension/llm/chat_template/CMakeLists.txt create mode 100644 extension/llm/chat_template/chat_templates.h create mode 100644 extension/llm/chat_template/targets.bzl create mode 100644 extension/llm/runner/chat_types.h create mode 100644 extension/llm/runner/jinja_chat_formatter.cpp create mode 100644 extension/llm/runner/jinja_chat_formatter.h create mode 100644 extension/llm/runner/test/test_jinja_chat_formatter.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 359a0e0f5e4..30dc0dabccb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,7 @@ include(${PROJECT_SOURCE_DIR}/tools/cmake/Codegen.cmake) include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake) include(CMakeDependentOption) include(ExternalProject) +include(FetchContent) include(GNUInstallDirs) if(NOT CMAKE_CXX_STANDARD) @@ -409,6 +410,14 @@ set(_common_include_directories $ ) +if(TARGET jinja2cpp) + install( + TARGETS jinja2cpp + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} + ) +endif() + # # The `__srcs` lists are defined by executorch_load_build_variables. # @@ -806,7 +815,7 @@ endif() if(EXECUTORCH_BUILD_EXTENSION_LLM) if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) - set(SUPPORT_REGEX_LOOKAHEAD ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/chat_template) # llama/runner/CMakeLists.txt builds a shared library libllama_runner.so # that transitively depends on tokenizers. Need to build tokenizers with # -fPIC. diff --git a/extension/llm/chat_template/BUCK b/extension/llm/chat_template/BUCK new file mode 100644 index 00000000000..bd8ea6e199e --- /dev/null +++ b/extension/llm/chat_template/BUCK @@ -0,0 +1,18 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +oncall("executorch") + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +non_fbcode_target(_kind = define_common_targets,) + +# !!!! fbcode/executorch/extension/llm/chat_template/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!! + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +fbcode_target(_kind = define_common_targets,) diff --git a/extension/llm/chat_template/CMakeLists.txt b/extension/llm/chat_template/CMakeLists.txt new file mode 100644 index 00000000000..88ce80b157f --- /dev/null +++ b/extension/llm/chat_template/CMakeLists.txt @@ -0,0 +1,116 @@ +if(NOT EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) + return() +endif() + +include(FetchContent) +cmake_policy(SET CMP0077 NEW) + +FetchContent_Declare( + jinja2cpp + GIT_REPOSITORY https://github.com/jinja2cpp/Jinja2Cpp.git + GIT_TAG 1.3.2 + GIT_SUBMODULES_RECURSE TRUE +) + +set(JINJA2CPP_BUILD_TESTS + OFF + CACHE BOOL "" + FORCE +) +set(JINJA2CPP_BUILD_SHARED + OFF + CACHE BOOL "" + FORCE +) +set(JINJA2CPP_INSTALL + OFF + CACHE BOOL "" + FORCE +) +# Enable PCRE2-based regex lookahead support in Jinja2Cpp. This must be set +# BEFORE FetchContent_MakeAvailable(jinja2cpp) so it propagates to the +# Jinja2Cpp configure step. +set(SUPPORT_REGEX_LOOKAHEAD + ON + CACHE BOOL "" + FORCE +) + +FetchContent_MakeAvailable(jinja2cpp) +if(NOT TARGET jinja2cpp) + message(FATAL_ERROR "Jinja2Cpp target not found after FetchContent.") +endif() + +if(DEFINED jinja2cpp_SOURCE_DIR) + function(executorch_copy_nonstd_header dep_name target header_name dest_root) + set(_copied FALSE) + if(TARGET ${target}) + get_target_property(_aliased ${target} ALIASED_TARGET) + if(_aliased) + set(_resolved_target ${_aliased}) + else() + set(_resolved_target ${target}) + endif() + get_target_property( + _include_dirs ${_resolved_target} INTERFACE_INCLUDE_DIRECTORIES + ) + foreach(_dir IN LISTS _include_dirs) + if(EXISTS "${_dir}/nonstd/${header_name}") + file(MAKE_DIRECTORY "${dest_root}/nonstd") + file( + COPY "${_dir}/nonstd/${header_name}" + DESTINATION "${dest_root}/nonstd" + ) + set(_copied TRUE) + break() + endif() + endforeach() + endif() + if(NOT _copied) + set( + _fallback_path + "${CMAKE_BINARY_DIR}/_deps/${dep_name}-src/include/nonstd/${header_name}" + ) + if(EXISTS "${_fallback_path}") + file(MAKE_DIRECTORY "${dest_root}/nonstd") + file(COPY "${_fallback_path}" DESTINATION "${dest_root}/nonstd") + endif() + endif() + endfunction() + + set(_jinja2cpp_nonstd_root + "${jinja2cpp_SOURCE_DIR}/thirdparty/nonstd" + ) + executorch_copy_nonstd_header( + expected-lite + nonstd::expected-lite + expected.hpp + "${_jinja2cpp_nonstd_root}/expected-lite/include" + ) + executorch_copy_nonstd_header( + variant-lite + nonstd::variant-lite + variant.hpp + "${_jinja2cpp_nonstd_root}/variant-lite/include" + ) + executorch_copy_nonstd_header( + optional-lite + nonstd::optional-lite + optional.hpp + "${_jinja2cpp_nonstd_root}/optional-lite/include" + ) + executorch_copy_nonstd_header( + string-view-lite + nonstd::string-view-lite + string_view.hpp + "${_jinja2cpp_nonstd_root}/string-view-lite/include" + ) +endif() + +# Install the chat_templates.h header so that downstream consumers of the +# installed ExecuTorch SDK can include +# . +install( + FILES chat_templates.h + DESTINATION include/executorch/extension/llm/chat_template +) diff --git a/extension/llm/chat_template/chat_templates.h b/extension/llm/chat_template/chat_templates.h new file mode 100644 index 00000000000..b9e059b80e6 --- /dev/null +++ b/extension/llm/chat_template/chat_templates.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +enum class ChatTemplateType { + None, + Llama3, + Llama32, + Gemma3, + Custom, +}; + +constexpr std::string_view kLlama3Template = R"({{ bos_token }}{%- for message in messages -%}<|start_header_id|>{{ message.role }}<|end_header_id|> + +{{ message.content }}<|eot_id|>{%- endfor -%}{%- if add_generation_prompt -%}<|start_header_id|>assistant<|end_header_id|> + +{%- endif -%})"; + +constexpr std::string_view kGemma3Template = R"({{ bos_token }}{%- for message in messages -%}{%- if message.role == 'assistant' -%}model +{%- else -%}{{ message.role }} +{%- endif -%}{{ message.content }}{%- endfor -%}{%- if add_generation_prompt -%}model +{%- endif -%})"; + +inline const std::unordered_map + kEmbeddedTemplates = { + {ChatTemplateType::Llama3, kLlama3Template}, + {ChatTemplateType::Llama32, kLlama3Template}, + {ChatTemplateType::Gemma3, kGemma3Template}, + }; + +struct ModelTokens { + std::string bos_token; + std::string eos_token; + std::vector stop_tokens; +}; + +inline const std::unordered_map kModelTokens = { + {ChatTemplateType::Llama3, + {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Llama32, + {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Gemma3, + {"", "", {"", ""}}}, +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/chat_template/targets.bzl b/extension/llm/chat_template/targets.bzl new file mode 100644 index 00000000000..197e5da9fb3 --- /dev/null +++ b/extension/llm/chat_template/targets.bzl @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "chat_templates", + exported_headers = [ + "chat_templates.h", + ], + visibility = ["PUBLIC"], + ) diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index 43b89f0a908..da5b42c3f4a 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -54,6 +54,14 @@ endif() list(APPEND runner_deps kernels_util_all_deps) target_link_libraries(extension_llm_runner PUBLIC ${runner_deps}) +target_link_libraries(extension_llm_runner PRIVATE $) +target_include_directories( + extension_llm_runner PRIVATE + $ +) +target_compile_definitions( + extension_llm_runner PUBLIC EXECUTORCH_USE_JINJA2CPP +) set_target_properties( extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON ) @@ -116,23 +124,19 @@ if(EXECUTORCH_BUILD_PYBIND) portable_lib ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} ) + # Set properties for the Python extension set_target_properties( _llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_VISIBILITY_PRESET "hidden" INTERPROCEDURAL_OPTIMIZATION TRUE - CXX_STANDARD 20 ) if(APPLE) - set(RPATH - "@loader_path/../../pybindings;@loader_path/../../../../torch/lib" - ) + set(RPATH "@loader_path/../../pybindings") else() - set(RPATH "$ORIGIN/../../pybindings:$ORIGIN/../../../../torch/lib") + set(RPATH "$ORIGIN/../../pybindings") endif() - set_target_properties( - _llm_runner PROPERTIES BUILD_RPATH "${RPATH}" INSTALL_RPATH "${RPATH}" - ) + set_target_properties(_llm_runner PROPERTIES INSTALL_RPATH ${RPATH}) # Add include directories target_include_directories( _llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS} diff --git a/extension/llm/runner/chat_types.h b/extension/llm/runner/chat_types.h new file mode 100644 index 00000000000..6a7cd2b625f --- /dev/null +++ b/extension/llm/runner/chat_types.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace executorch::extension::llm { + +struct ChatMessage { + std::string role; + std::string content; +}; + +struct ChatConversation { + std::vector messages; + std::string bos_token; + std::string eos_token; + bool add_generation_prompt = true; +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/jinja_chat_formatter.cpp b/extension/llm/runner/jinja_chat_formatter.cpp new file mode 100644 index 00000000000..00054554886 --- /dev/null +++ b/extension/llm/runner/jinja_chat_formatter.cpp @@ -0,0 +1,236 @@ +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { +namespace { + +std::string readFileToString(const std::filesystem::path& path) { + std::ifstream file(path); + if (!file) { + throw std::runtime_error("Failed to open template file: " + path.string()); + } + std::ostringstream buffer; + buffer << file.rdbuf(); + return buffer.str(); +} + +bool templateIncludesBos( + const std::string& template_str, + const ModelTokens& model_tokens) { + if (!model_tokens.bos_token.empty() && + template_str.find(model_tokens.bos_token) != std::string::npos) { + return true; + } + return template_str.find("bos_token") != std::string::npos; +} + +std::string normalizeTemplate(std::string input) { + // These replacements normalize vLLM/HuggingFace Jinja templates so they + // compile/render correctly with Jinja2Cpp, which has stricter parser + // semantics than Python Jinja2. + // + // IMPORTANT: "not tools is none" in Python Jinja means "tools is not none" + // (truthy when tools is defined and non-null), so we map it to a simple + // truthy check on `tools`. Mapping to "not tools" was a bug that would + // skip tool blocks for non-empty tools lists. + constexpr std::array, 10> + replacements = {{ + {"tools = none", "tools = []"}, + {"tools = None", "tools = []"}, + {"tools is not none", "tools"}, + {"tools is not None", "tools"}, + {"not tools is none", "tools"}, + {"not tools is None", "tools"}, + {"tools is none", "not tools"}, + {"tools is None", "not tools"}, + {"messages[1:]", "messages_tail"}, + {"{ \"output\": message.content } | tojson", "message.content | tojson"}, + }}; + // Handle special case that can't be constexpr due to escape sequence + const std::pair gemmaReplacement = { + "{{'model\\n'}}", "{{ 'model\\n' }}"}; + for (const auto& replacement : replacements) { + size_t pos = 0; + while ((pos = input.find(replacement.first, pos)) != std::string::npos) { + input.replace(pos, replacement.first.size(), replacement.second); + pos += replacement.second.size(); + } + } + // Apply the gemma replacement separately + size_t pos = 0; + while ((pos = input.find(gemmaReplacement.first, pos)) != std::string::npos) { + input.replace(pos, gemmaReplacement.first.size(), gemmaReplacement.second); + pos += gemmaReplacement.second.size(); + } + return input; +} + +ChatTemplateType detectTemplateType(const std::string& template_str) { + if (template_str.find("") != std::string::npos) { + return ChatTemplateType::Gemma3; + } + if (template_str.find("<|start_header_id|>") != std::string::npos) { + return ChatTemplateType::Llama3; + } + return ChatTemplateType::Custom; +} + +std::string toLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; +} + +} // namespace + +} // namespace executorch::extension::llm + +namespace jinja2 { + +// NOLINTBEGIN(facebook-hte-MisplacedTemplateSpecialization,facebook-hte-ShadowingClass) +// This template specialization must be in the jinja2 namespace for the library +// to find it via ADL during template instantiation. +template <> +struct TypeReflection + : TypeReflected { + static auto& GetAccessors() { + static std::unordered_map accessors = { + {"role", + [](const executorch::extension::llm::ChatMessage& msg) { + return jinja2::Reflect(msg.role); + }}, + {"content", + [](const executorch::extension::llm::ChatMessage& msg) { + return jinja2::Reflect(msg.content); + }}, + }; + return accessors; + } +}; +// NOLINTEND(facebook-hte-MisplacedTemplateSpecialization,facebook-hte-ShadowingClass) + +} // namespace jinja2 + +namespace executorch::extension::llm { + +JinjaChatFormatter::JinjaChatFormatter( + const std::string& template_str, + ChatTemplateType type) + : template_str_(template_str), type_(type) { + auto tokens_it = kModelTokens.find(type_); + if (tokens_it != kModelTokens.end()) { + model_tokens_ = tokens_it->second; + } + includes_bos_ = templateIncludesBos(template_str_, model_tokens_); + const std::string normalized_template = normalizeTemplate(template_str_); + compiled_template_ = std::make_unique(); + auto load_result = compiled_template_->Load(normalized_template); + if (!load_result) { + throw std::runtime_error( + "Failed to parse chat template: " + + load_result.error().ToString()); + } +} + +JinjaChatFormatter::~JinjaChatFormatter() = default; + +std::unique_ptr JinjaChatFormatter::fromTemplate( + ChatTemplateType type) { + auto it = kEmbeddedTemplates.find(type); + if (it == kEmbeddedTemplates.end()) { + throw std::runtime_error("Unsupported embedded chat template type."); + } + return std::unique_ptr( + new JinjaChatFormatter(std::string(it->second), type)); +} + +std::unique_ptr JinjaChatFormatter::fromString( + const std::string& template_str) { + const ChatTemplateType inferred_type = detectTemplateType(template_str); + return std::unique_ptr( + new JinjaChatFormatter(template_str, inferred_type)); +} + +std::unique_ptr JinjaChatFormatter::fromFile( + const std::filesystem::path& path) { + return fromString(readFileToString(path)); +} + +std::string JinjaChatFormatter::format( + const std::string& prompt, + const std::string& system_prompt) const { + ChatConversation conversation; + if (!system_prompt.empty()) { + conversation.messages.push_back({"system", system_prompt}); + } + conversation.messages.push_back({"user", prompt}); + conversation.bos_token = model_tokens_.bos_token; + conversation.eos_token = model_tokens_.eos_token; + conversation.add_generation_prompt = true; + return formatConversation(conversation); +} + +std::string JinjaChatFormatter::formatConversation( + const ChatConversation& conversation) const { + jinja2::ValuesMap params; + params["messages"] = jinja2::ValuesList(); + params["messages_tail"] = jinja2::ValuesList(); + bool is_first = true; + for (const auto& msg : conversation.messages) { + params["messages"].asList().push_back(jinja2::Reflect(msg)); + if (!is_first) { + params["messages_tail"].asList().push_back(jinja2::Reflect(msg)); + } + is_first = false; + } + params["bos_token"] = conversation.bos_token; + params["eos_token"] = conversation.eos_token; + params["add_generation_prompt"] = conversation.add_generation_prompt; + // Provide vLLM/HuggingFace-style defaults that templates often reference. + // Templates that don't use these will simply ignore them. + params["tools"] = jinja2::ValuesList(); + params["tool_choice"] = jinja2::Value(); + params["date_string"] = std::string("26 Jul 2024"); + params["chat_template_kwargs"] = jinja2::ValuesMap(); + + auto rendered = compiled_template_->RenderAsString(params); + if (!rendered) { + throw std::runtime_error( + "Failed to render chat template: " + rendered.error().ToString()); + } + return rendered.value(); +} + +ChatTemplateType parseChatTemplateType(const std::string& type_str) { + const std::string lower = toLower(type_str); + if (lower == "none") { + return ChatTemplateType::None; + } + if (lower == "llama3") { + return ChatTemplateType::Llama3; + } + if (lower == "llama3.2" || lower == "llama32" || lower == "llama3_2") { + return ChatTemplateType::Llama32; + } + if (lower == "gemma3") { + return ChatTemplateType::Gemma3; + } + if (lower == "custom" || lower == "jinja") { + return ChatTemplateType::Custom; + } + return ChatTemplateType::None; +} + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/jinja_chat_formatter.h b/extension/llm/runner/jinja_chat_formatter.h new file mode 100644 index 00000000000..e589a9ed18e --- /dev/null +++ b/extension/llm/runner/jinja_chat_formatter.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace jinja2 { +class Template; +} + +namespace executorch::extension::llm { + +class JinjaChatFormatter { + public: + static std::unique_ptr fromTemplate(ChatTemplateType type); + static std::unique_ptr fromString( + const std::string& template_str); + static std::unique_ptr fromFile( + const std::filesystem::path& path); + + ~JinjaChatFormatter(); + + std::string format( + const std::string& prompt, + const std::string& system_prompt = "") const; + std::string formatConversation(const ChatConversation& conversation) const; + + bool includesBos() const { + return includes_bos_; + } + + const ModelTokens& getModelTokens() const { + return model_tokens_; + } + + private: + JinjaChatFormatter(const std::string& template_str, ChatTemplateType type); + + std::string template_str_; + ChatTemplateType type_; + ModelTokens model_tokens_; + bool includes_bos_ = false; + std::unique_ptr compiled_template_; +}; + +ChatTemplateType parseChatTemplateType(const std::string& type_str); + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 0d4ed99308d..f8ceed67bf1 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -111,11 +111,14 @@ def define_common_targets(): runtime.cxx_library( name = "runner_lib" + aten_suffix, exported_headers = [ + "chat_types.h", + "jinja_chat_formatter.h", "text_llm_runner.h", "llm_runner_helper.h", "constants.h", ], srcs = [ + "jinja_chat_formatter.cpp", "text_llm_runner.cpp", "llm_runner_helper.cpp", "multimodal_runner.cpp", @@ -131,6 +134,7 @@ def define_common_targets(): ":text_decoder_runner" + aten_suffix, ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, + "//executorch/extension/llm/chat_template:chat_templates", "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, "//executorch/extension/memory_allocator:cpu_caching_allocator", "//pytorch/tokenizers:hf_tokenizer", @@ -138,5 +142,6 @@ def define_common_targets(): "//pytorch/tokenizers:sentencepiece", "//pytorch/tokenizers:tekken", "//pytorch/tokenizers:tiktoken", + "@fbsource//third-party/jinja2cpp:jinja2cpp", ], ) diff --git a/extension/llm/runner/test/BUCK b/extension/llm/runner/test/BUCK index 9ed85acadb7..cb8e8fcfb7e 100644 --- a/extension/llm/runner/test/BUCK +++ b/extension/llm/runner/test/BUCK @@ -10,32 +10,21 @@ oncall("executorch") # targets.bzl. This file can contain fbcode-only targets. load(":targets.bzl", "define_common_targets") - +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") non_fbcode_target(_kind = define_common_targets,) -# !!!! fbcode/executorch/extension/llm/runner/test/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!! - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Any targets that should be shared between fbcode and xplat must be defined in -# targets.bzl. This file can contain fbcode-only targets. - -load(":targets.bzl", "define_common_targets") -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") - fbcode_target(_kind = define_common_targets,) +# fbcode-only test that requires access to fbcode//executorch/test/models fbcode_target(_kind = runtime.cxx_test, name = "test_text_decoder_runner", srcs = ["test_text_decoder_runner.cpp"], deps = [ - "//executorch/extension/llm/runner:runner_lib", "//executorch/extension/llm/runner/io_manager:io_manager", + "//executorch/extension/llm/runner:text_decoder_runner", + "//executorch/extension/module:module", + "//executorch/extension/tensor:tensor", "//executorch/kernels/portable:generated_lib", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], @@ -43,5 +32,5 @@ fbcode_target(_kind = runtime.cxx_test, "KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])", "KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])", "NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])", - } + }, ) diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 81b69c0ab9a..64d89ff0ec4 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -19,6 +19,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp + test_jinja_chat_formatter.cpp test_text_llm_runner.cpp test_text_prefiller.cpp test_text_decoder_runner.cpp diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 08044de2d35..a8e3858af9d 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -17,6 +17,14 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "test_jinja_chat_formatter", + srcs = ["test_jinja_chat_formatter.cpp"], + deps = [ + "//executorch/extension/llm/runner:runner_lib", + ], + ) + runtime.cxx_test( name = "test_text_llm_runner", srcs = ["test_text_llm_runner.cpp"], diff --git a/extension/llm/runner/test/test_jinja_chat_formatter.cpp b/extension/llm/runner/test/test_jinja_chat_formatter.cpp new file mode 100644 index 00000000000..d1a3d6d9fff --- /dev/null +++ b/extension/llm/runner/test/test_jinja_chat_formatter.cpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using executorch::extension::llm::ChatConversation; +using executorch::extension::llm::ChatMessage; +using executorch::extension::llm::ChatTemplateType; +using executorch::extension::llm::JinjaChatFormatter; +using executorch::extension::llm::parseChatTemplateType; +using testing::HasSubstr; + +TEST(JinjaChatFormatter, Llama3SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const std::string prompt = "Test prompt"; + const std::string system_prompt = "You are a helpful assistant."; + // Note: The Jinja template uses {%- ... -%} which strips whitespace, + // so the output has \n\n after each <|end_header_id|> and content, + // but no trailing \n\n at the end due to the {%- endif -%} stripping. + const std::string expected = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + + system_prompt + + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + prompt + + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; + + EXPECT_EQ(formatter->format(prompt, system_prompt), expected); +} + +TEST(JinjaChatFormatter, Gemma3SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("")); + EXPECT_THAT(result, HasSubstr("user")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("model")); +} + +TEST(JinjaChatFormatter, Llama3WithoutSystemPrompt) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); + // Should not contain system header when no system prompt + EXPECT_THAT(result, ::testing::Not(HasSubstr("<|start_header_id|>system"))); +} + +TEST(JinjaChatFormatter, Llama3IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Gemma3IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Llama3ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, "<|begin_of_text|>"); + EXPECT_EQ(tokens.eos_token, "<|eot_id|>"); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +TEST(JinjaChatFormatter, Gemma3ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, ""); + EXPECT_EQ(tokens.eos_token, ""); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +TEST(JinjaChatFormatter, FormatConversationMultiTurn) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + + ChatConversation conversation; + conversation.bos_token = "<|begin_of_text|>"; + conversation.eos_token = "<|eot_id|>"; + conversation.add_generation_prompt = true; + conversation.messages = { + {"user", "Hello"}, + {"assistant", "Hi there!"}, + {"user", "How are you?"}, + }; + + const std::string result = formatter->formatConversation(conversation); + + EXPECT_THAT(result, HasSubstr("Hello")); + EXPECT_THAT(result, HasSubstr("Hi there!")); + EXPECT_THAT(result, HasSubstr("How are you?")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +} + +TEST(JinjaChatFormatter, FromStringLlama3Template) { + const std::string llama_template = + "{{ bos_token }}<|start_header_id|>user<|end_header_id|>\n\n" + "{{ messages[0].content }}<|eot_id|>"; + + auto formatter = JinjaChatFormatter::fromString(llama_template); + + // Should detect Llama3 type from the template content + const std::string result = formatter->format("Test"); + EXPECT_THAT(result, HasSubstr("Test")); +} + +TEST(JinjaChatFormatter, FromStringGemma3Template) { + const std::string gemma_template = + "{{ bos_token }}user\n" + "{{ messages[0].content }}"; + + auto formatter = JinjaChatFormatter::fromString(gemma_template); + + const std::string result = formatter->format("Test"); + EXPECT_THAT(result, HasSubstr("Test")); +} + +TEST(JinjaChatFormatter, UnsupportedTemplateTypeThrows) { + EXPECT_THROW( + JinjaChatFormatter::fromTemplate(ChatTemplateType::None), + std::runtime_error); +} + +// Tests for parseChatTemplateType +TEST(ParseChatTemplateType, ParseNone) { + EXPECT_EQ(parseChatTemplateType("none"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("None"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("NONE"), ChatTemplateType::None); +} + +TEST(ParseChatTemplateType, ParseLlama3) { + EXPECT_EQ(parseChatTemplateType("llama3"), ChatTemplateType::Llama3); + EXPECT_EQ(parseChatTemplateType("LLAMA3"), ChatTemplateType::Llama3); + EXPECT_EQ(parseChatTemplateType("Llama3"), ChatTemplateType::Llama3); +} + +TEST(ParseChatTemplateType, ParseLlama32Variants) { + EXPECT_EQ(parseChatTemplateType("llama3.2"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("llama32"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("llama3_2"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("LLAMA3.2"), ChatTemplateType::Llama32); +} + +TEST(ParseChatTemplateType, ParseGemma3) { + EXPECT_EQ(parseChatTemplateType("gemma3"), ChatTemplateType::Gemma3); + EXPECT_EQ(parseChatTemplateType("GEMMA3"), ChatTemplateType::Gemma3); + EXPECT_EQ(parseChatTemplateType("Gemma3"), ChatTemplateType::Gemma3); +} + +TEST(ParseChatTemplateType, ParseCustom) { + EXPECT_EQ(parseChatTemplateType("custom"), ChatTemplateType::Custom); + EXPECT_EQ(parseChatTemplateType("jinja"), ChatTemplateType::Custom); + EXPECT_EQ(parseChatTemplateType("CUSTOM"), ChatTemplateType::Custom); +} + +TEST(ParseChatTemplateType, ParseUnknownReturnsNone) { + EXPECT_EQ(parseChatTemplateType("unknown"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType(""), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("invalid"), ChatTemplateType::None); +} + +TEST(JinjaChatFormatter, Llama32SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + const std::string result = formatter->format("Hello!"); + + // Llama32 uses the same template as Llama3 + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +} + +TEST(JinjaChatFormatter, Llama32IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Llama32ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, "<|begin_of_text|>"); + EXPECT_EQ(tokens.eos_token, "<|eot_id|>"); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +// Universal Jinja support: any HuggingFace / vLLM-style template string +// (passed via fromString or fromFile) should work, not just the embedded +// Llama3/Gemma3 templates. This validates the renderer with a generic +// HuggingFace-style template that uses the standard `messages`, +// `bos_token`, and `add_generation_prompt` variables. +TEST(JinjaChatFormatter, UniversalJinjaGenericTemplate) { + // Generic chat template inspired by HuggingFace tokenizer_config.json + // examples. Uses only standard Jinja2 features and standard chat variables. + const std::string generic_template = + "{{ bos_token }}" + "{%- for message in messages -%}" + "<|{{ message.role }}|>\n{{ message.content }}<|end|>\n" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}<|assistant|>\n{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(generic_template); + ChatConversation conv; + conv.bos_token = ""; + conv.add_generation_prompt = true; + conv.messages.push_back(ChatMessage{"user", "Hi there"}); + + const std::string result = formatter->formatConversation(conv); + + EXPECT_THAT(result, HasSubstr("")); + EXPECT_THAT(result, HasSubstr("<|user|>")); + EXPECT_THAT(result, HasSubstr("Hi there")); + EXPECT_THAT(result, HasSubstr("<|assistant|>")); +} + +// Universal Jinja support: templates that reference `tools` (e.g. vLLM's +// tool_chat_template_*.jinja files) should not crash even when no tools +// are passed. The formatter injects `tools = []` (an empty list) so that +// truthy/none checks evaluate consistently. With our normalization, +// `tools is not none` is rewritten to `tools` (a truthy check), which means +// an empty list (no tools) is treated as "no tools available" — matching +// the typical template intent. +TEST(JinjaChatFormatter, UniversalJinjaToolsAwareTemplate) { + const std::string tools_template = + "{%- if tools is not none -%}" + "tools_present" + "{%- else -%}" + "no_tools" + "{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(tools_template); + ChatConversation conv; + conv.add_generation_prompt = false; + + // tools defaults to [] (empty list). After normalization this is a + // truthy check, so the empty-list (no tools) branch should be selected. + EXPECT_EQ(formatter->formatConversation(conv), "no_tools"); +} + +// Universal Jinja support: a template using "not tools is none" (semantically +// "tools is not none") should now safely evaluate without skipping the +// "no tools" branch. This guards the regression where the normalizer mapped +// "not tools is none" -> "not tools", which incorrectly evaluated to true +// for an empty list and would have rendered a tool block when none was +// intended. +TEST(JinjaChatFormatter, UniversalJinjaNormalizedNotToolsIsNone) { + const std::string template_str = + "{%- if not tools is none -%}defined{%- else -%}none{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(template_str); + ChatConversation conv; + conv.add_generation_prompt = false; + + // tools = [] is falsy after normalization (`not tools is none` -> `tools`). + // The "else" branch should be selected (no tools available). + EXPECT_EQ(formatter->formatConversation(conv), "none"); +} +} diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index b0545b8ce18..971dde20867 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -36,7 +36,6 @@ PROGRAM_NO_PRIM_OPS_SRCS = [ "method.cpp", "method_meta.cpp", "program.cpp", - "program_validation.cpp", "tensor_parser_exec_aten.cpp", ] @@ -357,6 +356,7 @@ EXTENSION_RUNNER_UTIL_SRCS = [ ] EXTENSION_LLM_RUNNER_SRCS = [ + "extension/llm/runner/jinja_chat_formatter.cpp", "extension/llm/runner/llm_runner_helper.cpp", "extension/llm/runner/multimodal_prefiller.cpp", "extension/llm/runner/multimodal_runner.cpp", @@ -476,7 +476,6 @@ XNNPACK_BACKEND_BUCK_SRCS = [ "runtime/XNNPACKBackend.cpp", "runtime/XNNWeightsCache.cpp", "runtime/XNNWorkspaceManager.cpp", - "runtime/XnnpackBackendOptions.cpp", "runtime/profiling/XNNProfiler.cpp", ] From c9104a76f90f05529b1a1d3adb4c99d4e2d351ea Mon Sep 17 00:00:00 2001 From: Young Han Date: Thu, 14 May 2026 14:46:43 -0700 Subject: [PATCH 2/3] [llm] Address chat formatter review feedback Restore accidentally dropped build sources and tighten the Jinja formatter API/build wiring so the first chat-template PR is reviewable on its own. Co-authored-by: Cursor --- CMakeLists.txt | 9 --- extension/llm/chat_template/CMakeLists.txt | 54 +++++++----------- extension/llm/chat_template/chat_templates.h | 55 +++++++++++++------ extension/llm/runner/CMakeLists.txt | 8 +-- extension/llm/runner/chat_types.h | 8 +++ extension/llm/runner/jinja_chat_formatter.cpp | 40 +++++++++----- extension/llm/runner/jinja_chat_formatter.h | 15 +++-- .../runner/test/test_jinja_chat_formatter.cpp | 15 ++--- .../executorch/build/build_variables.bzl | 2 + 9 files changed, 118 insertions(+), 88 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 30dc0dabccb..ebc93026789 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,7 +103,6 @@ include(${PROJECT_SOURCE_DIR}/tools/cmake/Codegen.cmake) include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake) include(CMakeDependentOption) include(ExternalProject) -include(FetchContent) include(GNUInstallDirs) if(NOT CMAKE_CXX_STANDARD) @@ -410,14 +409,6 @@ set(_common_include_directories $ ) -if(TARGET jinja2cpp) - install( - TARGETS jinja2cpp - EXPORT ExecuTorchTargets - DESTINATION ${CMAKE_INSTALL_LIBDIR} - ) -endif() - # # The `__srcs` lists are defined by executorch_load_build_variables. # diff --git a/extension/llm/chat_template/CMakeLists.txt b/extension/llm/chat_template/CMakeLists.txt index 88ce80b157f..2002a54d37c 100644 --- a/extension/llm/chat_template/CMakeLists.txt +++ b/extension/llm/chat_template/CMakeLists.txt @@ -14,32 +14,31 @@ FetchContent_Declare( set(JINJA2CPP_BUILD_TESTS OFF - CACHE BOOL "" - FORCE + CACHE BOOL "" FORCE ) set(JINJA2CPP_BUILD_SHARED OFF - CACHE BOOL "" - FORCE + CACHE BOOL "" FORCE ) set(JINJA2CPP_INSTALL OFF - CACHE BOOL "" - FORCE + CACHE BOOL "" FORCE ) # Enable PCRE2-based regex lookahead support in Jinja2Cpp. This must be set -# BEFORE FetchContent_MakeAvailable(jinja2cpp) so it propagates to the -# Jinja2Cpp configure step. +# BEFORE FetchContent_MakeAvailable(jinja2cpp) so it propagates to the Jinja2Cpp +# configure step. set(SUPPORT_REGEX_LOOKAHEAD ON - CACHE BOOL "" - FORCE + CACHE BOOL "" FORCE ) FetchContent_MakeAvailable(jinja2cpp) if(NOT TARGET jinja2cpp) message(FATAL_ERROR "Jinja2Cpp target not found after FetchContent.") endif() +if(APPLE) + target_compile_options(jinja2cpp PRIVATE -Wno-shorten-64-to-32) +endif() if(DEFINED jinja2cpp_SOURCE_DIR) function(executorch_copy_nonstd_header dep_name target header_name dest_root) @@ -57,9 +56,8 @@ if(DEFINED jinja2cpp_SOURCE_DIR) foreach(_dir IN LISTS _include_dirs) if(EXISTS "${_dir}/nonstd/${header_name}") file(MAKE_DIRECTORY "${dest_root}/nonstd") - file( - COPY "${_dir}/nonstd/${header_name}" - DESTINATION "${dest_root}/nonstd" + file(COPY "${_dir}/nonstd/${header_name}" + DESTINATION "${dest_root}/nonstd" ) set(_copied TRUE) break() @@ -67,9 +65,8 @@ if(DEFINED jinja2cpp_SOURCE_DIR) endforeach() endif() if(NOT _copied) - set( - _fallback_path - "${CMAKE_BINARY_DIR}/_deps/${dep_name}-src/include/nonstd/${header_name}" + set(_fallback_path + "${CMAKE_BINARY_DIR}/_deps/${dep_name}-src/include/nonstd/${header_name}" ) if(EXISTS "${_fallback_path}") file(MAKE_DIRECTORY "${dest_root}/nonstd") @@ -78,31 +75,21 @@ if(DEFINED jinja2cpp_SOURCE_DIR) endif() endfunction() - set(_jinja2cpp_nonstd_root - "${jinja2cpp_SOURCE_DIR}/thirdparty/nonstd" - ) + set(_jinja2cpp_nonstd_root "${jinja2cpp_SOURCE_DIR}/thirdparty/nonstd") executorch_copy_nonstd_header( - expected-lite - nonstd::expected-lite - expected.hpp + expected-lite nonstd::expected-lite expected.hpp "${_jinja2cpp_nonstd_root}/expected-lite/include" ) executorch_copy_nonstd_header( - variant-lite - nonstd::variant-lite - variant.hpp + variant-lite nonstd::variant-lite variant.hpp "${_jinja2cpp_nonstd_root}/variant-lite/include" ) executorch_copy_nonstd_header( - optional-lite - nonstd::optional-lite - optional.hpp + optional-lite nonstd::optional-lite optional.hpp "${_jinja2cpp_nonstd_root}/optional-lite/include" ) executorch_copy_nonstd_header( - string-view-lite - nonstd::string-view-lite - string_view.hpp + string-view-lite nonstd::string-view-lite string_view.hpp "${_jinja2cpp_nonstd_root}/string-view-lite/include" ) endif() @@ -110,7 +97,6 @@ endif() # Install the chat_templates.h header so that downstream consumers of the # installed ExecuTorch SDK can include # . -install( - FILES chat_templates.h - DESTINATION include/executorch/extension/llm/chat_template +install(FILES chat_templates.h + DESTINATION include/executorch/extension/llm/chat_template ) diff --git a/extension/llm/chat_template/chat_templates.h b/extension/llm/chat_template/chat_templates.h index b9e059b80e6..0e6e2d7dceb 100644 --- a/extension/llm/chat_template/chat_templates.h +++ b/extension/llm/chat_template/chat_templates.h @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include @@ -15,23 +23,29 @@ enum class ChatTemplateType { Custom, }; -constexpr std::string_view kLlama3Template = R"({{ bos_token }}{%- for message in messages -%}<|start_header_id|>{{ message.role }}<|end_header_id|> +constexpr std::string_view kLlama3Template = + R"({{ bos_token }}{%- for message in messages -%}<|start_header_id|>{{ message.role }}<|end_header_id|> {{ message.content }}<|eot_id|>{%- endfor -%}{%- if add_generation_prompt -%}<|start_header_id|>assistant<|end_header_id|> {%- endif -%})"; -constexpr std::string_view kGemma3Template = R"({{ bos_token }}{%- for message in messages -%}{%- if message.role == 'assistant' -%}model +constexpr std::string_view kGemma3Template = + R"({{ bos_token }}{%- for message in messages -%}{%- if message.role == 'assistant' -%}model {%- else -%}{{ message.role }} {%- endif -%}{{ message.content }}{%- endfor -%}{%- if add_generation_prompt -%}model {%- endif -%})"; -inline const std::unordered_map - kEmbeddedTemplates = { - {ChatTemplateType::Llama3, kLlama3Template}, - {ChatTemplateType::Llama32, kLlama3Template}, - {ChatTemplateType::Gemma3, kGemma3Template}, - }; +inline const std::unordered_map& +getEmbeddedTemplates() { + static const std::unordered_map + embedded_templates = { + {ChatTemplateType::Llama3, kLlama3Template}, + {ChatTemplateType::Llama32, kLlama3Template}, + {ChatTemplateType::Gemma3, kGemma3Template}, + }; + return embedded_templates; +} struct ModelTokens { std::string bos_token; @@ -39,13 +53,22 @@ struct ModelTokens { std::vector stop_tokens; }; -inline const std::unordered_map kModelTokens = { - {ChatTemplateType::Llama3, - {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, - {ChatTemplateType::Llama32, - {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, - {ChatTemplateType::Gemma3, - {"", "", {"", ""}}}, -}; +inline const std::unordered_map& +getModelTokens() { + static const std::unordered_map model_tokens = + { + {ChatTemplateType::Llama3, + {"<|begin_of_text|>", + "<|eot_id|>", + {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Llama32, + {"<|begin_of_text|>", + "<|eot_id|>", + {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Gemma3, + {"", "", {"", ""}}}, + }; + return model_tokens; +} } // namespace executorch::extension::llm diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index da5b42c3f4a..7ab60452ff7 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -56,12 +56,10 @@ list(APPEND runner_deps kernels_util_all_deps) target_link_libraries(extension_llm_runner PUBLIC ${runner_deps}) target_link_libraries(extension_llm_runner PRIVATE $) target_include_directories( - extension_llm_runner PRIVATE - $ -) -target_compile_definitions( - extension_llm_runner PUBLIC EXECUTORCH_USE_JINJA2CPP + extension_llm_runner + PRIVATE $ ) +target_compile_definitions(extension_llm_runner PUBLIC EXECUTORCH_USE_JINJA2CPP) set_target_properties( extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON ) diff --git a/extension/llm/runner/chat_types.h b/extension/llm/runner/chat_types.h index 6a7cd2b625f..fc90a199bc2 100644 --- a/extension/llm/runner/chat_types.h +++ b/extension/llm/runner/chat_types.h @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include diff --git a/extension/llm/runner/jinja_chat_formatter.cpp b/extension/llm/runner/jinja_chat_formatter.cpp index 00054554886..c9e2a932360 100644 --- a/extension/llm/runner/jinja_chat_formatter.cpp +++ b/extension/llm/runner/jinja_chat_formatter.cpp @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #include #include @@ -15,10 +23,10 @@ namespace executorch::extension::llm { namespace { -std::string readFileToString(const std::filesystem::path& path) { +std::string readFileToString(const std::string& path) { std::ifstream file(path); if (!file) { - throw std::runtime_error("Failed to open template file: " + path.string()); + throw std::runtime_error("Failed to open template file: " + path); } std::ostringstream buffer; buffer << file.rdbuf(); @@ -44,6 +52,8 @@ std::string normalizeTemplate(std::string input) { // (truthy when tools is defined and non-null), so we map it to a simple // truthy check on `tools`. Mapping to "not tools" was a bug that would // skip tool blocks for non-empty tools lists. + // Keep longer `tools is ... none` patterns before the shorter + // `tools is none` patterns to avoid partial replacements. constexpr std::array, 10> replacements = {{ {"tools = none", "tools = []"}, @@ -55,7 +65,8 @@ std::string normalizeTemplate(std::string input) { {"tools is none", "not tools"}, {"tools is None", "not tools"}, {"messages[1:]", "messages_tail"}, - {"{ \"output\": message.content } | tojson", "message.content | tojson"}, + {"{ \"output\": message.content } | tojson", + "message.content | tojson"}, }}; // Handle special case that can't be constexpr due to escape sequence const std::pair gemmaReplacement = { @@ -87,9 +98,10 @@ ChatTemplateType detectTemplateType(const std::string& template_str) { } std::string toLower(std::string value) { - std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { - return static_cast(std::tolower(c)); - }); + std::transform( + value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); return value; } @@ -129,8 +141,9 @@ JinjaChatFormatter::JinjaChatFormatter( const std::string& template_str, ChatTemplateType type) : template_str_(template_str), type_(type) { - auto tokens_it = kModelTokens.find(type_); - if (tokens_it != kModelTokens.end()) { + const auto& model_tokens = getModelTokens(); + auto tokens_it = model_tokens.find(type_); + if (tokens_it != model_tokens.end()) { model_tokens_ = tokens_it->second; } includes_bos_ = templateIncludesBos(template_str_, model_tokens_); @@ -139,8 +152,7 @@ JinjaChatFormatter::JinjaChatFormatter( auto load_result = compiled_template_->Load(normalized_template); if (!load_result) { throw std::runtime_error( - "Failed to parse chat template: " + - load_result.error().ToString()); + "Failed to parse chat template: " + load_result.error().ToString()); } } @@ -148,8 +160,9 @@ JinjaChatFormatter::~JinjaChatFormatter() = default; std::unique_ptr JinjaChatFormatter::fromTemplate( ChatTemplateType type) { - auto it = kEmbeddedTemplates.find(type); - if (it == kEmbeddedTemplates.end()) { + const auto& embedded_templates = getEmbeddedTemplates(); + auto it = embedded_templates.find(type); + if (it == embedded_templates.end()) { throw std::runtime_error("Unsupported embedded chat template type."); } return std::unique_ptr( @@ -164,7 +177,7 @@ std::unique_ptr JinjaChatFormatter::fromString( } std::unique_ptr JinjaChatFormatter::fromFile( - const std::filesystem::path& path) { + const std::string& path) { return fromString(readFileToString(path)); } @@ -202,6 +215,7 @@ std::string JinjaChatFormatter::formatConversation( // Templates that don't use these will simply ignore them. params["tools"] = jinja2::ValuesList(); params["tool_choice"] = jinja2::Value(); + // HuggingFace templates use a fixed date in their own regression fixtures. params["date_string"] = std::string("26 Jul 2024"); params["chat_template_kwargs"] = jinja2::ValuesMap(); diff --git a/extension/llm/runner/jinja_chat_formatter.h b/extension/llm/runner/jinja_chat_formatter.h index e589a9ed18e..79f1bf9beb5 100644 --- a/extension/llm/runner/jinja_chat_formatter.h +++ b/extension/llm/runner/jinja_chat_formatter.h @@ -1,9 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include #include -#include #include #include @@ -15,11 +22,11 @@ namespace executorch::extension::llm { class JinjaChatFormatter { public: - static std::unique_ptr fromTemplate(ChatTemplateType type); + static std::unique_ptr fromTemplate( + ChatTemplateType type); static std::unique_ptr fromString( const std::string& template_str); - static std::unique_ptr fromFile( - const std::filesystem::path& path); + static std::unique_ptr fromFile(const std::string& path); ~JinjaChatFormatter(); diff --git a/extension/llm/runner/test/test_jinja_chat_formatter.cpp b/extension/llm/runner/test/test_jinja_chat_formatter.cpp index d1a3d6d9fff..5da056dda81 100644 --- a/extension/llm/runner/test/test_jinja_chat_formatter.cpp +++ b/extension/llm/runner/test/test_jinja_chat_formatter.cpp @@ -26,9 +26,8 @@ TEST(JinjaChatFormatter, Llama3SingleMessage) { // but no trailing \n\n at the end due to the {%- endif -%} stripping. const std::string expected = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + - system_prompt + - "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + prompt + - "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; + system_prompt + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + + prompt + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; EXPECT_EQ(formatter->format(prompt, system_prompt), expected); } @@ -50,7 +49,8 @@ TEST(JinjaChatFormatter, Llama3WithoutSystemPrompt) { EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); EXPECT_THAT(result, HasSubstr("Hello!")); - EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); + EXPECT_THAT( + result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); // Should not contain system header when no system prompt EXPECT_THAT(result, ::testing::Not(HasSubstr("<|start_header_id|>system"))); } @@ -101,7 +101,8 @@ TEST(JinjaChatFormatter, FormatConversationMultiTurn) { EXPECT_THAT(result, HasSubstr("Hello")); EXPECT_THAT(result, HasSubstr("Hi there!")); EXPECT_THAT(result, HasSubstr("How are you?")); - EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); + EXPECT_THAT( + result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); } TEST(JinjaChatFormatter, FromStringLlama3Template) { @@ -179,7 +180,8 @@ TEST(JinjaChatFormatter, Llama32SingleMessage) { EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); EXPECT_THAT(result, HasSubstr("Hello!")); - EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); + EXPECT_THAT( + result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); } TEST(JinjaChatFormatter, Llama32IncludesBos) { @@ -267,4 +269,3 @@ TEST(JinjaChatFormatter, UniversalJinjaNormalizedNotToolsIsNone) { // The "else" branch should be selected (no tools available). EXPECT_EQ(formatter->formatConversation(conv), "none"); } -} diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index 971dde20867..9cb2b632a25 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -36,6 +36,7 @@ PROGRAM_NO_PRIM_OPS_SRCS = [ "method.cpp", "method_meta.cpp", "program.cpp", + "program_validation.cpp", "tensor_parser_exec_aten.cpp", ] @@ -476,6 +477,7 @@ XNNPACK_BACKEND_BUCK_SRCS = [ "runtime/XNNPACKBackend.cpp", "runtime/XNNWeightsCache.cpp", "runtime/XNNWorkspaceManager.cpp", + "runtime/XnnpackBackendOptions.cpp", "runtime/profiling/XNNProfiler.cpp", ] From be76393f5d81d197873b41a268c4ccd283822605 Mon Sep 17 00:00:00 2001 From: Young Han Date: Thu, 14 May 2026 15:18:38 -0700 Subject: [PATCH 3/3] [llm] Cover vLLM Llama 3.2 tool template Co-authored-by: Cursor --- extension/llm/runner/jinja_chat_formatter.cpp | 2 +- .../runner/test/test_jinja_chat_formatter.cpp | 114 ++++++++++++++++++ 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/extension/llm/runner/jinja_chat_formatter.cpp b/extension/llm/runner/jinja_chat_formatter.cpp index c9e2a932360..a8201412d64 100644 --- a/extension/llm/runner/jinja_chat_formatter.cpp +++ b/extension/llm/runner/jinja_chat_formatter.cpp @@ -141,7 +141,7 @@ JinjaChatFormatter::JinjaChatFormatter( const std::string& template_str, ChatTemplateType type) : template_str_(template_str), type_(type) { - const auto& model_tokens = getModelTokens(); + const auto& model_tokens = ::executorch::extension::llm::getModelTokens(); auto tokens_it = model_tokens.find(type_); if (tokens_it != model_tokens.end()) { model_tokens_ = tokens_it->second; diff --git a/extension/llm/runner/test/test_jinja_chat_formatter.cpp b/extension/llm/runner/test/test_jinja_chat_formatter.cpp index 5da056dda81..3975b6e735f 100644 --- a/extension/llm/runner/test/test_jinja_chat_formatter.cpp +++ b/extension/llm/runner/test/test_jinja_chat_formatter.cpp @@ -269,3 +269,117 @@ TEST(JinjaChatFormatter, UniversalJinjaNormalizedNotToolsIsNone) { // The "else" branch should be selected (no tools available). EXPECT_EQ(formatter->formatConversation(conv), "none"); } + +TEST(JinjaChatFormatter, VllmLlama32PythonicToolTemplate) { + // Mirrors vLLM's examples/tool_chat_template_llama3.2_pythonic.jinja. + const std::string template_str = R"({{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call functions, please respond with a python list of the calls. " }} + {{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a python list for function calls " }} + {{- "with their proper arguments to best answer the given prompt.\n\n" }} + {{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param + '=' -}} + {{- "%s" | format(tool_call.arguments[param]) -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ']<|eot_id|>' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} +)"; + + auto formatter = JinjaChatFormatter::fromString(template_str); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>system<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Today Date: 26 Jul 2024")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT( + result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +}