Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d69117c
Unify MultimodalRunner under IRunner with multimodal prefill
kirklandsign Feb 26, 2026
222f76d
Preserve old prefill(string, GenerationConfig) API, fix minor issues
kirklandsign Feb 26, 2026
5cd9cf5
Remove prefill_next_token_ — prefill just fills KV cache
kirklandsign Feb 26, 2026
52d8eb9
Restore removed comments in multimodal_runner.cpp
kirklandsign Feb 26, 2026
e7e5562
Restore removed comments in text_llm_runner.cpp
kirklandsign Feb 26, 2026
297ad1b
Remove tokens input handling from TextLLMRunner::prefill
kirklandsign Feb 26, 2026
06963e2
Restore and improve docstrings in multimodal_runner.h
kirklandsign Feb 26, 2026
8a22d1e
Fix lint formatting in irunner.h and multimodal_runner.cpp
kirklandsign Feb 26, 2026
6ab9e7c
Fix docstring: TextLLMRunner::prefill only handles text inputs
kirklandsign Feb 26, 2026
e0ee4dc
Fix misleading prefill docstring in _llm_runner.pyi
kirklandsign Feb 26, 2026
37159cd
Fix BOS zeroing in TextLLMRunner::prefill for skipped inputs
kirklandsign Feb 26, 2026
bcc7616
Fix build: forward-declare MultimodalInput in irunner.h
kirklandsign Feb 26, 2026
f03b171
Store prefill output token to enable prefill() → generate("") workflow
kirklandsign Feb 26, 2026
dd91e4f
Change prefill() to return Result<uint64_t>
kirklandsign Feb 26, 2026
f865cff
Update prefill @return docstring in irunner.h
kirklandsign Feb 26, 2026
72d6034
Extract decode_from_token to eliminate duplicate decode loop
kirklandsign Feb 26, 2026
c05c80a
Move string prefill overloads to .cpp, fix stale prefill_next_token_
kirklandsign Feb 26, 2026
2c44df7
Fix lint formatting in text_llm_runner.h and .cpp
kirklandsign Feb 26, 2026
e068f7e
Simplify generate(string) to always delegate to generate(vector)
kirklandsign Feb 26, 2026
9e6e103
Fix generate(string) docstring and initialize cur_token
kirklandsign Feb 26, 2026
e83bd7e
Fix lint formatting in text_llm_runner.cpp
kirklandsign Feb 26, 2026
7e944e9
Add unit tests for prefill() and prefill-then-generate workflow
kirklandsign Feb 26, 2026
cb1284a
Fix lint formatting in test_text_llm_runner.cpp
kirklandsign Feb 26, 2026
736fd9b
Merge branch 'main' into llm/unify-runner-prefill-interface
kirklandsign Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion extension/llm/runner/_llm_runner.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ class MultimodalRunner:
def prefill(self, inputs: List[MultimodalInput]) -> None:
"""
Prefill multimodal inputs (e.g., to rebuild KV cache from chat history)
without generating tokens.
without generating tokens. After prefill, call generate() with a
non-empty final text input to start decoding.

Args:
inputs: List of multimodal inputs to prefill
Expand Down
20 changes: 20 additions & 0 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@
#include <functional>
#include <memory>
#include <string>
#include <vector>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>

namespace executorch {
namespace extension {
namespace llm {

class MultimodalInput; // Forward declaration

Comment thread
kirklandsign marked this conversation as resolved.
// Configuration struct for generation parameters, fields should be sorted in
// alphabetic order
struct GenerationConfig {
Expand Down Expand Up @@ -128,6 +132,22 @@ class ET_EXPERIMENTAL IRunner {
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;

/**
* Prefill multimodal inputs into the KV cache without generating.
*
* @param inputs A vector of MultimodalInput objects (text, tokens, images,
* audio)
* @param num_bos Number of BOS tokens to prepend during encoding
* @param num_eos Number of EOS tokens to append during encoding
* @return The next token predicted after prefill, or an error
*/
virtual runtime::Result<uint64_t> prefill(
const std::vector<MultimodalInput>& inputs,
int32_t num_bos = 0,
int32_t num_eos = 0) {
return runtime::Error::NotSupported;
}
Comment thread
kirklandsign marked this conversation as resolved.

/**
* Stop the generation process.
*/
Expand Down
175 changes: 106 additions & 69 deletions extension/llm/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MultimodalRunner::MultimodalRunner(
#endif
}

bool MultimodalRunner::is_loaded() {
bool MultimodalRunner::is_loaded() const {
return multimodal_prefiller_->is_method_loaded() &&
text_token_generator_->is_loaded();
}
Expand Down Expand Up @@ -85,89 +85,57 @@ Error MultimodalRunner::load() {
ET_LOG(Info, format, __VA_ARGS__); \
}

Error MultimodalRunner::prefill(const std::vector<MultimodalInput>& inputs) {
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}
for (auto& input : inputs) {
auto prefill_result = multimodal_prefiller_->prefill(input, pos_);
if (!prefill_result.ok()) {
return prefill_result.error();
}
}
return Error::Ok;
}

Error MultimodalRunner::generate(
Result<uint64_t> MultimodalRunner::prefill(
const std::vector<MultimodalInput>& inputs,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
if (inputs.empty()) {
ET_LOG(Error, "MultimodalInput vector cannot be empty");
return Error::InvalidArgument;
}

int32_t num_bos,
int32_t num_eos) {
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}
Comment thread
kirklandsign marked this conversation as resolved.

if (config.warming) {
ET_LOG(Info, "Doing a warmup run...");
}

RUNNER_ET_LOG(
config.warming,
"RSS after loading model: %f MiB (0 if unsupported)",
get_rss_bytes() / 1024.0 / 1024.0);

// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback, config](const std::string& piece) {
if (!config.warming) {
safe_printf(piece.c_str());
fflush(stdout);
}
if (token_callback) {
token_callback(piece);
}
};

// Reset internal state and start inference
stats_->inference_start_ms = time_in_ms();

uint64_t prefill_next_token = 0;
// Process multimodal inputs in order
uint64_t last_token = 0;
for (size_t i = 0; i < inputs.size(); ++i) {
const MultimodalInput& input = inputs[i];
ET_LOG(
Info,
"Prefilling input %zu/%zu, type: %s",
i,
inputs.size(),
input.type_name());
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
wrapped_callback(input.get_text());
}
const auto& input = inputs[i];
int32_t bos = 0;
int32_t eos = 0;
if (i == 0 && input.is_text()) {
bos = config.num_bos;
eos = config.num_eos;
if (i == 0 && pos_ == 0) {
if (input.is_text() || input.is_tokens()) {
bos = num_bos;
eos = num_eos;
} else if (num_bos > 0) {
// Non-text first input: prepend BOS via a token input
auto it = metadata_.find(kBosId);
if (it != metadata_.end()) {
std::vector<uint64_t> bos_tokens(
num_bos, static_cast<uint64_t>(it->second));
MultimodalInput bos_input(std::move(bos_tokens));
auto bos_result = multimodal_prefiller_->prefill(bos_input, pos_);
if (!bos_result.ok()) {
return bos_result.error();
}
last_token = bos_result.get();
}
Comment thread
kirklandsign marked this conversation as resolved.
}
Comment thread
kirklandsign marked this conversation as resolved.
}
auto prefill_result = multimodal_prefiller_->prefill(input, pos_, bos, eos);
if (!prefill_result.ok()) {
return prefill_result.error();
}
prefill_next_token = prefill_result.get();
last_token = prefill_result.get();
}
prefill_next_token_ = last_token;
return last_token;
}

Error MultimodalRunner::decode_from_token(
uint64_t cur_token,
const GenerationConfig& config,
std::function<void(const std::string&)> wrapped_callback,
std::function<void(const Stats&)> stats_callback) {
stats_->first_token_ms = time_in_ms();
stats_->prompt_eval_end_ms = time_in_ms();
stats_->num_prompt_tokens = pos_;

auto decode_result =
tokenizer_->decode(prefill_next_token, prefill_next_token);
auto decode_result = tokenizer_->decode(cur_token, cur_token);
if (!decode_result.ok()) {
ET_LOG(
Error,
Expand All @@ -183,8 +151,7 @@ Error MultimodalRunner::generate(
get_rss_bytes() / 1024.0 / 1024.0);

// Resolve max_new_tokens based on config
int64_t max_context_len =
metadata_.at(kMaxContextLen) - 0; // No start_pos offset
int64_t max_context_len = metadata_.at(kMaxContextLen);
int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);

ET_LOG(
Expand All @@ -204,7 +171,7 @@ Error MultimodalRunner::generate(
text_token_generator_->set_ignore_eos(config.ignore_eos);

// Generate tokens using the text token generator
std::vector<uint64_t> prompt_tokens = {prefill_next_token};
std::vector<uint64_t> prompt_tokens = {cur_token};
auto generate_result = text_token_generator_->generate(
/*tokens=*/prompt_tokens,
/*start_pos=*/pos_,
Expand Down Expand Up @@ -249,4 +216,74 @@ Error MultimodalRunner::generate(
return Error::Ok;
}

Error MultimodalRunner::generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
std::vector<MultimodalInput> inputs;
if (!prompt.empty()) {
inputs.emplace_back(MultimodalInput(prompt));
}
return generate(inputs, config, token_callback, stats_callback);
}

Error MultimodalRunner::generate(
const std::vector<MultimodalInput>& inputs,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}

if (config.warming) {
ET_LOG(Info, "Doing a warmup run...");
}

RUNNER_ET_LOG(
config.warming,
"RSS after loading model: %f MiB (0 if unsupported)",
get_rss_bytes() / 1024.0 / 1024.0);

// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback, config](const std::string& piece) {
if (!config.warming) {
safe_printf(piece.c_str());
fflush(stdout);
}
if (token_callback) {
token_callback(piece);
}
};

// Reset internal state and start inference
stats_->inference_start_ms = time_in_ms();

uint64_t cur_token = 0;
if (!inputs.empty()) {
// Echo the last text input if enabled
if (config.echo && inputs.back().is_text()) {
wrapped_callback(inputs.back().get_text());
}

// Prefill all inputs and get the first decode token
auto prefill_result = prefill(inputs, config.num_bos, config.num_eos);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error());
cur_token = prefill_result.get();
prefill_next_token_.reset();
} else {
// Empty inputs: consume token from a prior prefill() call
ET_CHECK_OR_RETURN_ERROR(
prefill_next_token_.has_value(),
InvalidState,
"Empty inputs requires a prior prefill() call");
cur_token = prefill_next_token_.value();
prefill_next_token_.reset();
}

return decode_from_token(cur_token, config, wrapped_callback, stats_callback);
}

} // namespace executorch::extension::llm
63 changes: 53 additions & 10 deletions extension/llm/runner/multimodal_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>

Expand Down Expand Up @@ -74,7 +75,7 @@ namespace llm {
*
* runner->generate(inputs, config, token_callback, stats_callback);
*/
class ET_EXPERIMENTAL MultimodalRunner {
class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
public:
/**
* @brief Constructor for MultimodalRunner with dependency injection
Expand Down Expand Up @@ -105,8 +106,24 @@ class ET_EXPERIMENTAL MultimodalRunner {
std::unique_ptr<TextTokenGenerator> text_token_generator,
std::unique_ptr<Stats> stats);

virtual bool is_loaded();
virtual ::executorch::runtime::Error load();
bool is_loaded() const override;
::executorch::runtime::Error load() override;

/**
* Generate tokens from a text prompt. Wraps the prompt as a MultimodalInput
* and delegates to generate(vector). Empty prompt is allowed if prefill()
* was called beforehand.
* @param prompt The text prompt to generate from.
Comment thread
kirklandsign marked this conversation as resolved.
* @param config Generation configuration parameters.
* @param token_callback Callback function called for each generated token.
* @param stats_callback Callback function for generation statistics.
* @return The error code. KV cache position is tracked internally in pos_.
*/
::executorch::runtime::Error generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* Generate tokens from the given multimodal inputs using GenerationConfig.
Expand All @@ -124,24 +141,42 @@ class ET_EXPERIMENTAL MultimodalRunner {
std::function<void(const Stats&)> stats_callback = {});

/**
* Prefill multimodal inputs, for example to reload chat history.
* Prefill multimodal inputs to fill the KV cache, for example to reload
* chat history. Call generate() with a non-empty prompt afterwards to
* start decoding.
Comment thread
kirklandsign marked this conversation as resolved.
Comment on lines +145 to +146
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation comment states "Call generate() with a non-empty prompt afterwards to start decoding." However, the implementation actually supports calling generate() with an empty prompt (empty inputs vector) which will consume the prefill_next_token_. The comment should be updated to reflect this capability, e.g., "Call generate() afterwards to start decoding (can use empty inputs to consume the prefilled token)."

Suggested change
* chat history. Call generate() with a non-empty prompt afterwards to
* start decoding.
* chat history. Call generate() afterwards to start decoding (you may use
* empty inputs to consume the prefilled token).

Copilot uses AI. Check for mistakes.
* @param inputs A vector of MultimodalInput objects containing images and
* text.
* @return The error code. KV cache position is tracked internally in pos_.
* @param num_bos Number of BOS tokens to prepend during encoding.
* @param num_eos Number of EOS tokens to append during encoding.
* @return The next token predicted after prefill, or an error.
* KV cache position is tracked internally in pos_.
*/
::executorch::runtime::Result<uint64_t> prefill(
const std::vector<MultimodalInput>& inputs,
int32_t num_bos = 0,
int32_t num_eos = 0) override;

/**
* Convenience overload: prefill a single text prompt.
*/
virtual ::executorch::runtime::Error prefill(
const std::vector<MultimodalInput>& inputs);
::executorch::runtime::Result<uint64_t>
prefill(const std::string& prompt, int32_t num_bos = 0, int32_t num_eos = 0) {
std::vector<MultimodalInput> inputs;
inputs.emplace_back(MultimodalInput(prompt));
return prefill(inputs, num_bos, num_eos);
}

inline void stop() {
void stop() override {
text_token_generator_->stop();
}

inline void reset() {
void reset() override {
pos_ = 0;
stats_->reset();
prefill_next_token_.reset();
}

virtual ~MultimodalRunner() = default;
~MultimodalRunner() override = default;

protected:
// Components
Expand All @@ -160,7 +195,15 @@ class ET_EXPERIMENTAL MultimodalRunner {
#endif

// Internal state
std::optional<uint64_t> prefill_next_token_;
int64_t pos_;

private:
::executorch::runtime::Error decode_from_token(
uint64_t cur_token,
const GenerationConfig& config,
std::function<void(const std::string&)> wrapped_callback,
std::function<void(const Stats&)> stats_callback);
};

} // namespace llm
Expand Down
Loading
Loading