From bc88957001f6e655b151a7cc2cbb876af3e99b42 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 13 May 2026 15:22:44 -0700 Subject: [PATCH] Add LogitProcessor interface for pre-sampling logit transforms (#19517) Summary: Introduces a `LogitProcessor` abstract interface that allows callers to mutate logits in place between the model forward pass and the sampler. This enables grammar-constrained decoding, logit biasing, repetition penalties, and similar pre-sampling transforms without modifying the core generation loop. Changes: - `LogitProcessor` (new): abstract class with a constructor that takes `vocab_size` and a pure virtual `process(float*)` method, placed in `extension/llm/sampler/`. The `vocab_size` is fixed per model and stored as a member, avoiding redundant per-call arguments. - `TextTokenGenerator`: gains `add_logit_processor()`, `clear_logit_processors()`, and `num_logit_processors()`. The processor chain runs after the model step and before `logits_to_token()`. When no processors are registered, behavior is identical to before. - `apply_logit_processors_()`: private helper that advances to the last-position logits for 3D tensors (mirroring `logits_to_token`), and invokes each processor in order. Supports Float, Half, BFloat16, and UInt16 dtypes: Float logits are processed in place (zero-copy); for other dtypes, logits are cast to a temporary float buffer, processed, then cast back to the original dtype. - Buck: `logit_processor.h` exported from the sampler target; `text_token_generator` gains a direct dep on sampler; test target added. Processors must be configured before calling `generate()` -- concurrent modification during generation is not safe. Differential Revision: D104767967 --- extension/llm/runner/targets.bzl | 1 + extension/llm/runner/text_token_generator.h | 81 +++++++++++ extension/llm/sampler/logit_processor.h | 57 ++++++++ extension/llm/sampler/targets.bzl | 1 + extension/llm/sampler/test/targets.bzl | 10 ++ .../llm/sampler/test/test_logit_processor.cpp | 136 ++++++++++++++++++ 6 files changed, 286 insertions(+) create mode 100644 extension/llm/sampler/logit_processor.h create mode 100644 extension/llm/sampler/test/test_logit_processor.cpp diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 0d4ed99308d..d3e12266adc 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -68,6 +68,7 @@ def define_common_targets(): visibility = ["PUBLIC"], exported_deps = [ ":text_decoder_runner" + aten_suffix, + "//executorch/extension/llm/sampler:sampler" + aten_suffix, "//pytorch/tokenizers:headers", "//executorch/extension/module:module" + aten_suffix, "//executorch/extension/tensor:tensor" + aten_suffix, diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 7e7fbbf1341..14e62ecf503 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -10,10 +10,14 @@ #pragma once #include +#include +#include #include #include +#include #include +#include #include namespace executorch { @@ -38,6 +42,20 @@ class ET_EXPERIMENTAL TextTokenGenerator { ignore_eos_ = ignore_eos; } + void add_logit_processor(std::shared_ptr processor) { + if (processor) { + logit_processors_.push_back(std::move(processor)); + } + } + + void clear_logit_processors() { + logit_processors_.clear(); + } + + size_t num_logit_processors() const { + return logit_processors_.size(); + } + virtual ~TextTokenGenerator() = default; /** @@ -109,6 +127,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { prev_token = cur_token; + if (!logit_processors_.empty()) { + ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors_(logits_tensor)); + } + stats_->on_sampling_begin(); cur_token = text_decoder_runner_->logits_to_token(logits_tensor, temperature); @@ -177,6 +199,63 @@ class ET_EXPERIMENTAL TextTokenGenerator { } private: + inline ::executorch::runtime::Error apply_logit_processors_( + ::executorch::aten::Tensor& logits_tensor) { + ET_CHECK_OR_RETURN_ERROR( + logits_tensor.dim() >= 2, + InvalidArgument, + "LogitProcessor expects logits with dim >= 2, got %d", + static_cast(logits_tensor.dim())); + + const int32_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); + int32_t offset = 0; + if (logits_tensor.dim() == 3) { + const int32_t num_tokens = logits_tensor.size(1); + ET_CHECK_OR_RETURN_ERROR( + num_tokens > 0, + InvalidArgument, + "LogitProcessor expects non-empty sequence dimension"); + offset = (num_tokens - 1) * vocab_size; + } + + if (logits_tensor.scalar_type() ==::executorch::aten::ScalarType::Float) { + auto* logits = logits_tensor.mutable_data_ptr() + offset; + for (auto& processor : logit_processors_) { + processor->process(logits); + } + return ::executorch::runtime::Error::Ok; + } + + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in apply_logit_processors_"); + } + } ctx; + + std::vector temp(vocab_size); + ET_SWITCH_THREE_TYPES( + Half, + BFloat16, + UInt16, + logits_tensor.scalar_type(), + ctx, + "apply_logit_processors_", + CTYPE, + [&]() { + auto* logits = logits_tensor.mutable_data_ptr() + offset; + for (int32_t i = 0; i < vocab_size; ++i) { + temp[i] = static_cast(logits[i]); + } + for (auto& processor : logit_processors_) { + processor->process(temp.data()); + } + for (int32_t i = 0; i < vocab_size; ++i) { + logits[i] = static_cast(temp[i]); + } + }); + return ::executorch::runtime::Error::Ok; + } + /** * Note: TextTokenGenerator does not own the tokenizer_ and * text_decoder_runner_. The lifecycle of these objects should be managed @@ -189,6 +268,8 @@ class ET_EXPERIMENTAL TextTokenGenerator { bool use_kv_cache_; bool ignore_eos_ = false; + std::vector> logit_processors_; + // state machine std::atomic should_stop_{false}; diff --git a/extension/llm/sampler/logit_processor.h b/extension/llm/sampler/logit_processor.h new file mode 100644 index 00000000000..28c727c8cea --- /dev/null +++ b/extension/llm/sampler/logit_processor.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * Interface for in-place logit transformations applied between the model's + * forward pass and the sampler. Examples include: + * - Grammar / constrained-decoding masks (set disallowed tokens to -inf) + * - Logit bias (additive per-token bias) + * - Custom debug instrumentation + * + * A `TextTokenGenerator` may be configured with a chain of processors. They + * are invoked in order on every decoding step, before the sampler sees the + * logits. Each processor mutates the buffer in place; later processors + * observe earlier processors' modifications. + * + * Implementations must be cheap to call repeatedly — `process()` runs on the + * critical path of every generated token. + */ +class ET_EXPERIMENTAL LogitProcessor { + public: + explicit LogitProcessor(int32_t vocab_size) : vocab_size_(vocab_size) {} + virtual ~LogitProcessor() = default; + + /** + * Modify logits in place for the current decoding step. + * + * @param logits Mutable pointer to the logits buffer for the current + * step. Must contain at least `vocab_size` elements. + */ + virtual void process(float* logits) = 0; + + int32_t vocab_size() const { + return vocab_size_; + } + + private: + int32_t vocab_size_; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/sampler/targets.bzl b/extension/llm/sampler/targets.bzl index 42551e248e5..94a62745d6a 100644 --- a/extension/llm/sampler/targets.bzl +++ b/extension/llm/sampler/targets.bzl @@ -7,6 +7,7 @@ def define_common_targets(): runtime.cxx_library( name = "sampler" + aten_suffix, exported_headers = [ + "logit_processor.h", "sampler.h", "util.h", ], diff --git a/extension/llm/sampler/test/targets.bzl b/extension/llm/sampler/test/targets.bzl index 83b3d31e4cb..05138bea0d8 100644 --- a/extension/llm/sampler/test/targets.bzl +++ b/extension/llm/sampler/test/targets.bzl @@ -22,3 +22,13 @@ def define_common_targets(): "//caffe2:torch-cpp", ], ) + + runtime.cxx_test( + name = "test_logit_processor", + srcs = [ + "test_logit_processor.cpp", + ], + deps = [ + "//executorch/extension/llm/sampler:sampler", + ], + ) diff --git a/extension/llm/sampler/test/test_logit_processor.cpp b/extension/llm/sampler/test/test_logit_processor.cpp new file mode 100644 index 00000000000..ed2f76cb711 --- /dev/null +++ b/extension/llm/sampler/test/test_logit_processor.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include + +using ::executorch::extension::llm::LogitProcessor; + +namespace { + +// Adds a fixed bias to every logit slot. Records how many times it was +// invoked so tests can verify chain ordering. +class AddBiasProcessor : public LogitProcessor { + public: + AddBiasProcessor(int32_t vocab_size, float bias) + : LogitProcessor(vocab_size), bias_(bias) {} + + void process(float* logits) override { + ++call_count_; + for (int32_t i = 0; i < vocab_size(); ++i) { + logits[i] += bias_; + } + } + + int call_count() const { + return call_count_; + } + + private: + float bias_; + int call_count_ = 0; +}; + +class MultiplyProcessor : public LogitProcessor { + public: + MultiplyProcessor(int32_t vocab_size, float factor) + : LogitProcessor(vocab_size), factor_(factor) {} + + void process(float* logits) override { + for (int32_t i = 0; i < vocab_size(); ++i) { + logits[i] *= factor_; + } + } + + private: + float factor_; +}; + +class MaskTokenProcessor : public LogitProcessor { + public: + MaskTokenProcessor(int32_t vocab_size, int32_t banned_token) + : LogitProcessor(vocab_size), banned_token_(banned_token) {} + + void process(float* logits) override { + if (banned_token_ >= 0 && banned_token_ < vocab_size()) { + logits[banned_token_] = -std::numeric_limits::infinity(); + } + } + + private: + int32_t banned_token_; +}; + +} // namespace + +// A single processor sees the buffer and may mutate it in place. +TEST(LogitProcessorTest, SingleProcessorMutatesLogits) { + std::vector logits = {1.0f, 2.0f, 3.0f, 4.0f}; + AddBiasProcessor bias{static_cast(logits.size()), 10.0f}; + + bias.process(logits.data()); + + const std::vector expected = {11.0f, 12.0f, 13.0f, 14.0f}; + EXPECT_EQ(logits, expected); + EXPECT_EQ(bias.call_count(), 1); +} + +// Multiply(×2) then Add(+1) gives (x*2)+1, which differs from +// Add(+1) then Multiply(×2) = (x+1)*2. Non-commutative operations +// verify that processors run in registration order. +TEST(LogitProcessorTest, ProcessorChainAppliesInOrder) { + std::vector logits = {1.0f, 2.0f, 3.0f, 4.0f}; + + const int32_t vocab_size = static_cast(logits.size()); + std::vector> chain; + chain.push_back(std::make_shared(vocab_size, 2.0f)); + chain.push_back(std::make_shared(vocab_size, 1.0f)); + + for (auto& p : chain) { + // NOLINTNEXTLINE(facebook-hte-Deprecated) + p->process(logits.data()); + } + + // (x*2)+1, NOT (x+1)*2 + const std::vector expected = {3.0f, 5.0f, 7.0f, 9.0f}; + EXPECT_EQ(logits, expected); +} + +// A masking processor zeroes (well, -inf's) a specific token slot. This is +// the pattern grammar processors will follow. +TEST(LogitProcessorTest, MaskTokenDrivesArgmaxAway) { + std::vector logits = {0.1f, 0.2f, 0.99f, 0.4f}; // argmax = 2 + + MaskTokenProcessor mask{ + static_cast(logits.size()), /*banned_token=*/2}; + mask.process(logits.data()); + + const std::vector expected = { + 0.1f, 0.2f, -std::numeric_limits::infinity(), 0.4f}; + EXPECT_EQ(logits, expected); +} + +TEST(LogitProcessorTest, MaskTokenOutOfRangeIsNoOp) { + std::vector logits = {1.0f, 2.0f, 3.0f}; + const std::vector snapshot = logits; + + MaskTokenProcessor mask_over{ + static_cast(logits.size()), /*banned_token=*/99}; + mask_over.process(logits.data()); + EXPECT_EQ(logits, snapshot); + + MaskTokenProcessor mask_neg{ + static_cast(logits.size()), /*banned_token=*/-1}; + mask_neg.process(logits.data()); + EXPECT_EQ(logits, snapshot); +}