Skip to content

Unify MultimodalRunner under IRunner with multimodal prefill#17741

Merged
kirklandsign merged 24 commits intomainfrom
llm/unify-runner-prefill-interface
Feb 27, 2026
Merged

Unify MultimodalRunner under IRunner with multimodal prefill#17741
kirklandsign merged 24 commits intomainfrom
llm/unify-runner-prefill-interface

Conversation

@kirklandsign
Copy link
Contributor

@kirklandsign kirklandsign commented Feb 26, 2026

Summary

Make MultimodalRunner inherit from IRunner so callers can hold a single IRunner* regardless of model type. Add prefill(vector<MultimodalInput>, num_bos, num_eos) to IRunner returning Result<uint64_t> (the predicted next token), with a default NotSupported implementation.

Resolves #17728

Changes

irunner.h — Forward-declare MultimodalInput. Add virtual prefill(vector<MultimodalInput>) returning Result<uint64_t> with default NotSupported.

multimodal_runner.h/cppMultimodalRunner now inherits IRunner. generate(string) override is a pure wrapper that delegates to generate(vector). generate(vector) handles both non-empty inputs (prefill + decode) and empty inputs (consume prefill_next_token_ from a prior prefill() call). Decode loop extracted into private decode_from_token() to avoid duplication. is_loaded() becomes const override, stop()/reset()/load() gain override. String convenience prefill(string) provided inline.

text_llm_runner.h/cpp — New prefill(vector<MultimodalInput>) override handles text inputs (encode + prefill KV cache), returns predicted next token. generate("") allowed after prefill() — consumes stored prefill_next_token_. Old prefill(string, GenerationConfig) preserved as deprecated wrapper. String convenience methods defined in .cpp to avoid header dependency on multimodal_input.h.

pybindings.cpp — Adapts to Result<uint64_t> return type from prefill().

_llm_runner.pyi — Updated MultimodalRunner.prefill docstring.

Design decisions

  • prefill() returns Result<uint64_t> — the sampled next token from the final forward pass. This is stored internally in prefill_next_token_ for the prefill()generate("")/generate({}) workflow, and also returned to callers who may want the token directly.
  • prefill() takes num_bos/num_eos instead of GenerationConfig — those are the only fields relevant to prefill (for tokenizer encoding).
  • BOS is only applied when pos_ == 0 (start of conversation) in MultimodalRunner. TextLLMRunner trusts the caller's num_bos value.
  • generate(string) is always a pure wrapper to generate(vector) in MultimodalRunner — empty string passes an empty vector, non-empty wraps as MultimodalInput.
  • text_llm_runner.h avoids #include multimodal_input.h — only the forward declaration from irunner.h is needed in the header. String convenience methods are defined in the .cpp.

Backward compatibility

  • C++ source-compatible: MultimodalRunner::prefill(inputs) still compiles (new params have defaults). TextLLMRunner::prefill(string, GenerationConfig) preserved as deprecated wrapper. Return type changed from Error to Result<uint64_t> — callers checking .error() still work.
  • ABI-breaking: Expected for ET_EXPERIMENTAL APIs.
  • Python: Fully compatible, interface signatures unchanged.

Test plan

  • Existing C++ tests: test_text_llm_runner.cpp, test_text_prefiller.cpp, test_generation_config.cpp — verify no regressions
  • Build: cmake --build for extension_llm_runner target

Make MultimodalRunner inherit from IRunner so callers can hold a single
IRunner* regardless of model type. Add prefill(vector<MultimodalInput>,
num_bos, num_eos) to IRunner with a default NotSupported implementation,
plus a non-virtual string convenience overload.

Both TextLLMRunner and MultimodalRunner now store the predicted next
token from prefill in prefill_next_token_, enabling a prefill-then-
generate("") workflow: callers prefill chat history / multimodal inputs
in one call, then start decoding with an empty prompt.

