[JAX] Warmup FFIs with "initialize" stage#2800
[JAX] Warmup FFIs with "initialize" stage#2800jberchtold-nvidia wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 jax |
Greptile SummaryThis PR adds XLA FFI Confidence Score: 5/5Safe to merge; all initialize handlers are correct one-line trampolines with matching signatures and no logic changes to execute paths. No P0 or P1 issues found. All initialize FFI signatures were verified against their execute counterparts across all 7 changed files. The GroupedGemmV2 argument-mismatch flagged in a prior review is confirmed resolved. No remaining findings rise above P2. No files require special attention.
|
| Filename | Overview |
|---|---|
| transformer_engine/jax/csrc/extensions.h | Adds XLA_FFI_DECLARE_HANDLER_SYMBOL declarations for 15 new initialize handlers; purely mechanical, no logic. |
| transformer_engine/jax/csrc/extensions/softmax.cpp | Adds 6 softmax initialize handlers; ScaledMaskedSoftmaxBackwardInitialize correctly reuses ScaledSoftmaxBackwardInitializeFFI, matching the existing execute-handler design. |
| transformer_engine/jax/csrc/extensions/quantization.cpp | Adds DBiasQuantizeInitializeFFI and DequantizeInitializeFFI with correct signatures; GroupedQuantize intentionally left without an initialize handler. |
| transformer_engine/jax/csrc/extensions/attention.cpp | Adds FusedAttnForward/BackwardInitializeFFI with correct RemainingArgs variadic slot; registered alongside existing CudnnHandleInitHandler prepare stage. |
| transformer_engine/jax/csrc/extensions/gemm.cpp | Adds GemmInitializeFFI, GemmV2InitializeFFI, and GroupedGemmV2InitializeFFI; all signatures match execute counterparts; prior GroupedGemmV2 argument-mismatch is resolved. |
| transformer_engine/jax/csrc/extensions/router.cpp | Adds 4 MoE router initialize handlers (TopK fwd/bwd, AuxLoss fwd/bwd); all signatures match execute counterparts exactly. |
| transformer_engine/jax/csrc/extensions/pybind.cpp | Converts 12 bare FFI registrations to initialize+execute dicts; intentional omissions for GroupedGemm and GroupedQuantize are consistent with the existing non-graph-safe design. |
Sequence Diagram
sequenceDiagram
participant JAX as JAX Runtime
participant Init as *InitializeFFI
participant WSC as wrapInStreamCapture
participant Exec as *ExecuteFFI (called internally)
participant CUDA as CUDA Driver
Note over JAX,CUDA: kInitialize stage (warmup — once per model)
JAX->>Init: call initialize handler
Init->>WSC: wrapInStreamCapture(ExecuteFFI, stream, args...)
WSC->>CUDA: cudaStreamBeginCapture(stream, Relaxed)
WSC->>Exec: ExecuteFFI(stream, args...) — warms cuBLAS/cuDNN state
Exec-->>WSC: Error_Type
WSC->>CUDA: cudaStreamEndCapture → cudaGraph_t
WSC->>CUDA: cudaGraphDestroy (graph discarded)
WSC-->>Init: Error_Type
Init-->>JAX: return
Note over JAX,CUDA: kExecute stage (every inference step)
JAX->>Exec: ExecuteFFI(stream, args...) — now graph-capture-safe
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile
|
LGTM! |
|
/te-ci JAX L1 |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 jax |
Description
Add "initialize" stage to TE FFIs that didn't previously have them.
Type of change
Changes
Checklist: