diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 0d4ed99308d..d3e12266adc 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -68,6 +68,7 @@ def define_common_targets(): visibility = ["PUBLIC"], exported_deps = [ ":text_decoder_runner" + aten_suffix, + "//executorch/extension/llm/sampler:sampler" + aten_suffix, "//pytorch/tokenizers:headers", "//executorch/extension/module:module" + aten_suffix, "//executorch/extension/tensor:tensor" + aten_suffix, diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 7e7fbbf1341..3627cacf3c3 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -10,9 +10,12 @@ #pragma once #include +#include +#include #include #include +#include #include #include @@ -38,6 +41,20 @@ class ET_EXPERIMENTAL TextTokenGenerator { ignore_eos_ = ignore_eos; } + void add_logit_processor(std::shared_ptr processor) { + if (processor) { + logit_processors_.push_back(std::move(processor)); + } + } + + void clear_logit_processors() { + logit_processors_.clear(); + } + + size_t num_logit_processors() const { + return logit_processors_.size(); + } + virtual ~TextTokenGenerator() = default; /** @@ -109,6 +126,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { prev_token = cur_token; + for (auto& processor : logit_processors_) { + ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor)); + } + stats_->on_sampling_begin(); cur_token = text_decoder_runner_->logits_to_token(logits_tensor, temperature); @@ -189,6 +210,8 @@ class ET_EXPERIMENTAL TextTokenGenerator { bool use_kv_cache_; bool ignore_eos_ = false; + std::vector> logit_processors_; + // state machine std::atomic should_stop_{false}; diff --git a/extension/llm/sampler/logit_processor.h b/extension/llm/sampler/logit_processor.h new file mode 100644 index 00000000000..1e499cc18a6 --- /dev/null +++ b/extension/llm/sampler/logit_processor.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * In-place logit transform applied between the model forward pass and the + * sampler. Examples: grammar masks, logit bias, repetition penalty. + * + * `TextTokenGenerator` runs registered processors in order; each sees + * prior processors' edits. Called once per decoded token — keep it cheap. + * + * Tensor contract: + * rank 2 [batch, vocab] — operate on the full last dim + * rank 3 [batch, seq, vocab] — operate on the LAST sequence position + * other ranks — undefined behavior + * + * Implementations dispatch their own dtype (the chain runner neither casts + * nor copies the tensor). Return non-Ok to abort the chain. + */ +class ET_EXPERIMENTAL LogitProcessor { + public: + virtual ~LogitProcessor() = default; + + virtual ::executorch::runtime::Error process( + ::executorch::aten::Tensor logits) = 0; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/sampler/targets.bzl b/extension/llm/sampler/targets.bzl index 42551e248e5..94a62745d6a 100644 --- a/extension/llm/sampler/targets.bzl +++ b/extension/llm/sampler/targets.bzl @@ -7,6 +7,7 @@ def define_common_targets(): runtime.cxx_library( name = "sampler" + aten_suffix, exported_headers = [ + "logit_processor.h", "sampler.h", "util.h", ], diff --git a/extension/llm/sampler/test/targets.bzl b/extension/llm/sampler/test/targets.bzl index 83b3d31e4cb..b95649ba4b9 100644 --- a/extension/llm/sampler/test/targets.bzl +++ b/extension/llm/sampler/test/targets.bzl @@ -22,3 +22,14 @@ def define_common_targets(): "//caffe2:torch-cpp", ], ) + + runtime.cxx_test( + name = "test_logit_processor", + srcs = [ + "test_logit_processor.cpp", + ], + deps = [ + "//executorch/extension/llm/sampler:sampler", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) diff --git a/extension/llm/sampler/test/test_logit_processor.cpp b/extension/llm/sampler/test/test_logit_processor.cpp new file mode 100644 index 00000000000..edc3838df9c --- /dev/null +++ b/extension/llm/sampler/test/test_logit_processor.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::extension::llm::LogitProcessor; +using ::executorch::runtime::Error; +using ::executorch::runtime::testing::TensorFactory; + +namespace { + +// Shared by the test processors below to advance to the last sequence +// position (rank-3 case), per the LogitProcessor shape contract. +inline float* float_data_at_last_position(Tensor logits) { + auto* data = logits.mutable_data_ptr(); + if (logits.dim() == 3) { + data += (logits.size(1) - 1) * logits.size(logits.dim() - 1); + } + return data; +} + +// Adds a fixed bias to every logit slot in the last position. Records how +// many times it was invoked so tests can verify chain ordering. +class AddBiasProcessor : public LogitProcessor { + public: + explicit AddBiasProcessor(float bias) : bias_(bias) {} + + Error process(Tensor logits) override { + ++call_count_; + if (logits.scalar_type() != ScalarType::Float) { + return Error::InvalidArgument; + } + auto* data = float_data_at_last_position(logits); + const auto vocab_size = logits.size(logits.dim() - 1); + for (ssize_t i = 0; i < vocab_size; ++i) { + data[i] += bias_; + } + return Error::Ok; + } + + int call_count() const { + return call_count_; + } + + private: + float bias_; + int call_count_ = 0; +}; + +class MultiplyProcessor : public LogitProcessor { + public: + explicit MultiplyProcessor(float factor) : factor_(factor) {} + + Error process(Tensor logits) override { + if (logits.scalar_type() != ScalarType::Float) { + return Error::InvalidArgument; + } + auto* data = float_data_at_last_position(logits); + const auto vocab_size = logits.size(logits.dim() - 1); + for (ssize_t i = 0; i < vocab_size; ++i) { + data[i] *= factor_; + } + return Error::Ok; + } + + private: + float factor_; +}; + +class MaskTokenProcessor : public LogitProcessor { + public: + explicit MaskTokenProcessor(int32_t banned_token) + : banned_token_(banned_token) {} + + Error process(Tensor logits) override { + if (logits.scalar_type() != ScalarType::Float) { + return Error::InvalidArgument; + } + auto* data = float_data_at_last_position(logits); + const auto vocab_size = logits.size(logits.dim() - 1); + if (banned_token_ >= 0 && banned_token_ < vocab_size) { + data[banned_token_] = -std::numeric_limits::infinity(); + } + return Error::Ok; + } + + private: + int32_t banned_token_; +}; + +} // namespace + +// A single processor mutates the rank-2 logits tensor in place. +TEST(LogitProcessorTest, SingleProcessorMutatesLogits) { + TensorFactory tf; + auto logits = tf.make({1, 4}, {1.0f, 2.0f, 3.0f, 4.0f}); + + AddBiasProcessor bias{10.0f}; + ASSERT_EQ(bias.process(logits), Error::Ok); + + auto* data = logits.mutable_data_ptr(); + EXPECT_FLOAT_EQ(data[0], 11.0f); + EXPECT_FLOAT_EQ(data[3], 14.0f); + EXPECT_EQ(bias.call_count(), 1); +} + +// Multiply(×2) then Add(+1) gives (x*2)+1, which differs from +// Add(+1) then Multiply(×2) = (x+1)*2. Non-commutative operations +// verify that processors run in registration order. +TEST(LogitProcessorTest, ProcessorChainAppliesInOrder) { + TensorFactory tf; + auto logits = tf.make({1, 4}, {1.0f, 2.0f, 3.0f, 4.0f}); + + std::vector> chain; + chain.push_back(std::make_shared(2.0f)); + chain.push_back(std::make_shared(1.0f)); + + for (auto& p : chain) { + ASSERT_EQ(p->process(logits), Error::Ok); + } + + // (x*2)+1, NOT (x+1)*2 + auto* data = logits.mutable_data_ptr(); + EXPECT_FLOAT_EQ(data[0], 3.0f); + EXPECT_FLOAT_EQ(data[1], 5.0f); + EXPECT_FLOAT_EQ(data[2], 7.0f); + EXPECT_FLOAT_EQ(data[3], 9.0f); +} + +// A masking processor sets a specific token's logit to -inf. This is the +// pattern grammar processors will follow. +TEST(LogitProcessorTest, MaskTokenDrivesArgmaxAway) { + TensorFactory tf; + auto logits = tf.make({1, 4}, {0.1f, 0.2f, 0.99f, 0.4f}); // argmax = 2 + + MaskTokenProcessor mask{/*banned_token=*/2}; + ASSERT_EQ(mask.process(logits), Error::Ok); + + auto* data = logits.mutable_data_ptr(); + EXPECT_EQ(data[2], -std::numeric_limits::infinity()); + // Other slots untouched. + EXPECT_FLOAT_EQ(data[0], 0.1f); + EXPECT_FLOAT_EQ(data[1], 0.2f); + EXPECT_FLOAT_EQ(data[3], 0.4f); +} + +// Out-of-range banned token id is silently ignored — defensive behavior +// for grammar processors that may pass an EOS-or-similar id that the +// underlying vocab doesn't actually contain. +TEST(LogitProcessorTest, MaskTokenOutOfRangeIsNoOp) { + TensorFactory tf; + auto logits = tf.make({1, 3}, {1.0f, 2.0f, 3.0f}); + const std::vector snapshot = {1.0f, 2.0f, 3.0f}; + + MaskTokenProcessor mask_over{/*banned_token=*/99}; + ASSERT_EQ(mask_over.process(logits), Error::Ok); + auto* data = logits.mutable_data_ptr(); + EXPECT_FLOAT_EQ(data[0], snapshot[0]); + EXPECT_FLOAT_EQ(data[1], snapshot[1]); + EXPECT_FLOAT_EQ(data[2], snapshot[2]); + + MaskTokenProcessor mask_neg{/*banned_token=*/-1}; + ASSERT_EQ(mask_neg.process(logits), Error::Ok); + EXPECT_FLOAT_EQ(data[0], snapshot[0]); + EXPECT_FLOAT_EQ(data[1], snapshot[1]); + EXPECT_FLOAT_EQ(data[2], snapshot[2]); +} + +// On a rank-3 [batch, seq, vocab] tensor, the processor must only mutate +// the LAST sequence position. Earlier positions stay untouched. +TEST(LogitProcessorTest, RespectsLastPositionOf3DTensor) { + TensorFactory tf; + // Shape [batch=1, seq=2, vocab=4]. First-position values are sentinels. + auto logits = tf.make( + {1, 2, 4}, + { + 99.0f, 99.0f, 99.0f, 99.0f, // first position — must NOT change + 1.0f, 2.0f, 3.0f, 4.0f, // last position — gets +10 from bias + }); + + AddBiasProcessor bias{10.0f}; + ASSERT_EQ(bias.process(logits), Error::Ok); + + auto* data = logits.mutable_data_ptr(); + // First position untouched. + EXPECT_FLOAT_EQ(data[0], 99.0f); + EXPECT_FLOAT_EQ(data[3], 99.0f); + // Last position got +10. + EXPECT_FLOAT_EQ(data[4], 11.0f); + EXPECT_FLOAT_EQ(data[7], 14.0f); +} + +// Each processor declares its own dtype expectations. The test processors +// here only support Float; passing a Half tensor must surface +// InvalidArgument rather than silently corrupt memory. +TEST(LogitProcessorTest, ProcessorRejectsUnsupportedDtype) { + TensorFactory tf; + auto logits = tf.make({1, 4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + AddBiasProcessor bias{1.0f}; + EXPECT_EQ(bias.process(logits), Error::InvalidArgument); +}