Key changes:
- IRunner gains virtual prefill() accepting MultimodalInput vector
- TextLLMRunner::prefill accepts MultimodalInput (text/tokens only)
- TextLLMRunner::generate handles empty prompt via prefill_next_token_
- MultimodalRunner inherits IRunner, overrides generate/prefill/stop/reset
- MultimodalRunner::generate(string) delegates or decodes from prefill
- MultimodalRunner::generate(vector) refactored to call prefill() internally
- Pybindings updated to use new prefill signatures

This PR was authored with the assistance of Claude.
- Add back TextLLMRunner::prefill(string, GenerationConfig) as an
  inline deprecated wrapper that delegates to prefill(string, num_bos,
  num_eos), preserving source compatibility for existing callers
- Initialize cur_token and num_prompt_tokens to 0 to silence compiler
  warnings about potentially uninitialized variables
- Add has_value() guard before consuming prefill_next_token_ in
  MultimodalRunner::generate(vector) for defensive safety

This PR was authored with the assistance of Claude.
prefill() semantics: only fill KV cache and update pos_, discard the
sampled next token. The worst case of discarding is wasting one
logits_to_token call (trivially cheap) — the KV cache entries are
preserved via pos_ tracking.

Callers use prefill(history) → generate(last_turn), where generate()
does its own internal prefill to get the first decode token.

To avoid duplicating the prefill loop between prefill() and
generate(vector) in MultimodalRunner, extract a private
prefill_and_sample() that returns Result<uint64_t>. prefill() calls
it and discards the token; generate(vector) calls it and uses it.

This PR was authored with the assistance of Claude.
This PR was authored with the assistance of Claude.
This PR was authored with the assistance of Claude.
TextLLMRunner only handles text inputs; skip everything else.
Raw token prefill can be added later if needed.

This PR was authored with the assistance of Claude.
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17741

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Pending

As of commit cb1284a with merge base a29539d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 26, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Add back docstrings for generate(vector) and prefill(vector) that were
removed during refactoring. Also add docstring for the new
generate(string) override.

This PR was authored with the assistance of Claude.
@kirklandsign kirklandsign marked this pull request as ready for review February 26, 2026 19:32
Copilot AI review requested due to automatic review settings February 26, 2026 19:32
This PR was authored with the assistance of Claude.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Unifies text-only and multimodal LLM runners behind a single IRunner interface by adding a multimodal prefill() API to IRunner and updating both runners/bindings to use it, enabling callers to hold a single IRunner* regardless of model type.

Changes:

  • Added IRunner::prefill(vector<MultimodalInput>, num_bos, num_eos) with a default NotSupported implementation plus a string convenience overload.
  • Updated MultimodalRunner to inherit from IRunner, add a generate(string) override, and refactor shared prefill logic into a helper.
  • Updated TextLLMRunner, Python bindings, and Python stubs to match the new prefill shape/signatures.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
extension/llm/runner/irunner.h Adds multimodal prefill API (virtual) + string convenience overload to unify runner interface.
extension/llm/runner/text_llm_runner.h Declares the new prefill(vector<MultimodalInput>, num_bos, num_eos) override and keeps the legacy prefill(string, GenerationConfig) wrapper.
extension/llm/runner/text_llm_runner.cpp Implements multimodal prefill for the text runner by iterating inputs and prefilling supported modalities.
extension/llm/runner/multimodal_runner.h Makes MultimodalRunner inherit IRunner, adds overrides, and introduces a shared private helper for prefill+first-token sampling.
extension/llm/runner/multimodal_runner.cpp Refactors multimodal prefill/generate to share prefill logic and adds generate(string) wrapper.
extension/llm/runner/pybindings.cpp Updates Python bindings to call the updated prefill signatures.
extension/llm/runner/_llm_runner.pyi Updates Python stub documentation for prefill usage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

