Skip to content

[JAX] Warmup FFIs with "initialize" stage#2800

Open
jberchtold-nvidia wants to merge 5 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/warmup-xla-ffis
Open

[JAX] Warmup FFIs with "initialize" stage#2800
jberchtold-nvidia wants to merge 5 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/warmup-xla-ffis

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 25, 2026

Description

Add "initialize" stage to TE FFIs that didn't previously have them.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add initialize FFI handlers in JAX .cpp extensions
  • Register them as "initialize" stage in pybind.cpp

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 25, 2026 21:19
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 25, 2026

Greptile Summary

This PR adds XLA FFI kInitialize-stage handlers to JAX transformer engine extensions that previously lacked them, enabling CUDA graph warmup via wrapInStreamCapture. Each new initialize function is a one-line trampoline that captures and immediately destroys a CUDA graph to prime cuBLAS/cuDNN state before any live graph capture occurs. All signatures match their execute counterparts; the previously-reported GroupedGemmV2InitializeFFI argument mismatch has been resolved.

Confidence Score: 5/5

Safe to merge; all initialize handlers are correct one-line trampolines with matching signatures and no logic changes to execute paths.

No P0 or P1 issues found. All initialize FFI signatures were verified against their execute counterparts across all 7 changed files. The GroupedGemmV2 argument-mismatch flagged in a prior review is confirmed resolved. No remaining findings rise above P2.

No files require special attention.

Vulnerabilities

No security concerns identified. The PR only adds warmup trampolines that begin a relaxed-mode CUDA stream capture, call the existing execute FFI, and immediately discard the resulting graph via cudaGraphDestroy. No new I/O, memory ownership transfer, or untrusted input parsing is introduced.

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions.h Adds XLA_FFI_DECLARE_HANDLER_SYMBOL declarations for 15 new initialize handlers; purely mechanical, no logic.
transformer_engine/jax/csrc/extensions/softmax.cpp Adds 6 softmax initialize handlers; ScaledMaskedSoftmaxBackwardInitialize correctly reuses ScaledSoftmaxBackwardInitializeFFI, matching the existing execute-handler design.
transformer_engine/jax/csrc/extensions/quantization.cpp Adds DBiasQuantizeInitializeFFI and DequantizeInitializeFFI with correct signatures; GroupedQuantize intentionally left without an initialize handler.
transformer_engine/jax/csrc/extensions/attention.cpp Adds FusedAttnForward/BackwardInitializeFFI with correct RemainingArgs variadic slot; registered alongside existing CudnnHandleInitHandler prepare stage.
transformer_engine/jax/csrc/extensions/gemm.cpp Adds GemmInitializeFFI, GemmV2InitializeFFI, and GroupedGemmV2InitializeFFI; all signatures match execute counterparts; prior GroupedGemmV2 argument-mismatch is resolved.
transformer_engine/jax/csrc/extensions/router.cpp Adds 4 MoE router initialize handlers (TopK fwd/bwd, AuxLoss fwd/bwd); all signatures match execute counterparts exactly.
transformer_engine/jax/csrc/extensions/pybind.cpp Converts 12 bare FFI registrations to initialize+execute dicts; intentional omissions for GroupedGemm and GroupedQuantize are consistent with the existing non-graph-safe design.

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Runtime
    participant Init as *InitializeFFI
    participant WSC as wrapInStreamCapture
    participant Exec as *ExecuteFFI (called internally)
    participant CUDA as CUDA Driver

    Note over JAX,CUDA: kInitialize stage (warmup — once per model)
    JAX->>Init: call initialize handler
    Init->>WSC: wrapInStreamCapture(ExecuteFFI, stream, args...)
    WSC->>CUDA: cudaStreamBeginCapture(stream, Relaxed)
    WSC->>Exec: ExecuteFFI(stream, args...) — warms cuBLAS/cuDNN state
    Exec-->>WSC: Error_Type
    WSC->>CUDA: cudaStreamEndCapture → cudaGraph_t
    WSC->>CUDA: cudaGraphDestroy (graph discarded)
    WSC-->>Init: Error_Type
    Init-->>JAX: return

    Note over JAX,CUDA: kExecute stage (every inference step)
    JAX->>Exec: ExecuteFFI(stream, args...) — now graph-capture-safe
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review March 30, 2026 19:28
phu0ngng
phu0ngng previously approved these changes Apr 1, 2026
@phu0ngng
Copy link
Copy Markdown
Collaborator

phu0ngng commented Apr 6, 2026

LGTM!

@phu0ngng
Copy link
Copy Markdown
Collaborator

phu0ngng commented Apr 7, 2026

/te-ci JAX L1

Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants