From 0898aa3db56fbffa12434c6e9051ede3d823e180 Mon Sep 17 00:00:00 2001 From: Young Han Date: Tue, 12 May 2026 17:27:40 -0700 Subject: [PATCH 1/4] [llm][1/4] Add Jinja2Cpp-based chat template formatter library MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foundation PR for the chat-template support stack. Adds the Jinja2Cpp-based JinjaChatFormatter, supporting chat-types, embedded Llama3/Llama3.2/Gemma3 templates, build glue (CMake/Buck), and a focused C++ unit-test suite. This PR is reviewable in isolation — it has no behavior change for any existing runner; downstream PRs (B/C/D) plug it in. This is part 1 of a 4-PR stack split out of #16987 per reviewer request: 1/4 (this PR) Library + tests 2/4 TextLLMRunner echo-gated special-token filter + EOS merge 3/4 Python bindings + Python LlamaRunner integration 4/4 llama_main CLI flags + chat_formatter wrapper + docs What this PR adds ----------------- * extension/llm/chat_template/{chat_templates.h, BUCK, CMakeLists.txt, targets.bzl} — embedded Llama3/Llama3.2/Gemma3 templates and the ChatTemplateType enum + ModelTokens. The CMake file FetchContent's Jinja2Cpp 1.3.2, with SUPPORT_REGEX_LOOKAHEAD set BEFORE FetchContent_MakeAvailable so it propagates correctly, plus header staging for nonstd headers that some Jinja2Cpp installations omit. Installs chat_templates.h so SDK consumers can include it. * extension/llm/runner/{chat_types.h, jinja_chat_formatter.{h,cpp}} — the Universal Jinja chat formatter that supports any HuggingFace / vLLM chat template, not just the embedded ones. Loadable via fromTemplate (built-in), fromString (any string), or fromFile (any .jinja file). formatConversation injects vLLM/HuggingFace-standard params (tools=[], tool_choice=None, date_string, chat_template_kwargs) so any template that references those variables renders correctly. * normalizeTemplate handles vLLM/HF template quirks for Jinja2Cpp: notably, 'not tools is none' maps to 'tools' (truthy check), preserving the intent of 'tools is not none' for empty-list defaults. * extension/llm/runner/{CMakeLists.txt, targets.bzl} — link extension_llm_runner against jinja2cpp (PRIVATE) and define EXECUTORCH_USE_JINJA2CPP. * extension/llm/runner/test/{test_jinja_chat_formatter.cpp, CMakeLists.txt, targets.bzl, BUCK} — unit tests covering Llama3 / Llama3.2 / Gemma3 embedded templates, parseChatTemplateType (case-insensitive), and three universal-Jinja regression tests: - generic HuggingFace-style template (proves it's not Llama-specific) - tools-aware template (validates the tools=[] default) - 'not tools is none' normalization regression test * CMakeLists.txt — adds add_subdirectory(extension/llm/chat_template) guarded by EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER. * shim_et/xplat/executorch/build/build_variables.bzl — adds jinja_chat_formatter.cpp to the runner sources. Notes ----- * No behavior change for existing TextLLMRunner / MultimodalRunner users: the formatter is opt-in, only invoked when downstream code calls it. * Sample vLLM templates are NOT checked in (per reviewer feedback); documentation in the follow-up CLI PR points users to vLLM's examples directory and HuggingFace tokenizer_config.json files. Original PR (full stack): https://github.com/pytorch/executorch/pull/16987 --- CMakeLists.txt | 11 +- extension/llm/chat_template/BUCK | 18 ++ extension/llm/chat_template/CMakeLists.txt | 116 ++++++++ extension/llm/chat_template/chat_templates.h | 51 ++++ extension/llm/chat_template/targets.bzl | 16 ++ extension/llm/runner/CMakeLists.txt | 20 +- extension/llm/runner/chat_types.h | 20 ++ extension/llm/runner/jinja_chat_formatter.cpp | 236 +++++++++++++++ extension/llm/runner/jinja_chat_formatter.h | 51 ++++ extension/llm/runner/targets.bzl | 5 + extension/llm/runner/test/BUCK | 23 +- extension/llm/runner/test/CMakeLists.txt | 1 + extension/llm/runner/test/targets.bzl | 8 + .../runner/test/test_jinja_chat_formatter.cpp | 270 ++++++++++++++++++ .../executorch/build/build_variables.bzl | 3 +- 15 files changed, 821 insertions(+), 28 deletions(-) create mode 100644 extension/llm/chat_template/BUCK create mode 100644 extension/llm/chat_template/CMakeLists.txt create mode 100644 extension/llm/chat_template/chat_templates.h create mode 100644 extension/llm/chat_template/targets.bzl create mode 100644 extension/llm/runner/chat_types.h create mode 100644 extension/llm/runner/jinja_chat_formatter.cpp create mode 100644 extension/llm/runner/jinja_chat_formatter.h create mode 100644 extension/llm/runner/test/test_jinja_chat_formatter.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ce0def6000b..ec552518b9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,7 @@ include(${PROJECT_SOURCE_DIR}/tools/cmake/Codegen.cmake) include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake) include(CMakeDependentOption) include(ExternalProject) +include(FetchContent) include(GNUInstallDirs) if(NOT CMAKE_CXX_STANDARD) @@ -406,6 +407,14 @@ set(_common_include_directories $ ) +if(TARGET jinja2cpp) + install( + TARGETS jinja2cpp + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} + ) +endif() + # # The `__srcs` lists are defined by executorch_load_build_variables. # @@ -803,7 +812,7 @@ endif() if(EXECUTORCH_BUILD_EXTENSION_LLM) if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) - set(SUPPORT_REGEX_LOOKAHEAD ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/chat_template) # llama/runner/CMakeLists.txt builds a shared library libllama_runner.so # that transitively depends on tokenizers. Need to build tokenizers with # -fPIC. diff --git a/extension/llm/chat_template/BUCK b/extension/llm/chat_template/BUCK new file mode 100644 index 00000000000..bd8ea6e199e --- /dev/null +++ b/extension/llm/chat_template/BUCK @@ -0,0 +1,18 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +oncall("executorch") + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +non_fbcode_target(_kind = define_common_targets,) + +# !!!! fbcode/executorch/extension/llm/chat_template/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!! + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +fbcode_target(_kind = define_common_targets,) diff --git a/extension/llm/chat_template/CMakeLists.txt b/extension/llm/chat_template/CMakeLists.txt new file mode 100644 index 00000000000..88ce80b157f --- /dev/null +++ b/extension/llm/chat_template/CMakeLists.txt @@ -0,0 +1,116 @@ +if(NOT EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) + return() +endif() + +include(FetchContent) +cmake_policy(SET CMP0077 NEW) + +FetchContent_Declare( + jinja2cpp + GIT_REPOSITORY https://github.com/jinja2cpp/Jinja2Cpp.git + GIT_TAG 1.3.2 + GIT_SUBMODULES_RECURSE TRUE +) + +set(JINJA2CPP_BUILD_TESTS + OFF + CACHE BOOL "" + FORCE +) +set(JINJA2CPP_BUILD_SHARED + OFF + CACHE BOOL "" + FORCE +) +set(JINJA2CPP_INSTALL + OFF + CACHE BOOL "" + FORCE +) +# Enable PCRE2-based regex lookahead support in Jinja2Cpp. This must be set +# BEFORE FetchContent_MakeAvailable(jinja2cpp) so it propagates to the +# Jinja2Cpp configure step. +set(SUPPORT_REGEX_LOOKAHEAD + ON + CACHE BOOL "" + FORCE +) + +FetchContent_MakeAvailable(jinja2cpp) +if(NOT TARGET jinja2cpp) + message(FATAL_ERROR "Jinja2Cpp target not found after FetchContent.") +endif() + +if(DEFINED jinja2cpp_SOURCE_DIR) + function(executorch_copy_nonstd_header dep_name target header_name dest_root) + set(_copied FALSE) + if(TARGET ${target}) + get_target_property(_aliased ${target} ALIASED_TARGET) + if(_aliased) + set(_resolved_target ${_aliased}) + else() + set(_resolved_target ${target}) + endif() + get_target_property( + _include_dirs ${_resolved_target} INTERFACE_INCLUDE_DIRECTORIES + ) + foreach(_dir IN LISTS _include_dirs) + if(EXISTS "${_dir}/nonstd/${header_name}") + file(MAKE_DIRECTORY "${dest_root}/nonstd") + file( + COPY "${_dir}/nonstd/${header_name}" + DESTINATION "${dest_root}/nonstd" + ) + set(_copied TRUE) + break() + endif() + endforeach() + endif() + if(NOT _copied) + set( + _fallback_path + "${CMAKE_BINARY_DIR}/_deps/${dep_name}-src/include/nonstd/${header_name}" + ) + if(EXISTS "${_fallback_path}") + file(MAKE_DIRECTORY "${dest_root}/nonstd") + file(COPY "${_fallback_path}" DESTINATION "${dest_root}/nonstd") + endif() + endif() + endfunction() + + set(_jinja2cpp_nonstd_root + "${jinja2cpp_SOURCE_DIR}/thirdparty/nonstd" + ) + executorch_copy_nonstd_header( + expected-lite + nonstd::expected-lite + expected.hpp + "${_jinja2cpp_nonstd_root}/expected-lite/include" + ) + executorch_copy_nonstd_header( + variant-lite + nonstd::variant-lite + variant.hpp + "${_jinja2cpp_nonstd_root}/variant-lite/include" + ) + executorch_copy_nonstd_header( + optional-lite + nonstd::optional-lite + optional.hpp + "${_jinja2cpp_nonstd_root}/optional-lite/include" + ) + executorch_copy_nonstd_header( + string-view-lite + nonstd::string-view-lite + string_view.hpp + "${_jinja2cpp_nonstd_root}/string-view-lite/include" + ) +endif() + +# Install the chat_templates.h header so that downstream consumers of the +# installed ExecuTorch SDK can include +# . +install( + FILES chat_templates.h + DESTINATION include/executorch/extension/llm/chat_template +) diff --git a/extension/llm/chat_template/chat_templates.h b/extension/llm/chat_template/chat_templates.h new file mode 100644 index 00000000000..b9e059b80e6 --- /dev/null +++ b/extension/llm/chat_template/chat_templates.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +enum class ChatTemplateType { + None, + Llama3, + Llama32, + Gemma3, + Custom, +}; + +constexpr std::string_view kLlama3Template = R"({{ bos_token }}{%- for message in messages -%}<|start_header_id|>{{ message.role }}<|end_header_id|> + +{{ message.content }}<|eot_id|>{%- endfor -%}{%- if add_generation_prompt -%}<|start_header_id|>assistant<|end_header_id|> + +{%- endif -%})"; + +constexpr std::string_view kGemma3Template = R"({{ bos_token }}{%- for message in messages -%}{%- if message.role == 'assistant' -%}model +{%- else -%}{{ message.role }} +{%- endif -%}{{ message.content }}{%- endfor -%}{%- if add_generation_prompt -%}model +{%- endif -%})"; + +inline const std::unordered_map + kEmbeddedTemplates = { + {ChatTemplateType::Llama3, kLlama3Template}, + {ChatTemplateType::Llama32, kLlama3Template}, + {ChatTemplateType::Gemma3, kGemma3Template}, + }; + +struct ModelTokens { + std::string bos_token; + std::string eos_token; + std::vector stop_tokens; +}; + +inline const std::unordered_map kModelTokens = { + {ChatTemplateType::Llama3, + {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Llama32, + {"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}}, + {ChatTemplateType::Gemma3, + {"", "", {"", ""}}}, +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/chat_template/targets.bzl b/extension/llm/chat_template/targets.bzl new file mode 100644 index 00000000000..197e5da9fb3 --- /dev/null +++ b/extension/llm/chat_template/targets.bzl @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "chat_templates", + exported_headers = [ + "chat_templates.h", + ], + visibility = ["PUBLIC"], + ) diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index 43b89f0a908..da5b42c3f4a 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -54,6 +54,14 @@ endif() list(APPEND runner_deps kernels_util_all_deps) target_link_libraries(extension_llm_runner PUBLIC ${runner_deps}) +target_link_libraries(extension_llm_runner PRIVATE $) +target_include_directories( + extension_llm_runner PRIVATE + $ +) +target_compile_definitions( + extension_llm_runner PUBLIC EXECUTORCH_USE_JINJA2CPP +) set_target_properties( extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON ) @@ -116,23 +124,19 @@ if(EXECUTORCH_BUILD_PYBIND) portable_lib ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} ) + # Set properties for the Python extension set_target_properties( _llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_VISIBILITY_PRESET "hidden" INTERPROCEDURAL_OPTIMIZATION TRUE - CXX_STANDARD 20 ) if(APPLE) - set(RPATH - "@loader_path/../../pybindings;@loader_path/../../../../torch/lib" - ) + set(RPATH "@loader_path/../../pybindings") else() - set(RPATH "$ORIGIN/../../pybindings:$ORIGIN/../../../../torch/lib") + set(RPATH "$ORIGIN/../../pybindings") endif() - set_target_properties( - _llm_runner PROPERTIES BUILD_RPATH "${RPATH}" INSTALL_RPATH "${RPATH}" - ) + set_target_properties(_llm_runner PROPERTIES INSTALL_RPATH ${RPATH}) # Add include directories target_include_directories( _llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS} diff --git a/extension/llm/runner/chat_types.h b/extension/llm/runner/chat_types.h new file mode 100644 index 00000000000..6a7cd2b625f --- /dev/null +++ b/extension/llm/runner/chat_types.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace executorch::extension::llm { + +struct ChatMessage { + std::string role; + std::string content; +}; + +struct ChatConversation { + std::vector messages; + std::string bos_token; + std::string eos_token; + bool add_generation_prompt = true; +}; + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/jinja_chat_formatter.cpp b/extension/llm/runner/jinja_chat_formatter.cpp new file mode 100644 index 00000000000..00054554886 --- /dev/null +++ b/extension/llm/runner/jinja_chat_formatter.cpp @@ -0,0 +1,236 @@ +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { +namespace { + +std::string readFileToString(const std::filesystem::path& path) { + std::ifstream file(path); + if (!file) { + throw std::runtime_error("Failed to open template file: " + path.string()); + } + std::ostringstream buffer; + buffer << file.rdbuf(); + return buffer.str(); +} + +bool templateIncludesBos( + const std::string& template_str, + const ModelTokens& model_tokens) { + if (!model_tokens.bos_token.empty() && + template_str.find(model_tokens.bos_token) != std::string::npos) { + return true; + } + return template_str.find("bos_token") != std::string::npos; +} + +std::string normalizeTemplate(std::string input) { + // These replacements normalize vLLM/HuggingFace Jinja templates so they + // compile/render correctly with Jinja2Cpp, which has stricter parser + // semantics than Python Jinja2. + // + // IMPORTANT: "not tools is none" in Python Jinja means "tools is not none" + // (truthy when tools is defined and non-null), so we map it to a simple + // truthy check on `tools`. Mapping to "not tools" was a bug that would + // skip tool blocks for non-empty tools lists. + constexpr std::array, 10> + replacements = {{ + {"tools = none", "tools = []"}, + {"tools = None", "tools = []"}, + {"tools is not none", "tools"}, + {"tools is not None", "tools"}, + {"not tools is none", "tools"}, + {"not tools is None", "tools"}, + {"tools is none", "not tools"}, + {"tools is None", "not tools"}, + {"messages[1:]", "messages_tail"}, + {"{ \"output\": message.content } | tojson", "message.content | tojson"}, + }}; + // Handle special case that can't be constexpr due to escape sequence + const std::pair gemmaReplacement = { + "{{'model\\n'}}", "{{ 'model\\n' }}"}; + for (const auto& replacement : replacements) { + size_t pos = 0; + while ((pos = input.find(replacement.first, pos)) != std::string::npos) { + input.replace(pos, replacement.first.size(), replacement.second); + pos += replacement.second.size(); + } + } + // Apply the gemma replacement separately + size_t pos = 0; + while ((pos = input.find(gemmaReplacement.first, pos)) != std::string::npos) { + input.replace(pos, gemmaReplacement.first.size(), gemmaReplacement.second); + pos += gemmaReplacement.second.size(); + } + return input; +} + +ChatTemplateType detectTemplateType(const std::string& template_str) { + if (template_str.find("") != std::string::npos) { + return ChatTemplateType::Gemma3; + } + if (template_str.find("<|start_header_id|>") != std::string::npos) { + return ChatTemplateType::Llama3; + } + return ChatTemplateType::Custom; +} + +std::string toLower(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return value; +} + +} // namespace + +} // namespace executorch::extension::llm + +namespace jinja2 { + +// NOLINTBEGIN(facebook-hte-MisplacedTemplateSpecialization,facebook-hte-ShadowingClass) +// This template specialization must be in the jinja2 namespace for the library +// to find it via ADL during template instantiation. +template <> +struct TypeReflection + : TypeReflected { + static auto& GetAccessors() { + static std::unordered_map accessors = { + {"role", + [](const executorch::extension::llm::ChatMessage& msg) { + return jinja2::Reflect(msg.role); + }}, + {"content", + [](const executorch::extension::llm::ChatMessage& msg) { + return jinja2::Reflect(msg.content); + }}, + }; + return accessors; + } +}; +// NOLINTEND(facebook-hte-MisplacedTemplateSpecialization,facebook-hte-ShadowingClass) + +} // namespace jinja2 + +namespace executorch::extension::llm { + +JinjaChatFormatter::JinjaChatFormatter( + const std::string& template_str, + ChatTemplateType type) + : template_str_(template_str), type_(type) { + auto tokens_it = kModelTokens.find(type_); + if (tokens_it != kModelTokens.end()) { + model_tokens_ = tokens_it->second; + } + includes_bos_ = templateIncludesBos(template_str_, model_tokens_); + const std::string normalized_template = normalizeTemplate(template_str_); + compiled_template_ = std::make_unique(); + auto load_result = compiled_template_->Load(normalized_template); + if (!load_result) { + throw std::runtime_error( + "Failed to parse chat template: " + + load_result.error().ToString()); + } +} + +JinjaChatFormatter::~JinjaChatFormatter() = default; + +std::unique_ptr JinjaChatFormatter::fromTemplate( + ChatTemplateType type) { + auto it = kEmbeddedTemplates.find(type); + if (it == kEmbeddedTemplates.end()) { + throw std::runtime_error("Unsupported embedded chat template type."); + } + return std::unique_ptr( + new JinjaChatFormatter(std::string(it->second), type)); +} + +std::unique_ptr JinjaChatFormatter::fromString( + const std::string& template_str) { + const ChatTemplateType inferred_type = detectTemplateType(template_str); + return std::unique_ptr( + new JinjaChatFormatter(template_str, inferred_type)); +} + +std::unique_ptr JinjaChatFormatter::fromFile( + const std::filesystem::path& path) { + return fromString(readFileToString(path)); +} + +std::string JinjaChatFormatter::format( + const std::string& prompt, + const std::string& system_prompt) const { + ChatConversation conversation; + if (!system_prompt.empty()) { + conversation.messages.push_back({"system", system_prompt}); + } + conversation.messages.push_back({"user", prompt}); + conversation.bos_token = model_tokens_.bos_token; + conversation.eos_token = model_tokens_.eos_token; + conversation.add_generation_prompt = true; + return formatConversation(conversation); +} + +std::string JinjaChatFormatter::formatConversation( + const ChatConversation& conversation) const { + jinja2::ValuesMap params; + params["messages"] = jinja2::ValuesList(); + params["messages_tail"] = jinja2::ValuesList(); + bool is_first = true; + for (const auto& msg : conversation.messages) { + params["messages"].asList().push_back(jinja2::Reflect(msg)); + if (!is_first) { + params["messages_tail"].asList().push_back(jinja2::Reflect(msg)); + } + is_first = false; + } + params["bos_token"] = conversation.bos_token; + params["eos_token"] = conversation.eos_token; + params["add_generation_prompt"] = conversation.add_generation_prompt; + // Provide vLLM/HuggingFace-style defaults that templates often reference. + // Templates that don't use these will simply ignore them. + params["tools"] = jinja2::ValuesList(); + params["tool_choice"] = jinja2::Value(); + params["date_string"] = std::string("26 Jul 2024"); + params["chat_template_kwargs"] = jinja2::ValuesMap(); + + auto rendered = compiled_template_->RenderAsString(params); + if (!rendered) { + throw std::runtime_error( + "Failed to render chat template: " + rendered.error().ToString()); + } + return rendered.value(); +} + +ChatTemplateType parseChatTemplateType(const std::string& type_str) { + const std::string lower = toLower(type_str); + if (lower == "none") { + return ChatTemplateType::None; + } + if (lower == "llama3") { + return ChatTemplateType::Llama3; + } + if (lower == "llama3.2" || lower == "llama32" || lower == "llama3_2") { + return ChatTemplateType::Llama32; + } + if (lower == "gemma3") { + return ChatTemplateType::Gemma3; + } + if (lower == "custom" || lower == "jinja") { + return ChatTemplateType::Custom; + } + return ChatTemplateType::None; +} + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/jinja_chat_formatter.h b/extension/llm/runner/jinja_chat_formatter.h new file mode 100644 index 00000000000..e589a9ed18e --- /dev/null +++ b/extension/llm/runner/jinja_chat_formatter.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace jinja2 { +class Template; +} + +namespace executorch::extension::llm { + +class JinjaChatFormatter { + public: + static std::unique_ptr fromTemplate(ChatTemplateType type); + static std::unique_ptr fromString( + const std::string& template_str); + static std::unique_ptr fromFile( + const std::filesystem::path& path); + + ~JinjaChatFormatter(); + + std::string format( + const std::string& prompt, + const std::string& system_prompt = "") const; + std::string formatConversation(const ChatConversation& conversation) const; + + bool includesBos() const { + return includes_bos_; + } + + const ModelTokens& getModelTokens() const { + return model_tokens_; + } + + private: + JinjaChatFormatter(const std::string& template_str, ChatTemplateType type); + + std::string template_str_; + ChatTemplateType type_; + ModelTokens model_tokens_; + bool includes_bos_ = false; + std::unique_ptr compiled_template_; +}; + +ChatTemplateType parseChatTemplateType(const std::string& type_str); + +} // namespace executorch::extension::llm diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 0d4ed99308d..f8ceed67bf1 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -111,11 +111,14 @@ def define_common_targets(): runtime.cxx_library( name = "runner_lib" + aten_suffix, exported_headers = [ + "chat_types.h", + "jinja_chat_formatter.h", "text_llm_runner.h", "llm_runner_helper.h", "constants.h", ], srcs = [ + "jinja_chat_formatter.cpp", "text_llm_runner.cpp", "llm_runner_helper.cpp", "multimodal_runner.cpp", @@ -131,6 +134,7 @@ def define_common_targets(): ":text_decoder_runner" + aten_suffix, ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, + "//executorch/extension/llm/chat_template:chat_templates", "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, "//executorch/extension/memory_allocator:cpu_caching_allocator", "//pytorch/tokenizers:hf_tokenizer", @@ -138,5 +142,6 @@ def define_common_targets(): "//pytorch/tokenizers:sentencepiece", "//pytorch/tokenizers:tekken", "//pytorch/tokenizers:tiktoken", + "@fbsource//third-party/jinja2cpp:jinja2cpp", ], ) diff --git a/extension/llm/runner/test/BUCK b/extension/llm/runner/test/BUCK index 9ed85acadb7..cb8e8fcfb7e 100644 --- a/extension/llm/runner/test/BUCK +++ b/extension/llm/runner/test/BUCK @@ -10,32 +10,21 @@ oncall("executorch") # targets.bzl. This file can contain fbcode-only targets. load(":targets.bzl", "define_common_targets") - +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") non_fbcode_target(_kind = define_common_targets,) -# !!!! fbcode/executorch/extension/llm/runner/test/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!! - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Any targets that should be shared between fbcode and xplat must be defined in -# targets.bzl. This file can contain fbcode-only targets. - -load(":targets.bzl", "define_common_targets") -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") - fbcode_target(_kind = define_common_targets,) +# fbcode-only test that requires access to fbcode//executorch/test/models fbcode_target(_kind = runtime.cxx_test, name = "test_text_decoder_runner", srcs = ["test_text_decoder_runner.cpp"], deps = [ - "//executorch/extension/llm/runner:runner_lib", "//executorch/extension/llm/runner/io_manager:io_manager", + "//executorch/extension/llm/runner:text_decoder_runner", + "//executorch/extension/module:module", + "//executorch/extension/tensor:tensor", "//executorch/kernels/portable:generated_lib", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], @@ -43,5 +32,5 @@ fbcode_target(_kind = runtime.cxx_test, "KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])", "KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])", "NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])", - } + }, ) diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 81b69c0ab9a..64d89ff0ec4 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -19,6 +19,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp + test_jinja_chat_formatter.cpp test_text_llm_runner.cpp test_text_prefiller.cpp test_text_decoder_runner.cpp diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 08044de2d35..a8e3858af9d 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -17,6 +17,14 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "test_jinja_chat_formatter", + srcs = ["test_jinja_chat_formatter.cpp"], + deps = [ + "//executorch/extension/llm/runner:runner_lib", + ], + ) + runtime.cxx_test( name = "test_text_llm_runner", srcs = ["test_text_llm_runner.cpp"], diff --git a/extension/llm/runner/test/test_jinja_chat_formatter.cpp b/extension/llm/runner/test/test_jinja_chat_formatter.cpp new file mode 100644 index 00000000000..d1a3d6d9fff --- /dev/null +++ b/extension/llm/runner/test/test_jinja_chat_formatter.cpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using executorch::extension::llm::ChatConversation; +using executorch::extension::llm::ChatMessage; +using executorch::extension::llm::ChatTemplateType; +using executorch::extension::llm::JinjaChatFormatter; +using executorch::extension::llm::parseChatTemplateType; +using testing::HasSubstr; + +TEST(JinjaChatFormatter, Llama3SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const std::string prompt = "Test prompt"; + const std::string system_prompt = "You are a helpful assistant."; + // Note: The Jinja template uses {%- ... -%} which strips whitespace, + // so the output has \n\n after each <|end_header_id|> and content, + // but no trailing \n\n at the end due to the {%- endif -%} stripping. + const std::string expected = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + + system_prompt + + "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" + prompt + + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; + + EXPECT_EQ(formatter->format(prompt, system_prompt), expected); +} + +TEST(JinjaChatFormatter, Gemma3SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("")); + EXPECT_THAT(result, HasSubstr("user")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("model")); +} + +TEST(JinjaChatFormatter, Llama3WithoutSystemPrompt) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const std::string result = formatter->format("Hello!"); + + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); + // Should not contain system header when no system prompt + EXPECT_THAT(result, ::testing::Not(HasSubstr("<|start_header_id|>system"))); +} + +TEST(JinjaChatFormatter, Llama3IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Gemma3IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Llama3ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, "<|begin_of_text|>"); + EXPECT_EQ(tokens.eos_token, "<|eot_id|>"); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +TEST(JinjaChatFormatter, Gemma3ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, ""); + EXPECT_EQ(tokens.eos_token, ""); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +TEST(JinjaChatFormatter, FormatConversationMultiTurn) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3); + + ChatConversation conversation; + conversation.bos_token = "<|begin_of_text|>"; + conversation.eos_token = "<|eot_id|>"; + conversation.add_generation_prompt = true; + conversation.messages = { + {"user", "Hello"}, + {"assistant", "Hi there!"}, + {"user", "How are you?"}, + }; + + const std::string result = formatter->formatConversation(conversation); + + EXPECT_THAT(result, HasSubstr("Hello")); + EXPECT_THAT(result, HasSubstr("Hi there!")); + EXPECT_THAT(result, HasSubstr("How are you?")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +} + +TEST(JinjaChatFormatter, FromStringLlama3Template) { + const std::string llama_template = + "{{ bos_token }}<|start_header_id|>user<|end_header_id|>\n\n" + "{{ messages[0].content }}<|eot_id|>"; + + auto formatter = JinjaChatFormatter::fromString(llama_template); + + // Should detect Llama3 type from the template content + const std::string result = formatter->format("Test"); + EXPECT_THAT(result, HasSubstr("Test")); +} + +TEST(JinjaChatFormatter, FromStringGemma3Template) { + const std::string gemma_template = + "{{ bos_token }}user\n" + "{{ messages[0].content }}"; + + auto formatter = JinjaChatFormatter::fromString(gemma_template); + + const std::string result = formatter->format("Test"); + EXPECT_THAT(result, HasSubstr("Test")); +} + +TEST(JinjaChatFormatter, UnsupportedTemplateTypeThrows) { + EXPECT_THROW( + JinjaChatFormatter::fromTemplate(ChatTemplateType::None), + std::runtime_error); +} + +// Tests for parseChatTemplateType +TEST(ParseChatTemplateType, ParseNone) { + EXPECT_EQ(parseChatTemplateType("none"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("None"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("NONE"), ChatTemplateType::None); +} + +TEST(ParseChatTemplateType, ParseLlama3) { + EXPECT_EQ(parseChatTemplateType("llama3"), ChatTemplateType::Llama3); + EXPECT_EQ(parseChatTemplateType("LLAMA3"), ChatTemplateType::Llama3); + EXPECT_EQ(parseChatTemplateType("Llama3"), ChatTemplateType::Llama3); +} + +TEST(ParseChatTemplateType, ParseLlama32Variants) { + EXPECT_EQ(parseChatTemplateType("llama3.2"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("llama32"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("llama3_2"), ChatTemplateType::Llama32); + EXPECT_EQ(parseChatTemplateType("LLAMA3.2"), ChatTemplateType::Llama32); +} + +TEST(ParseChatTemplateType, ParseGemma3) { + EXPECT_EQ(parseChatTemplateType("gemma3"), ChatTemplateType::Gemma3); + EXPECT_EQ(parseChatTemplateType("GEMMA3"), ChatTemplateType::Gemma3); + EXPECT_EQ(parseChatTemplateType("Gemma3"), ChatTemplateType::Gemma3); +} + +TEST(ParseChatTemplateType, ParseCustom) { + EXPECT_EQ(parseChatTemplateType("custom"), ChatTemplateType::Custom); + EXPECT_EQ(parseChatTemplateType("jinja"), ChatTemplateType::Custom); + EXPECT_EQ(parseChatTemplateType("CUSTOM"), ChatTemplateType::Custom); +} + +TEST(ParseChatTemplateType, ParseUnknownReturnsNone) { + EXPECT_EQ(parseChatTemplateType("unknown"), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType(""), ChatTemplateType::None); + EXPECT_EQ(parseChatTemplateType("invalid"), ChatTemplateType::None); +} + +TEST(JinjaChatFormatter, Llama32SingleMessage) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + const std::string result = formatter->format("Hello!"); + + // Llama32 uses the same template as Llama3 + EXPECT_THAT(result, HasSubstr("<|begin_of_text|>")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>user<|end_header_id|>")); + EXPECT_THAT(result, HasSubstr("Hello!")); + EXPECT_THAT(result, HasSubstr("<|start_header_id|>assistant<|end_header_id|>")); +} + +TEST(JinjaChatFormatter, Llama32IncludesBos) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + EXPECT_TRUE(formatter->includesBos()); +} + +TEST(JinjaChatFormatter, Llama32ModelTokens) { + auto formatter = JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama32); + const auto& tokens = formatter->getModelTokens(); + + EXPECT_EQ(tokens.bos_token, "<|begin_of_text|>"); + EXPECT_EQ(tokens.eos_token, "<|eot_id|>"); + EXPECT_EQ(tokens.stop_tokens.size(), 2); +} + +// Universal Jinja support: any HuggingFace / vLLM-style template string +// (passed via fromString or fromFile) should work, not just the embedded +// Llama3/Gemma3 templates. This validates the renderer with a generic +// HuggingFace-style template that uses the standard `messages`, +// `bos_token`, and `add_generation_prompt` variables. +TEST(JinjaChatFormatter, UniversalJinjaGenericTemplate) { + // Generic chat template inspired by HuggingFace tokenizer_config.json + // examples. Uses only standard Jinja2 features and standard chat variables. + const std::string generic_template = + "{{ bos_token }}" + "{%- for message in messages -%}" + "<|{{ message.role }}|>\n{{ message.content }}<|end|>\n" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}<|assistant|>\n{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(generic_template); + ChatConversation conv; + conv.bos_token = ""; + conv.add_generation_prompt = true; + conv.messages.push_back(ChatMessage{"user", "Hi there"}); + + const std::string result = formatter->formatConversation(conv); + + EXPECT_THAT(result, HasSubstr("")); + EXPECT_THAT(result, HasSubstr("<|user|>")); + EXPECT_THAT(result, HasSubstr("Hi there")); + EXPECT_THAT(result, HasSubstr("<|assistant|>")); +} + +// Universal Jinja support: templates that reference `tools` (e.g. vLLM's +// tool_chat_template_*.jinja files) should not crash even when no tools +// are passed. The formatter injects `tools = []` (an empty list) so that +// truthy/none checks evaluate consistently. With our normalization, +// `tools is not none` is rewritten to `tools` (a truthy check), which means +// an empty list (no tools) is treated as "no tools available" — matching +// the typical template intent. +TEST(JinjaChatFormatter, UniversalJinjaToolsAwareTemplate) { + const std::string tools_template = + "{%- if tools is not none -%}" + "tools_present" + "{%- else -%}" + "no_tools" + "{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(tools_template); + ChatConversation conv; + conv.add_generation_prompt = false; + + // tools defaults to [] (empty list). After normalization this is a + // truthy check, so the empty-list (no tools) branch should be selected. + EXPECT_EQ(formatter->formatConversation(conv), "no_tools"); +} + +// Universal Jinja support: a template using "not tools is none" (semantically +// "tools is not none") should now safely evaluate without skipping the +// "no tools" branch. This guards the regression where the normalizer mapped +// "not tools is none" -> "not tools", which incorrectly evaluated to true +// for an empty list and would have rendered a tool block when none was +// intended. +TEST(JinjaChatFormatter, UniversalJinjaNormalizedNotToolsIsNone) { + const std::string template_str = + "{%- if not tools is none -%}defined{%- else -%}none{%- endif -%}"; + + auto formatter = JinjaChatFormatter::fromString(template_str); + ChatConversation conv; + conv.add_generation_prompt = false; + + // tools = [] is falsy after normalization (`not tools is none` -> `tools`). + // The "else" branch should be selected (no tools available). + EXPECT_EQ(formatter->formatConversation(conv), "none"); +} +} diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index b0545b8ce18..971dde20867 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -36,7 +36,6 @@ PROGRAM_NO_PRIM_OPS_SRCS = [ "method.cpp", "method_meta.cpp", "program.cpp", - "program_validation.cpp", "tensor_parser_exec_aten.cpp", ] @@ -357,6 +356,7 @@ EXTENSION_RUNNER_UTIL_SRCS = [ ] EXTENSION_LLM_RUNNER_SRCS = [ + "extension/llm/runner/jinja_chat_formatter.cpp", "extension/llm/runner/llm_runner_helper.cpp", "extension/llm/runner/multimodal_prefiller.cpp", "extension/llm/runner/multimodal_runner.cpp", @@ -476,7 +476,6 @@ XNNPACK_BACKEND_BUCK_SRCS = [ "runtime/XNNPACKBackend.cpp", "runtime/XNNWeightsCache.cpp", "runtime/XNNWorkspaceManager.cpp", - "runtime/XnnpackBackendOptions.cpp", "runtime/profiling/XNNProfiler.cpp", ] From 0b6a51d680e2be6e15ed041abf6aa3c7dee556d2 Mon Sep 17 00:00:00 2001 From: Young Han Date: Tue, 12 May 2026 17:32:59 -0700 Subject: [PATCH 2/4] [llm][2/4] Echo-gated special-token filtering and EOS metadata merge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part 2 of the chat-template support stack split out of #16987. What this PR adds ----------------- * extension/llm/runner/text_llm_runner.cpp: Add 'is_special_token()' with a small kKnownSpecialTokens set covering Llama 3.x, Gemma, and generic /// tokens, plus a regex-style match for Llama-format <|...|> tokens. wrapped_callback now suppresses these from the printed stream when GenerationConfig.echo == false. When echo == true, raw model output (including chat-template tokens) is emitted unchanged - this preserves backward compatibility for users who explicitly want to see raw tokens. * extension/llm/runner/llm_runner_helper.cpp: get_eos_ids() now MERGES the tokenizer's primary eos_tok() with any additional EOS IDs the model metadata exports under kEosIds, instead of clearing the set when metadata is present. This is correct for HF-tokenizer models (e.g. Llama 3.x) where eos_tok() = <|end_of_text|> but the model also wants <|eot_id|> as a stop token. Also logs the primary tok and only logs metadata IDs that are newly inserted. Why this is split out --------------------- These are runner-behavior changes that affect ALL TextLLMRunner users, not just the new chat-template path. They deserve focused review for backward-compat impact (echo gating) and EOS-set semantics (merge vs clear). Depends on: PR-A (extension/llm/chat_template/* + JinjaChatFormatter library) — only for stack ordering; this PR has no include or symbol dependency on that library. Original PR (full stack): https://github.com/pytorch/executorch/pull/16987 --- extension/llm/runner/llm_runner_helper.cpp | 11 +++-- extension/llm/runner/text_llm_runner.cpp | 53 +++++++++++++++++++++- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 0744c09e641..ddaf280a9ee 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -158,7 +158,9 @@ std::unordered_set get_eos_ids( tokenizers::Tokenizer* tokenizer, Module* module) { std::unordered_set eos_ids = {tokenizer->eos_tok()}; - // Get EOS IDs if available + ET_LOG(Info, "Primary eos_tok = %" PRIu64, tokenizer->eos_tok()); + + // Get EOS IDs from model metadata if available auto method_names_result = module->method_names(); if (method_names_result.error() != Error::Ok) { ET_LOG(Error, "Failed reading method names"); @@ -167,7 +169,6 @@ std::unordered_set get_eos_ids( const auto& method_names = method_names_result.get(); if (method_names.count(llm::kEosIds)) { - eos_ids.clear(); auto execute_result = module->execute(llm::kEosIds); if (execute_result.error() != Error::Ok) { ET_LOG(Error, "Failed to execute %s", llm::kEosIds); @@ -175,8 +176,10 @@ std::unordered_set get_eos_ids( } for (const auto& eos_id : execute_result.get()) { auto value = eos_id.toScalar().to(); - eos_ids.emplace(value); - ET_LOG(Info, "eos_id = %" PRId64, value); + auto [_, inserted] = eos_ids.emplace(value); + if (inserted) { + ET_LOG(Info, "Added eos_id from model metadata: %" PRId64, value); + } } } return eos_ids; diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index cf7ab50b9c8..3010aea6583 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -20,8 +20,51 @@ #include #include +#include + namespace executorch::extension::llm { +namespace { +// Known special tokens used by LLM chat templates. +// When echo=false, these are filtered out of the streamed output so users +// see clean assistant text. When echo=true, the user explicitly asked for +// raw model output, so we emit them unchanged. +const std::unordered_set kKnownSpecialTokens = { + // Llama 3.x tokens + "<|begin_of_text|>", + "<|end_of_text|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eot_id|>", + // Gemma tokens + "", + "", + "", + "", + // Common tokens + "", + "", + "", + "", +}; + +bool is_special_token(const std::string& text) { + if (text.empty()) { + return false; + } + // Check against known special tokens + if (kKnownSpecialTokens.count(text) > 0) { + return true; + } + // Match Llama-style tokens: <|...|> + if (text.size() >= 4 && text.front() == '<' && text[1] == '|' && + text.back() == '>' && text[text.size() - 2] == '|') { + return true; + } + return false; +} +} // namespace + using ::executorch::extension::Module; using ::executorch::runtime::Error; using ::executorch::runtime::Result; @@ -96,8 +139,14 @@ Error TextLLMRunner::generate( std::function wrapped_callback = [token_callback, config](const std::string& piece) { if (!config.warming) { - llm::safe_printf(piece.c_str()); - fflush(stdout); + // When echo=false, filter out special tokens (e.g. <|eot_id|>) so + // users get clean assistant output. When echo=true, the user asked + // for raw model output including any chat-template tokens, so emit + // everything unchanged. + if (config.echo || !is_special_token(piece)) { + llm::safe_printf(piece.c_str()); + fflush(stdout); + } } if (token_callback) { token_callback(piece); From 13af6b1ed15c0e1e1acebea10ef942406d369d04 Mon Sep 17 00:00:00 2001 From: Young Han Date: Tue, 12 May 2026 17:35:01 -0700 Subject: [PATCH 3/4] [llm][3/4] Python bindings for JinjaChatFormatter + LlamaRunner integration Part 3 of the chat-template support stack split out of #16987. What this PR adds ----------------- * extension/llm/runner/pybindings.cpp: New pybind11 classes: - ChatMessage(role, content) - ChatConversation(messages, bos_token, eos_token, add_generation_prompt) - ChatTemplateType enum (None_, Llama3, Llama32, Gemma3, Custom) - JinjaChatFormatter with from_template / from_string / from_file static factories, format(prompt, system_prompt) and format_conversation(ChatConversation) methods, includes_bos(). * extension/llm/runner/__init__.py: re-exports the new bindings via __all__. * extension/llm/runner/_llm_runner.pyi: type stubs for the new classes so consumers get IDE / mypy support. * extension/llm/runner/test/test_runner_pybindings.py: Python tests covering the new bindings end-to-end. * examples/models/llama/runner/generation.py: LlamaRunner now accepts chat_format / system_prompt / chat_template_file kwargs and exposes _format_prompt + chat_completion using the JinjaChatFormatter. Default chat_format is 'none' (matches llama_main, preserves backward compatibility for existing EagerLlamaRunner / NativeLlamaRunner callers). _resolve_template_type maps 'llama3.2' / 'llama32' / 'llama3_2' to ChatTemplateType.Llama32 (consistent with C++ parseChatTemplateType). * examples/models/llama/runner/eager.py: adds --chat_template_file CLI flag for chat mode. Why this is split out --------------------- Python changes are independently testable and reviewers may want different eyes on the Python vs C++ paths. Also isolates the backward-compat concern around the chat_format default. Depends on: PR-A (extension/llm/chat_template/* + JinjaChatFormatter library headers/symbols). Original PR (full stack): https://github.com/pytorch/executorch/pull/16987 --- examples/models/llama/runner/eager.py | 18 ++++ examples/models/llama/runner/generation.py | 63 ++++++++++++-- extension/llm/runner/__init__.py | 13 ++- extension/llm/runner/_llm_runner.pyi | 34 ++++++++ extension/llm/runner/pybindings.cpp | 57 +++++++++++++ .../llm/runner/test/test_runner_pybindings.py | 83 +++++++++++++++++++ 6 files changed, 256 insertions(+), 12 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7e662317509..fe498f3a552 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -82,6 +82,20 @@ def build_args_parser() -> argparse.ArgumentParser: help="Have multi-turn chat with the model", ) + parser.add_argument( + "--system_prompt", + type=str, + default="", + help="System prompt for chat formatting (optional).", + ) + + parser.add_argument( + "--chat_template_file", + type=str, + default="", + help="Path to a custom Jinja2 chat template file for chat mode.", + ) + parser.add_argument( "--tokenizer_config_path", type=str, @@ -104,6 +118,8 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: temperature = args.temperature show_tokens = args.show_tokens chat_mode = args.chat + system_prompt = args.system_prompt + chat_template_file = args.chat_template_file tokenizer_config_path = args.tokenizer_config_path use_attention_sink = args.use_attention_sink @@ -113,6 +129,8 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: llm_config=llm_config, tokenizer_config_path=tokenizer_config_path, use_attention_sink=use_attention_sink, + system_prompt=system_prompt, + chat_template_file=chat_template_file or None, ) generated_tokens = ( diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 2baa8f5cd14..a56f16c2413 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -9,7 +9,6 @@ from typing import List, Optional import torch - from pytorch_tokenizers import get_tokenizer @@ -56,6 +55,9 @@ def __init__( max_batch_size: int, use_kv_cache: bool, vocab_size: int, + chat_format: str = "none", + system_prompt: str = "", + chat_template_file: Optional[str] = None, device: str = "cpu", ): """ @@ -74,6 +76,10 @@ def __init__( self.use_kv_cache = use_kv_cache self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path) self.device = device + self.chat_format = chat_format + self.system_prompt = system_prompt + self.chat_template_file = chat_template_file + self._chat_formatter = None # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 if vocab_size != self.tokenizer.n_words: print( @@ -207,9 +213,14 @@ def chat_completion( prompt = input("Me: ") while prompt and prompt != exit_prompt: print("LLM: ", end="", flush=True) - prompt_tokens = self.tokenizer.encode( - self._format_prompt(prompt), bos=True, eos=False + formatter = self._get_chat_formatter() + formatted_prompt = ( + formatter.format(prompt, self.system_prompt) + if formatter is not None + else prompt ) + bos = not (formatter is not None and formatter.includes_bos()) + prompt_tokens = self.tokenizer.encode(formatted_prompt, bos=bos, eos=False) generated_tokens = self.generate( prompt_tokens=pre_stop_token + prompt_tokens, max_seq_len=max_seq_len, @@ -227,8 +238,44 @@ def chat_completion( return tokens def _format_prompt(self, prompt: str) -> str: - return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> - -You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> - -{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" + formatter = self._get_chat_formatter() + if formatter is None: + return prompt + return formatter.format(prompt, self.system_prompt) + + def _resolve_template_type(self, chat_template_type) -> Optional["object"]: + normalized = (self.chat_format or "").lower() + if normalized in ("", "none"): + return None + if normalized == "llama3": + return chat_template_type.Llama3 + if normalized in ("llama3.2", "llama32", "llama3_2"): + return chat_template_type.Llama32 + if normalized == "gemma3": + return chat_template_type.Gemma3 + return None + + def _get_chat_formatter(self): + if self._chat_formatter is not None: + return self._chat_formatter + try: + from executorch.extension.llm.runner import ( + ChatTemplateType, + JinjaChatFormatter, + ) + except ImportError as exc: + raise RuntimeError( + "Jinja chat templates require ExecuTorch pybindings. " + "Build with EXECUTORCH_BUILD_PYBIND=ON." + ) from exc + + if self.chat_template_file: + self._chat_formatter = JinjaChatFormatter.from_file(self.chat_template_file) + return self._chat_formatter + + template_type = self._resolve_template_type(ChatTemplateType) + if template_type is None: + return None + + self._chat_formatter = JinjaChatFormatter.from_template(template_type) + return self._chat_formatter diff --git a/extension/llm/runner/__init__.py b/extension/llm/runner/__init__.py index 4e0ced33b21..a8b7328793a 100644 --- a/extension/llm/runner/__init__.py +++ b/extension/llm/runner/__init__.py @@ -11,13 +11,15 @@ enabling processing of mixed inputs (text, images, audio) and text generation. """ -import torch # preload libtorch shared libs for _llm_runner - try: # Import shared components from the compiled C++ extension from executorch.extension.llm.runner._llm_runner import ( # noqa: F401 + ChatConversation, + ChatMessage, + ChatTemplateType, GenerationConfig, Image, + JinjaChatFormatter, make_audio_input, make_image_input, make_raw_audio_input, @@ -26,7 +28,6 @@ MultimodalInput, MultimodalRunner, Stats, - TextLLMRunner, ) except ImportError: raise RuntimeError( @@ -37,6 +38,7 @@ import logging from typing import Callable, List, Optional, Union +import torch from transformers.feature_extraction_utils import BatchFeature @@ -224,8 +226,12 @@ def generate_text_hf( __all__ = [ + "ChatConversation", + "ChatMessage", + "ChatTemplateType", "GenerationConfig", "Image", + "JinjaChatFormatter", "make_audio_input", "make_image_input", "make_raw_audio_input", @@ -233,6 +239,5 @@ def generate_text_hf( "make_token_input", "MultimodalInput", "MultimodalRunner", - "TextLLMRunner", "Stats", ] diff --git a/extension/llm/runner/_llm_runner.pyi b/extension/llm/runner/_llm_runner.pyi index 271cf1e1540..16599e30767 100644 --- a/extension/llm/runner/_llm_runner.pyi +++ b/extension/llm/runner/_llm_runner.pyi @@ -4,6 +4,7 @@ Type stubs for _llm_runner module. This file provides type annotations for the ExecuTorch LLM Runner Python bindings. """ +from enum import Enum from typing import Callable, List, Optional, overload import torch @@ -64,6 +65,39 @@ class GenerationConfig: def __repr__(self) -> str: ... +class ChatTemplateType(Enum): + None_ = 0 + Llama3 = 1 + Llama32 = 2 + Gemma3 = 3 + Custom = 4 + +class ChatMessage: + role: str + content: str + + def __init__(self, role: str, content: str) -> None: ... + def __repr__(self) -> str: ... + +class ChatConversation: + messages: List[ChatMessage] + bos_token: str + eos_token: str + add_generation_prompt: bool + + def __init__(self) -> None: ... + +class JinjaChatFormatter: + @staticmethod + def from_template(template_type: ChatTemplateType) -> "JinjaChatFormatter": ... + @staticmethod + def from_string(template_str: str) -> "JinjaChatFormatter": ... + @staticmethod + def from_file(path: str) -> "JinjaChatFormatter": ... + def format(self, prompt: str, system_prompt: str = "") -> str: ... + def format_conversation(self, conversation: ChatConversation) -> str: ... + def includes_bos(self) -> bool: ... + class Stats: """Statistics for LLM generation performance.""" diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index 3188b5390c4..063c39c2e74 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -12,7 +12,10 @@ #include #include +#include #include +#include +#include #include #include #include @@ -308,6 +311,60 @@ PYBIND11_MODULE(_llm_runner, m) { " warming=" + (config.warming ? "True" : "False") + ">"; }); + // Bind chat template helpers + py::class_(m, "ChatMessage") + .def( + py::init(), + py::arg("role"), + py::arg("content")) + .def_readwrite("role", &ChatMessage::role) + .def_readwrite("content", &ChatMessage::content) + .def("__repr__", [](const ChatMessage& msg) { + std::string content_preview = msg.content.substr(0, 50); + if (msg.content.length() > 50) { + content_preview += "..."; + } + return ""; + }); + + py::class_(m, "ChatConversation") + .def(py::init<>()) + .def_readwrite("messages", &ChatConversation::messages) + .def_readwrite("bos_token", &ChatConversation::bos_token) + .def_readwrite("eos_token", &ChatConversation::eos_token) + .def_readwrite( + "add_generation_prompt", &ChatConversation::add_generation_prompt); + + py::enum_(m, "ChatTemplateType") + .value("None_", ChatTemplateType::None) + .value("Llama3", ChatTemplateType::Llama3) + .value("Llama32", ChatTemplateType::Llama32) + .value("Gemma3", ChatTemplateType::Gemma3) + .value("Custom", ChatTemplateType::Custom); + + py::class_(m, "JinjaChatFormatter") + .def_static( + "from_template", + &JinjaChatFormatter::fromTemplate, + py::arg("template_type")) + .def_static( + "from_string", + &JinjaChatFormatter::fromString, + py::arg("template_str")) + .def_static( + "from_file", &JinjaChatFormatter::fromFile, py::arg("path")) + .def( + "format", + &JinjaChatFormatter::format, + py::arg("prompt"), + py::arg("system_prompt") = "") + .def( + "format_conversation", + &JinjaChatFormatter::formatConversation, + py::arg("conversation")) + .def("includes_bos", &JinjaChatFormatter::includesBos); + // Bind Stats py::class_(m, "Stats") .def_readonly( diff --git a/extension/llm/runner/test/test_runner_pybindings.py b/extension/llm/runner/test/test_runner_pybindings.py index 5619e586c4b..32689d18d5e 100644 --- a/extension/llm/runner/test/test_runner_pybindings.py +++ b/extension/llm/runner/test/test_runner_pybindings.py @@ -18,8 +18,12 @@ import torch from executorch.extension.llm.runner import ( + ChatConversation, + ChatMessage, + ChatTemplateType, GenerationConfig, Image, + JinjaChatFormatter, make_image_input, make_text_input, MultimodalInput, @@ -264,3 +268,82 @@ def test_make_image_input(self): img_tensor_rgba = torch.ones((4, 50, 50), dtype=torch.uint8) * 128 image_input_rgba = make_image_input(img_tensor_rgba) self.assertTrue(image_input_rgba.is_image()) + + +class TestChatTemplateBindings(unittest.TestCase): + """Test Jinja chat template bindings.""" + + def test_format_llama3(self): + formatter = JinjaChatFormatter.from_template(ChatTemplateType.Llama3) + result = formatter.format("Hello!", "System prompt") + self.assertIn("<|begin_of_text|>", result) + self.assertIn("System prompt", result) + self.assertIn("Hello!", result) + self.assertIn("<|start_header_id|>assistant", result) + + def test_format_conversation(self): + formatter = JinjaChatFormatter.from_template(ChatTemplateType.Gemma3) + conversation = ChatConversation() + conversation.bos_token = "" + conversation.eos_token = "" + conversation.add_generation_prompt = True + conversation.messages = [ + ChatMessage("user", "Hi"), + ChatMessage("assistant", "Hello"), + ] + result = formatter.format_conversation(conversation) + self.assertIn("user", result) + self.assertIn("Hi", result) + self.assertIn("model", result) + + def test_format_llama3_without_system_prompt(self): + formatter = JinjaChatFormatter.from_template(ChatTemplateType.Llama3) + result = formatter.format("Hello!") + self.assertIn("<|begin_of_text|>", result) + self.assertIn("Hello!", result) + self.assertIn("<|start_header_id|>user", result) + # Should not contain system when no system prompt provided + self.assertNotIn("<|start_header_id|>system", result) + + def test_format_gemma3(self): + formatter = JinjaChatFormatter.from_template(ChatTemplateType.Gemma3) + result = formatter.format("Test message", "Be helpful") + self.assertIn("", result) + self.assertIn("Test message", result) + self.assertIn("model", result) + + def test_includes_bos_llama3(self): + formatter = JinjaChatFormatter.from_template(ChatTemplateType.Llama3) + self.assertTrue(formatter.includes_bos()) + + def test_includes_bos_gemma3(self): + formatter = JinjaChatFormatter.from_template(ChatTemplateType.Gemma3) + self.assertTrue(formatter.includes_bos()) + + def test_from_string_llama_template(self): + template = "{{ bos_token }}<|start_header_id|>user<|end_header_id|>\n\n{{ messages[0].content }}<|eot_id|>" + formatter = JinjaChatFormatter.from_string(template) + result = formatter.format("Test") + self.assertIn("Test", result) + + def test_from_string_gemma_template(self): + template = "{{ bos_token }}user\n{{ messages[0].content }}" + formatter = JinjaChatFormatter.from_string(template) + result = formatter.format("Test") + self.assertIn("Test", result) + + def test_chat_message_creation(self): + msg = ChatMessage("user", "Hello world") + self.assertEqual(msg.role, "user") + self.assertEqual(msg.content, "Hello world") + + def test_chat_conversation_creation(self): + conv = ChatConversation() + conv.bos_token = "" + conv.eos_token = "" + conv.add_generation_prompt = False + conv.messages = [ChatMessage("user", "Hi")] + self.assertEqual(conv.bos_token, "") + self.assertEqual(conv.eos_token, "") + self.assertFalse(conv.add_generation_prompt) + self.assertEqual(len(conv.messages), 1) From 8338d005307792f4de3542f724bcdb6c879504fc Mon Sep 17 00:00:00 2001 From: Young Han Date: Tue, 12 May 2026 17:37:07 -0700 Subject: [PATCH 4/4] [llm][4/4] llama_main CLI flags + chat_formatter wrapper + universal Jinja docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part 4 of the chat-template support stack split out of #16987. This is the user-facing surface that wires everything together. What this PR adds ----------------- * examples/models/llama/runner/chat_formatter.h: Example-local ChatFormatter abstraction with NoChatFormatter and a JinjaChatFormatterAdapter wrapping executorch::extension::llm::JinjaChatFormatter. parse_chat_format is case-insensitive and trims whitespace, so 'Llama3', ' llama3 ', 'LLAMA3' all map correctly. create_chat_formatter throws std::invalid_argument when chat_format=jinja is passed without --chat_template_file (no more silent no-op). * examples/models/llama/main.cpp: Adds --chat_format, --chat_template_file, --system_prompt, --echo flags. Wraps the prompt with the chat formatter, catches invalid_argument / std::exception from formatter creation with clear error messages. Wires GenerationConfig.echo from the new --echo flag. * examples/models/llama/runner/CMakeLists.txt + targets.bzl: link llama_runner against jinja2cpp (transitive include in chat_formatter.h). * examples/models/llama/CMakeLists.txt: add a guarded FetchContent_Declare(jinja2cpp) so the example builds standalone (when the parent build hasn't already added jinja2cpp via extension/llm/chat_template), without redeclaring when it has. * examples/models/llama/README.md: documents the new flags AND the recommended workflow of passing any HuggingFace / vLLM Jinja template via --chat_template_file (universal Jinja support). * extension/llm/runner/README.md: documents universal Jinja support for the LLM runner library — points at vLLM examples and HF tokenizer_config.json files as supported template sources. Why this is split out --------------------- This is the user-facing CLI integration that depends on PRs A and C. It's the most reviewable in isolation since it's example code with lower blast radius — reviewers can focus on the CLI ergonomics and docs without re-reading library internals. Sample vLLM templates are NOT checked in (per reviewer feedback); documentation here points users to vLLM's examples directory and HuggingFace tokenizer_config.json files, which the universal Jinja support handles directly. Depends on: - PR-A: extension/llm/chat_template/* + JinjaChatFormatter library - PR-C: chat_formatter.h includes JinjaChatFormatter (header-only), but generation.py / eager.py changes are independent Original PR (full stack): https://github.com/pytorch/executorch/pull/16987 --- examples/models/llama/CMakeLists.txt | 51 +++-- examples/models/llama/README.md | 90 ++++++++- examples/models/llama/main.cpp | 64 ++++++ examples/models/llama/runner/CMakeLists.txt | 1 + examples/models/llama/runner/chat_formatter.h | 190 ++++++++++++++++++ examples/models/llama/runner/targets.bzl | 19 +- extension/llm/runner/README.md | 93 +++++---- 7 files changed, 424 insertions(+), 84 deletions(-) create mode 100644 examples/models/llama/runner/chat_formatter.h diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index 6d5b5cc2566..96f1dd78fad 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -48,6 +48,28 @@ set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +# Jinja2Cpp is required by the chat formatter that the llama runner links +# against. When this example CMake file is run standalone (i.e. without the +# parent ExecuTorch build that already pulls Jinja2Cpp from +# extension/llm/chat_template), declare it here as well so the target exists. +# This guard prevents redeclaring it when the parent build has already added +# the dependency. +if(NOT TARGET jinja2cpp) + include(FetchContent) + cmake_policy(SET CMP0077 NEW) + FetchContent_Declare( + jinja2cpp + GIT_REPOSITORY https://github.com/jinja2cpp/Jinja2Cpp.git + GIT_TAG 1.3.2 + GIT_SUBMODULES_RECURSE TRUE + ) + set(JINJA2CPP_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(JINJA2CPP_BUILD_SHARED OFF CACHE BOOL "" FORCE) + set(JINJA2CPP_INSTALL OFF CACHE BOOL "" FORCE) + set(SUPPORT_REGEX_LOOKAHEAD ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(jinja2cpp) +endif() + if(NOT PYTHON_EXECUTABLE) resolve_python_executable() endif() @@ -107,13 +129,8 @@ else() endif() # quantized_ops_lib: Register quantized op kernels into the runtime -if(TARGET quantized_ops_lib) - list(APPEND link_libraries quantized_kernels quantized_ops_lib) - get_target_property(_quantized_imported quantized_ops_lib IMPORTED) - if(NOT _quantized_imported) - executorch_target_link_options_shared_lib(quantized_ops_lib) - endif() -endif() +executorch_target_link_options_shared_lib(quantized_ops_lib) +list(APPEND link_libraries quantized_kernels quantized_ops_lib) if(TARGET custom_ops) executorch_target_link_options_shared_lib(custom_ops) @@ -168,15 +185,6 @@ if(TARGET xnnpack_backend) executorch_target_link_options_shared_lib(xnnpack_backend) endif() -# CUDA backend -if(EXECUTORCH_BUILD_CUDA) - find_package(CUDAToolkit REQUIRED) - list(APPEND link_libraries aoti_cuda_backend) - if(NOT MSVC) - executorch_target_link_options_shared_lib(aoti_cuda_backend) - endif() -endif() - # Vulkan backend if(TARGET vulkan_backend) list(APPEND link_libraries vulkan_backend) @@ -203,12 +211,6 @@ if(TARGET mpsdelegate) executorch_target_link_options_shared_lib(mpsdelegate) endif() -# MLX backend -if(TARGET mlxdelegate) - list(APPEND link_libraries mlxdelegate mlx) - executorch_target_link_options_shared_lib(mlxdelegate) -endif() - # Openvino backend if(TARGET openvino_backend) find_package(OpenVINO REQUIRED) @@ -237,11 +239,6 @@ endif() add_executable(llama_main ${_srcs}) -# Copy MLX metallib for runtime if MLX delegate is enabled -if(TARGET mlxdelegate) - executorch_target_copy_mlx_metallib(llama_main) -endif() - # Only strip symbols for Release and MinSizeRel builds. if(CMAKE_BUILD_TYPE STREQUAL "Release" OR CMAKE_BUILD_TYPE STREQUAL "MinSizeRel" diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index f674d454987..d8083eebbec 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -238,6 +238,10 @@ If you're interested in deploying on non-CPU backends, [please refer the non-cpu ``` cmake --workflow llm-release ``` +If you build with `make llama-cpu` and hit a RapidJSON CMake error, run it as: +``` +CMAKE_POLICY_VERSION_MINIMUM=3.5 make llama-cpu +``` Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions. 2. Build llama runner. @@ -252,6 +256,87 @@ popd cmake-out/examples/models/llama/llama_main --model_path= --tokenizer_path= --prompt= ``` +### Chat Format for Instruct Models + +For **Instruct models** (e.g., Llama-3.2-1B-Instruct), use either +`--chat_format` (built-in) or `--chat_template_file` (any HuggingFace / +vLLM-style Jinja template) to wrap your prompt in the appropriate chat +template. Without this, Instruct models may not generate end-of-turn tokens +and will run until max tokens. + +#### Universal Jinja templates (recommended) + +The runner supports **any HuggingFace / vLLM-style Jinja2 template** via +`--chat_template_file`. Templates from +[vLLM's examples directory](https://github.com/vllm-project/vllm/tree/main/examples) +or HuggingFace `tokenizer_config.json` files work out of the box: + +```bash +# Use any Jinja template from vLLM, HuggingFace, or your own: +cmake-out/examples/models/llama/llama_main \ + --model_path= \ + --tokenizer_path= \ + --chat_template_file=path/to/template.jinja \ + --prompt="What is the capital of France?" +``` + +#### Built-in formats (convenience) + +```bash +# Basic usage with chat format +cmake-out/examples/models/llama/llama_main \ + --model_path= \ + --tokenizer_path= \ + --chat_format=llama3 \ + --prompt="What is the capital of France?" +``` + +**Template/model compatibility:** +- Use Llama templates (`llama3` or the Llama vLLM template) with Llama models. +- Using a Gemma template with a Llama model will cause the model to echo Gemma tokens. + +**Supported chat formats:** +| Format | Models | Template Style | +|--------|--------|----------------| +| `llama3` | Llama 3.x Instruct | `<\|begin_of_text\|><\|start_header_id\|>user...` | +| `gemma3` | Gemma 3 Instruct | `user...` | +| `jinja` | Custom template | Jinja2 chat template from file (requires `--chat_template_file`) | +| `none` | Base models (default) | No formatting | + +**Additional options:** +| Flag | Description | Default | +|------|-------------|---------| +| `--chat_format` | Chat template format (llama3, gemma3, jinja, none) | `none` | +| `--chat_template_file` | Path to custom Jinja2 template (overrides `--chat_format`) | (empty) | +| `--system_prompt` | System prompt to set assistant behavior | (empty) | +| `--echo` | Echo input prompt in output (set to false for clean output) | `true` | + +**Example with system prompt and clean output:** +```bash +cmake-out/examples/models/llama/llama_main \ + --model_path= \ + --tokenizer_path= \ + --chat_format=llama3 \ + --system_prompt="You are a helpful assistant. Be concise." \ + --echo=false \ + --prompt="What is the capital of France?" + +# Output: The capital of France is Paris. +``` + +**Example with a custom template file:** +```bash +cmake-out/examples/models/llama/llama_main \ + --model_path= \ + --tokenizer_path= \ + --chat_template_file=./my_template.jinja \ + --prompt="Hello!" +``` + +**Build note:** If you see a CMake error about RapidJSON requiring +`CMAKE_POLICY_VERSION_MINIMUM=3.5`, add `CMAKE_POLICY_VERSION_MINIMUM=3.5` +to your build environment when running `make llama-cpu`. + To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON` If you an error about "RE2 failed to compile pattern with lookahead:...SUPPORT_REGEX_LOOKAHEAD=ON", add "-DSUPPORT_REGEX_LOOKAHEAD=ON" when building the runner. @@ -277,9 +362,6 @@ cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ - -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ - -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_ENABLE_LOGGING=1 \ -DPYTHON_EXECUTABLE=python \ -DEXECUTORCH_BUILD_XNNPACK=ON \ @@ -303,7 +385,7 @@ cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ - -DSUPPORT_REGEX_LOOKAHEAD=ON \ + -DSUPPORT_REGEX_LOOKAHEAD=ON -Bcmake-out-android/examples/models/llama \ examples/models/llama diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 364efb2b7e8..9de1bbb2ba8 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -7,6 +7,7 @@ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */ +#include #include #include #include @@ -87,6 +88,27 @@ DEFINE_string( "forward", "Method name to execute in the model (e.g., 'forward', 'lora_forward')."); +DEFINE_string( + chat_format, + "none", + "Chat template format for Instruct models. Supported formats: llama3, gemma3, jinja, none (default: none). " + "When set, the prompt will be wrapped in the appropriate chat template."); + +DEFINE_string( + chat_template_file, + "", + "Path to a custom Jinja2 chat template file. Overrides --chat_format."); + +DEFINE_string( + system_prompt, + "", + "System prompt for chat format (optional). Sets the behavior/personality of the assistant."); + +DEFINE_bool( + echo, + true, + "Echo the input prompt in the output. Set to false to only show generated text."); + // Helper function to parse comma-separated string lists std::vector parseStringList(const std::string& input) { std::vector result; @@ -143,6 +165,47 @@ int32_t main(int32_t argc, char** argv) { prompt = prompt_storage.c_str(); } + // Parse chat format and create formatter + auto chat_format = example::parse_chat_format(FLAGS_chat_format); + std::unique_ptr chat_formatter; + try { + chat_formatter = + example::create_chat_formatter(chat_format, FLAGS_chat_template_file); + } catch (const std::invalid_argument& ex) { + ET_LOG(Error, "Invalid chat format configuration: %s", ex.what()); + return 1; + } catch (const std::exception& ex) { + ET_LOG(Error, "Failed to load chat template: %s", ex.what()); + return 1; + } + const bool using_chat_template = + chat_format != example::ChatFormat::None || + !FLAGS_chat_template_file.empty(); + + // Apply chat formatting to the prompt (no-op when chat_format=none and no + // template file provided). + std::string formatted_prompt = + chat_formatter->format(prompt, FLAGS_system_prompt); + prompt = formatted_prompt.c_str(); + + if (using_chat_template) { + if (!FLAGS_chat_template_file.empty()) { + ET_LOG( + Info, + "Using chat template file: %s", + FLAGS_chat_template_file.c_str()); + } else { + ET_LOG(Info, "Using chat format: %s", FLAGS_chat_format.c_str()); + } + if (FLAGS_num_bos > 0 && chat_formatter->includes_bos()) { + ET_LOG( + Info, + "Note: Chat format '%s' already includes BOS token. " + "Consider setting --num_bos=0 to avoid duplicate BOS tokens.", + FLAGS_chat_format.c_str()); + } + } + float temperature = FLAGS_temperature; int32_t seq_len = FLAGS_seq_len; @@ -200,6 +263,7 @@ int32_t main(int32_t argc, char** argv) { } // generate executorch::extension::llm::GenerationConfig config{ + .echo = FLAGS_echo, .temperature = temperature}; config.ignore_eos = FLAGS_ignore_eos; diff --git a/examples/models/llama/runner/CMakeLists.txt b/examples/models/llama/runner/CMakeLists.txt index 7c6c5413ab3..95d15681602 100644 --- a/examples/models/llama/runner/CMakeLists.txt +++ b/examples/models/llama/runner/CMakeLists.txt @@ -53,6 +53,7 @@ set(llama_runner_deps target_link_libraries(llama_runner PUBLIC ${llama_runner_deps}) target_link_libraries(llama_runner PUBLIC tokenizers::tokenizers) +target_link_libraries(llama_runner PRIVATE jinja2cpp) target_include_directories( llama_runner PUBLIC ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include diff --git a/examples/models/llama/runner/chat_formatter.h b/examples/models/llama/runner/chat_formatter.h new file mode 100644 index 00000000000..0458fc45203 --- /dev/null +++ b/examples/models/llama/runner/chat_formatter.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace example { + +/** + * Supported chat template formats for different model families. + */ +enum class ChatFormat { + None, // No formatting (pass-through) + Llama3, // Llama 3.x Instruct models + Gemma3, // Gemma 3 Instruct models + Jinja, // Custom Jinja template from file +}; + +/** + * Abstract base class for chat formatters. + * Implementations format user prompts into model-specific chat templates. + */ +class ChatFormatter { + public: + virtual ~ChatFormatter() = default; + + /** + * Format a user prompt into the model's expected chat template. + * + * @param prompt The user's input message + * @param system_prompt Optional system prompt to set model behavior + * @return Formatted string ready for tokenization + */ + virtual std::string format( + const std::string& prompt, + const std::string& system_prompt = "") const = 0; + + /** + * Whether this formatter includes BOS token in the template. + * If true, the runner should not prepend additional BOS tokens. + */ + virtual bool includes_bos() const = 0; +}; + +/** + * No formatting (pass-through). + * Use when the prompt is already formatted or for base models. + */ +class NoChatFormatter : public ChatFormatter { + public: + std::string format( + const std::string& prompt, + const std::string& system_prompt = "") const override { + (void)system_prompt; // Unused in pass-through mode + return prompt; + } + + bool includes_bos() const override { + return false; // User controls BOS token + } +}; + +class JinjaChatFormatterAdapter : public ChatFormatter { + public: + explicit JinjaChatFormatterAdapter( + std::unique_ptr formatter) + : formatter_(std::move(formatter)) {} + + std::string format( + const std::string& prompt, + const std::string& system_prompt = "") const override { + return formatter_->format(prompt, system_prompt); + } + + bool includes_bos() const override { + return formatter_->includesBos(); + } + + private: + std::unique_ptr formatter_; +}; + +/** + * Parse a chat format string into the corresponding enum value. + * + * The lookup is case-insensitive and tolerant of surrounding whitespace, + * so values like "Llama3", " llama3 ", or "LLAMA3" all map to + * ChatFormat::Llama3. + * + * @param format_str String identifier (e.g., "llama3", "none") + * @return ChatFormat enum value, defaults to None for unknown formats + */ +inline ChatFormat parse_chat_format(const std::string& format_str) { + static const std::unordered_map format_map = { + {"none", ChatFormat::None}, + {"llama3", ChatFormat::Llama3}, + {"llama3.2", ChatFormat::Llama3}, + {"llama32", ChatFormat::Llama3}, + {"llama3_2", ChatFormat::Llama3}, + {"gemma3", ChatFormat::Gemma3}, + {"jinja", ChatFormat::Jinja}, + }; + + // Trim whitespace and lowercase to make CLI input forgiving. + std::string normalized = format_str; + const auto first_non_ws = + normalized.find_first_not_of(" \t\n\r\f\v"); + const auto last_non_ws = normalized.find_last_not_of(" \t\n\r\f\v"); + if (first_non_ws == std::string::npos) { + return ChatFormat::None; + } + normalized = normalized.substr(first_non_ws, last_non_ws - first_non_ws + 1); + std::transform( + normalized.begin(), + normalized.end(), + normalized.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + auto it = format_map.find(normalized); + if (it != format_map.end()) { + return it->second; + } + return ChatFormat::None; +} + +/** + * Get a human-readable list of supported chat formats. + */ +inline std::string get_supported_formats() { + return "llama3, gemma3, jinja, none"; +} + +/** + * Factory function to create the appropriate ChatFormatter instance. + * + * Universal Jinja support: when `template_file` is non-empty, any + * HuggingFace-style or vLLM-style chat template (e.g. files from + * https://github.com/vllm-project/vllm/tree/main/examples) can be passed in + * regardless of the `format` value. The Jinja formatter will load and + * render the template directly. + * + * @param format The chat format to use + * @param template_file Optional path to a Jinja2 chat template file. If + * provided, takes precedence over `format`. + * @return Unique pointer to a ChatFormatter instance + * @throws std::invalid_argument if `format == ChatFormat::Jinja` and no + * `template_file` is provided. + */ +inline std::unique_ptr create_chat_formatter( + ChatFormat format, + const std::string& template_file = "") { + using executorch::extension::llm::ChatTemplateType; + using executorch::extension::llm::JinjaChatFormatter; + + if (!template_file.empty()) { + return std::make_unique( + JinjaChatFormatter::fromFile(template_file)); + } + + switch (format) { + case ChatFormat::Llama3: + return std::make_unique( + JinjaChatFormatter::fromTemplate(ChatTemplateType::Llama3)); + case ChatFormat::Gemma3: + return std::make_unique( + JinjaChatFormatter::fromTemplate(ChatTemplateType::Gemma3)); + case ChatFormat::Jinja: + throw std::invalid_argument( + "chat_format=jinja requires --chat_template_file="); + case ChatFormat::None: + default: + return std::make_unique(); + } +} + +} // namespace example diff --git a/examples/models/llama/runner/targets.bzl b/examples/models/llama/runner/targets.bzl index 81a2df117f4..8dd84beebbd 100644 --- a/examples/models/llama/runner/targets.bzl +++ b/examples/models/llama/runner/targets.bzl @@ -6,21 +6,6 @@ def _get_operator_lib(aten = False): else: return ["//executorch/configurations:optimized_native_cpu_ops", "//executorch/extension/llm/custom_ops:custom_ops"] -def _get_torchao_lowbit_deps(): - """Returns torchao lowbit kernel deps for shared embedding and linear on ARM builds.""" - if runtime.is_oss: - return [] - else: - # Use select to conditionally include torchao lowbit kernels only on ARM64 builds - # These kernels are only available for aarch64 architecture - return select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": [ - "//xplat/pytorch/ao/torchao/csrc/cpu/shared_kernels/embedding_xbit:op_embedding_xbit_executorch", - "//xplat/pytorch/ao/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight:op_linear_8bit_act_xbit_weight_executorch", - ], - }) - def get_qnn_dependency(): # buck build -c executorch.enable_qnn=true //executorch/examples/models/llama/runner:runner # Check if QNN is enabled before including the dependency @@ -41,6 +26,7 @@ def define_common_targets(): "runner.cpp", ], exported_headers = [ + "chat_formatter.h", "runner.h", ], deps = [ @@ -62,8 +48,7 @@ def define_common_targets(): "//executorch/examples/models/llama/tokenizer:tiktoken", "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:hf_tokenizer", - "//pytorch/tokenizers:regex_lookahead", - ] + (_get_operator_lib(aten)) + _get_torchao_lowbit_deps() + ([ + ] + (_get_operator_lib(aten)) + ([ # Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE) # Therefore enable it explicitly for now to avoid failing tests "//executorch/backends/vulkan:vulkan_backend_lib", diff --git a/extension/llm/runner/README.md b/extension/llm/runner/README.md index 4fa3b079039..9a0b8df2ec1 100644 --- a/extension/llm/runner/README.md +++ b/extension/llm/runner/README.md @@ -89,6 +89,44 @@ MultimodalRunner Supported Model Architecture: ## Quick Start +## Chat Templates (Jinja2) + +The runner supports any **HuggingFace / vLLM-style Jinja2 chat template**. +The implementation is built on Jinja2Cpp and accepts the same templates that +[vLLM ships](https://github.com/vllm-project/vllm/tree/main/examples) or the +`chat_template` field exposed by HuggingFace `tokenizer_config.json` files. + +### Universal Jinja Support + +You can point `--chat_template_file` at any `.jinja` file from: +- HuggingFace tokenizer configs (`tokenizer_config.json` → `chat_template`). +- vLLM's example templates: +- Your own custom Jinja2 chat templates. + +The template is rendered with the standard variables expected by HuggingFace: +`messages`, `bos_token`, `eos_token`, `add_generation_prompt`, `tools`, +`tool_choice`, and `date_string`. Templates that don't use these variables +will simply ignore them. + +### Quick example + +```bash +# Download a chat template from HuggingFace or vLLM's examples directory: +curl -L -o llama3.2_chat.jinja \ + https://raw.githubusercontent.com/vllm-project/vllm/main/examples/tool_chat_template_llama3.2_pythonic.jinja + +cmake-out/examples/models/llama/llama_main \ + --model_path= \ + --tokenizer_path= \ + --chat_template_file=llama3.2_chat.jinja \ + --prompt="Hello" +``` + +Notes: +- Match the template to the model family (e.g., Llama templates for Llama models). +- For clean text output, pass `--echo=false` so prompt formatting tokens are not printed. +- `--chat_template_file` always takes precedence over `--chat_format`. + ### TextLLMRunner Example ```cpp @@ -173,26 +211,9 @@ The LLM Runner framework provides Python bindings for easy integration with Pyth Build the Python bindings as part of the ExecuTorch build: ```bash -# Option 1: Use the install script (includes pybindings by default) +# Build from source with Python bindings enabled: +# In executorch root directory bash install_executorch.sh - -# Option 2: Build with CMake directly -cmake -B cmake-out \ - -DEXECUTORCH_BUILD_PYBIND=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - cmake/ -cmake --build cmake-out -j$(nproc) --target install - -# Option 3: pip install from source (includes pybindings) -pip install -e . --no-build-isolation -``` - -The key CMake flag is `EXECUTORCH_BUILD_PYBIND=ON`, which builds the `_llm_runner` extension module providing `TextLLMRunner`, `MultimodalRunner`, `GenerationConfig`, and related classes. - -Verify the installation: - -```python -from executorch.extension.llm.runner import TextLLMRunner, GenerationConfig ``` ### Quick Start Examples @@ -201,7 +222,7 @@ from executorch.extension.llm.runner import TextLLMRunner, GenerationConfig ```python from executorch.extension.llm.runner import ( - GenerationConfig, MultimodalRunner, + GenerationConfig, MultimodalRunner, make_text_input, make_image_input, make_audio_input ) import torch @@ -246,7 +267,7 @@ runner.generate(inputs, config, token_callback, stats_callback) ```python from executorch.extension.llm.runner import ( MultimodalRunner, GenerationConfig, - make_text_input, make_token_input, make_image_input, + make_text_input, make_token_input, make_image_input, make_audio_input, make_raw_audio_input ) import torch @@ -314,8 +335,8 @@ inputs_hf = processor(prompt, image, return_tensors="pt") # Generate using HF inputs directly config = GenerationConfig(max_new_tokens=100, temperature=0.7) runner.generate_hf( - inputs_hf, - config, + inputs_hf, + config, image_token_id=processor.tokenizer.convert_tokens_to_ids(""), token_callback=lambda token: print(token, end='', flush=True) ) @@ -330,13 +351,13 @@ class ChatSession: def __init__(self, model_path: str, tokenizer_path: str): self.runner = MultimodalRunner(model_path, tokenizer_path) self.config = GenerationConfig(max_new_tokens=150, temperature=0.7, echo=False) - + def send_message(self, message: str) -> str: """Send a message and get response""" inputs = [make_text_input(message)] response = self.runner.generate_text(inputs, self.config) return response - + def send_multimodal(self, text: str, image_tensor: torch.Tensor) -> str: """Send text + image and get response""" inputs = [ @@ -345,7 +366,7 @@ class ChatSession: ] response = self.runner.generate_text(inputs, self.config) return response - + def reset_conversation(self): """Reset the conversation state""" self.runner.reset() @@ -388,7 +409,7 @@ config.max_new_tokens = 50 #### MultimodalInput Types ```python from executorch.extension.llm.runner import ( - MultimodalInput, make_text_input, make_token_input, + MultimodalInput, make_text_input, make_token_input, make_image_input, make_audio_input ) @@ -424,25 +445,25 @@ def detailed_stats_callback(stats): print(f"\n=== Generation Statistics ===") print(f"Prompt tokens: {stats.num_prompt_tokens}") print(f"Generated tokens: {stats.num_generated_tokens}") - + # Timing breakdown model_load_time = stats.model_load_end_ms - stats.model_load_start_ms if model_load_time > 0: print(f"Model load time: {model_load_time}ms") - + inference_time = stats.inference_end_ms - stats.inference_start_ms if inference_time > 0: print(f"Total inference time: {inference_time}ms") - + # Calculate throughput tokens_per_sec = stats.num_generated_tokens * 1000 / inference_time print(f"Generation speed: {tokens_per_sec:.1f} tokens/sec") - + # Time to first token if stats.first_token_ms > stats.inference_start_ms: ttft = stats.first_token_ms - stats.inference_start_ms print(f"Time to first token: {ttft}ms") - + # Export to JSON for logging json_stats = stats.to_json_string() print(f"JSON stats: {json_stats}") @@ -459,17 +480,17 @@ import torch try: runner = MultimodalRunner("model.pte", "tokenizer.bin") - + # Invalid image tensor will raise RuntimeError invalid_image = torch.rand(2, 224, 224, 3) # Wrong number of dimensions inputs = [make_image_input(invalid_image)] - + config = GenerationConfig(max_new_tokens=50) runner.generate_text(inputs, config) - + except RuntimeError as e: print(f"Generation failed: {e}") - + except FileNotFoundError as e: print(f"Model or tokenizer file not found: {e}") ```