This PR was authored with the assistance of Claude.
The old wording suggested calling generate("") after prefill, but
the Python API takes List[MultimodalInput] and the C++ runner
enforces a non-empty prompt. Updated to describe the correct
post-prefill pattern.

This PR was authored with the assistance of Claude.
Only zero num_bos/num_eos after actually processing a text input,
not after skipping a non-text input. Previously, if a non-text input
preceded a text input, the BOS would be lost.

This PR was authored with the assistance of Claude.
irunner is a header-only Buck target with no deps, so it cannot
include multimodal_input.h (owned by multimodal_runner_lib which
requires aten-variant deps). Use a forward declaration instead.

Move the string convenience prefill(prompt, num_bos, num_eos) from
IRunner to TextLLMRunner and MultimodalRunner, where
multimodal_input.h is available for constructing MultimodalInput.

This PR was authored with the assistance of Claude.
Copilot AI review requested due to automatic review settings February 26, 2026 20:51
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +143 to +148
virtual runtime::Error prefill(
const std::vector<MultimodalInput>& inputs,
int32_t num_bos = 0,
int32_t num_eos = 0) {
return runtime::Error::NotSupported;
}
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IRunner now exposes prefill(vector, num_bos, num_eos), but the PR description mentions a non-virtual convenience overload prefill(string, ...) on IRunner and it is not present here. If callers are expected to work through an IRunner* with text-only prompts, consider adding that overload to IRunner (or updating the PR description/usage guidance accordingly).

Copilot uses AI. Check for mistakes.
Comment on lines +254 to +256
auto encode_res = tokenizer_->encode(
input.get_text(), /*bos=*/num_bos, /*eos=*/num_eos);
ET_CHECK_TK_OK_OR_RETURN_ERROR(
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TextLLMRunner::prefill() applies num_bos/num_eos to the first text input unconditionally. If prefill() is called after prior prefill/generate (pos_ > 0), this will inject BOS/EOS into the middle of the conversation. Consider gating BOS/EOS application on pos_ == 0 (then clearing them) to avoid mid-context special tokens.

Copilot uses AI. Check for mistakes.
Comment on lines +252 to +256
for (const auto& input : inputs) {
if (input.is_text()) {
auto encode_res = tokenizer_->encode(
input.get_text(), /*bos=*/num_bos, /*eos=*/num_eos);
ET_CHECK_TK_OK_OR_RETURN_ERROR(
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TextLLMRunner::prefill() currently processes only input.is_text() and skips token inputs (input.is_tokens()). Since MultimodalInput supports TOKENS, this prevents prefilling already-tokenized prompts. Consider handling is_tokens() by copying input.get_tokens() into a mutable vector and passing it to text_prefiller_->prefill().

Copilot uses AI. Check for mistakes.
Comment on lines +244 to +247
Error TextLLMRunner::prefill(
const std::string& prompt,
const GenerationConfig& config) {
const std::vector<MultimodalInput>& inputs,
int32_t num_bos,
int32_t num_eos) {
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior in TextLLMRunner::prefill(vector, num_bos, num_eos) isn’t covered by existing runner tests (only generate()/warmup() are exercised). Consider adding unit tests that verify: (1) text inputs advance pos_, (2) non-text inputs are skipped, (3) BOS is only applied when pos_==0, and (4) token inputs are handled if supported.

Copilot uses AI. Check for mistakes.
Error MultimodalRunner::prefill(
const std::vector<MultimodalInput>& inputs,
int32_t num_bos,
int32_t num_eos) {
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultimodalRunner::prefill() currently treats an empty inputs vector as success (no-op). For consistency with generate() (which rejects empty inputs), consider validating !inputs.empty() here and returning InvalidArgument, or documenting that empty prefill is an intentional no-op.

Suggested change
int32_t num_eos) {
int32_t num_eos) {
ET_CHECK_OR_RETURN_ERROR(
!inputs.empty(), InvalidArgument, "Inputs cannot be empty");

Copilot uses AI. Check for mistakes.
Both TextLLMRunner and MultimodalRunner now store the sampled next
token from prefill() in prefill_next_token_. When generate() is
called with an empty prompt, it consumes this token and starts
decoding directly without re-prefilling.

This enables the workflow:
  runner->prefill("system prompt", 1, 0);
  runner->prefill("user turn", 0, 0);
  runner->generate("", config, callback);  // decode from KV cache

This PR was authored with the assistance of Claude.
prefill() now returns the sampled next token directly, making the
API more natural. Callers get the token if they want it, and
internally it's also stored in prefill_next_token_ for the
generate("") workflow.

This eliminates the need for the private prefill_and_sample() helper
since prefill() itself returns the token — generate(vector) can
call prefill() directly.

This PR was authored with the assistance of Claude.
Copilot AI review requested due to automatic review settings February 26, 2026 21:33
This PR was authored with the assistance of Claude.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 9 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

py::gil_scoped_release release;
Error error = runner_->prefill(inputs);
THROW_IF_ERROR(error, "Prefill failed");
auto result = runner_->prefill(inputs, 0, 0);
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python binding for MultimodalRunner.prefill() hardcodes num_bos=0 and num_eos=0, preventing Python users from specifying BOS/EOS tokens during prefill. This is inconsistent with TextLLMRunner.prefill() which accepts a GenerationConfig parameter (line 118-127) to allow specifying num_bos/num_eos. For API consistency, MultimodalRunner.prefill() should also accept a GenerationConfig parameter or at least optional num_bos/num_eos parameters.

Copilot uses AI. Check for mistakes.
Both generate(string) empty-prompt path and generate(vector) shared
an identical decode loop (decode first token, resolve max_new_tokens,
run text_token_generator, update stats). Extract this into a private
decode_from_token() method called by both paths.

This PR was authored with the assistance of Claude.
Move the two string convenience prefill methods from inline in
text_llm_runner.h to text_llm_runner.cpp. This removes the header's
dependency on multimodal_input.h — only a forward declaration of
MultimodalInput (from irunner.h) is needed in the header.

Also fix a bug where MultimodalRunner::generate(vector) left
prefill_next_token_ set after consuming the token. A subsequent
generate("") would incorrectly reuse the stale token. Clear it
after extracting the value.

This PR was authored with the assistance of Claude.
Copilot AI review requested due to automatic review settings February 26, 2026 21:53
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

/**
* Generate tokens from a text prompt. Wraps the prompt as a MultimodalInput
* and delegates to generate(vector).
* @param prompt The text prompt to generate from. Must be non-empty.
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment at line 115 says "Must be non-empty", but the implementation actually supports empty prompts starting at line 230. This generates a discrepancy between the documented behavior and the actual implementation. Consider updating the documentation to accurately reflect that empty prompts are allowed when there is a prior prefill() call.

Copilot uses AI. Check for mistakes.
This PR was authored with the assistance of Claude.
generate(string) is now a pure wrapper — empty string passes an
empty vector, non-empty wraps as MultimodalInput. The "decode from
prior prefill" logic moves to generate(vector) when inputs is empty,
giving both overloads consistent semantics.

This PR was authored with the assistance of Claude.
Copilot AI review requested due to automatic review settings February 26, 2026 22:20
- Update docstring: empty prompt is allowed after prefill()
- Initialize cur_token = 0 to silence compiler warnings

This PR was authored with the assistance of Claude.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

extension/llm/runner/text_llm_runner.cpp:175

  • max_context_len is computed as metadata_.at(kMaxContextLen) - pos_ (remaining capacity), but later it’s used with GenerationConfig::resolve_max_new_tokens(), which is documented in irunner.h as taking the model’s maximum context length. This becomes incorrect when config.seq_len is set and pos_ is already non-zero (e.g., after history prefill), because the seq_len constraint can be violated. Consider passing the full metadata_.at(kMaxContextLen) and using the total prompt length (e.g., current pos_ after prefill) as the num_prompt_tokens argument.
  // Determine max_new_tokens using the GenerationConfig's resolve method,
  // then subtract pos_ for max_new_tokens.
  int max_new_tokens =
      config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

This PR was authored with the assistance of Claude.
Three new tests:
- PrefillReturnsNextToken: verifies prefill() returns the predicted
  next token from the mock text_prefiller
- PrefillThenGenerateEmpty: verifies the prefill() → generate("")
  workflow produces the expected number of tokens
- GenerateEmptyWithoutPrefillFails: verifies generate("") without
  prior prefill() returns InvalidState error

This PR was authored with the assistance of Claude.
This PR was authored with the assistance of Claude.
Copilot AI review requested due to automatic review settings February 26, 2026 22:36
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +258 to +288
Result<uint64_t> TextLLMRunner::prefill(
const std::vector<MultimodalInput>& inputs,
int32_t num_bos,
int32_t num_eos) {
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}

::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
prompt,
/*bos=*/config.num_bos,
/*eos=*/config.num_eos);
for (const auto& input : inputs) {
if (input.is_text()) {
auto encode_res = tokenizer_->encode(
input.get_text(), /*bos=*/num_bos, /*eos=*/num_eos);
ET_CHECK_TK_OK_OR_RETURN_ERROR(
encode_res.error(),
"Failed to encode prompt %s",
input.get_text().c_str());
std::vector<uint64_t> tokens = encode_res.get();
auto prefill_res = text_prefiller_->prefill(tokens, pos_);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
prefill_next_token_ = prefill_res.get();
num_bos = 0;
num_eos = 0;
}
// Skip non-text inputs — text-only runner
}

if (!prefill_next_token_.has_value()) {
return Error::InvalidArgument;
}
return prefill_next_token_.value();
}
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefill() method returns InvalidArgument when inputs contains no text elements, but the error message would be generic and not indicate the actual problem (empty or non-text inputs). Consider adding an explicit check at the start for empty inputs, and providing a more descriptive error message like "No text inputs to prefill" when the loop completes without processing any text.

Copilot uses AI. Check for mistakes.
Comment on lines +145 to +146
* chat history. Call generate() with a non-empty prompt afterwards to
* start decoding.
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation comment states "Call generate() with a non-empty prompt afterwards to start decoding." However, the implementation actually supports calling generate() with an empty prompt (empty inputs vector) which will consume the prefill_next_token_. The comment should be updated to reflect this capability, e.g., "Call generate() afterwards to start decoding (can use empty inputs to consume the prefilled token)."

Suggested change
* chat history. Call generate() with a non-empty prompt afterwards to
* start decoding.
* chat history. Call generate() afterwards to start decoding (you may use
* empty inputs to consume the prefilled token).

Copilot uses AI. Check for mistakes.
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Feb 27, 2026

@kirklandsign has imported this pull request. If you are a Meta employee, you can view this in D94621001.

@kirklandsign kirklandsign merged commit 67bc28b into main Feb 27, 2026
167 of 170 checks passed
@kirklandsign kirklandsign deleted the llm/unify-runner-prefill-interface branch February 27, 2026 05:57
kirklandsign added a commit that referenced this pull request Feb 27, 2026
Replace the dual-runner pattern (runner_ + multi_modal_runner_) with a
single IRunner* that holds either TextLLMRunner or MultimodalRunner,
leveraging MultimodalRunner's new IRunner inheritance from #17741.

Each prefill method (text, images, audio) now immediately calls
IRunner::prefill(vector<MultimodalInput>) instead of buffering inputs
for later consumption by generate(). A needs_bos_ flag tracks whether
the next prefill should apply BOS tokens — MultimodalRunner also
guards this via pos_==0 internally, but TextLLMRunner trusts the caller.

generate(), stop(), load(), and reset() no longer branch on
model_type_category_; all dispatch through the unified runner_.

Rename all JNI native methods from append* to prefill* to match the
existing Java public API naming.
larryliu0820 added a commit that referenced this pull request Mar 2, 2026
The IRunner refactoring in #17741 split generate() into generate() +
decode_from_token(), adding an extra call frame at the deepest point of
execution. decode_from_token() also took two std::function parameters by
value (128 bytes of copies on MSVC). Combined with the deep AOTI CUDA
call stack, this exceeded the default 1 MB Windows thread stack.

Fix by:
- Passing std::function params to decode_from_token by const ref
- Increasing the voxtral_runner stack to 8 MB on Windows (matching
  the Linux default)
- Printing runner stderr in the Windows E2E test for diagnostics

This PR was authored with the assistance of Claude.
larryliu0820 added a commit that referenced this pull request Mar 3, 2026
The IRunner refactoring in #17741 added a redundant is_loaded() check
in MultimodalRunner::prefill() that is already performed by generate().
The is_loaded() path calls MultimodalPrefiller::is_method_loaded()
which invokes module_->method_names() — this second call corrupts state
and causes STATUS_STACK_BUFFER_OVERRUN (0xC0000409) on Windows.

Fix by:
- Removing the redundant is_loaded() guard from prefill() (callers
  like generate() already ensure the model is loaded)
- Passing std::function params to decode_from_token by const ref
- Increasing voxtral_runner stack to 8 MB on Windows as a safety net
- Printing runner stderr in the Windows E2E test for diagnostics

This PR was authored with the assistance of Claude.
larryliu0820 added a commit that referenced this pull request Mar 3, 2026
The IRunner refactoring in #17741 split generate() into separate
prefill() and decode_from_token() calls. On Windows, calling any
sub-method from generate() triggers STATUS_STACK_BUFFER_OVERRUN
(0xC0000409) — this appears to be a Windows-specific issue with the
function call pattern (confirmed by SSH debugging: inlining the prefill
loop works, but calling it as a method crashes even with 8 MB stack).

Fix by restoring the monolithic generate(vector, ...) implementation
that keeps all prefill and decode logic inline, matching the pre-#17741
pattern that works on Windows. The separate prefill() and
decode_from_token() methods are retained for external callers and the
prefill-then-generate workflow.

Also:
- Pass std::function params to decode_from_token by const ref
- Increase voxtral_runner stack to 8 MB on Windows as a safety net
- Print runner stderr in the Windows E2E test for diagnostics

This PR was authored with the assistance of Claude.
larryliu0820 added a commit that referenced this pull request Mar 3, 2026
## Summary

The IRunner refactoring in #17741 split `MultimodalRunner::generate()`
into
separate `prefill()` and `decode_from_token()` method calls. On Windows
with
CUDA, this triggers `STATUS_STACK_BUFFER_OVERRUN` (0xC0000409) —
confirmed
via SSH debugging that inlining the logic works but calling sub-methods
from
  `generate()` crashes, even with an 8 MB stack.

- Restore the monolithic `generate(vector, ...)` implementation with all
    prefill and decode logic inline, matching the pre-#17741 pattern
- The separate `prefill()` and `decode_from_token()` methods are
retained
    for external callers and the prefill-then-generate workflow
  - Pass `std::function` params to `decode_from_token` by const ref
- Increase `voxtral_runner` stack to 8 MB on Windows (`/STACK:8388608`)
  - Print runner stderr in the Windows E2E test for diagnostics

## Test plan

  - [x] Verified fix on Windows CUDA CI machine via SSH
- [x] `test-model-cuda-windows-e2e (mistralai, Voxtral-Mini-3B-2507,
non-quantized)` CI passes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Unify TextLLMRunner and MultimodalRunner under IRunner

3 participants