Skip to content

[JAX] TE GMM v2 enforcement Env Var#2794

Draft
jberchtold-nvidia wants to merge 1 commit intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-debug-flags
Draft

[JAX] TE GMM v2 enforcement Env Var#2794
jberchtold-nvidia wants to merge 1 commit intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-debug-flags

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 23, 2026 22:52
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 23, 2026

Greptile Summary

This PR adds a NVTE_JAX_ENFORCE_V2_GROUPED_GEMM environment variable that allows users to assert the V2 (CUDA-graphable) grouped GEMM path is taken, turning silent fallbacks into loud RuntimeErrors useful for debugging and CI enforcement.

Key changes in transformer_engine/jax/cpp_extensions/gemm.py:

  • Captures _v2_grouped_gemm_available_reason at import time so the cuBLAS version error can be surfaced in the enforcement message.
  • Adds a @cache-decorated _should_enforce_v2_grouped_gemm() helper that reads the env var exactly once per process.
  • Extends _can_use_v2_grouped_gemm with three guard-point checks: cuBLAS version availability, SM compute capability, and dtype/quantization/bias restrictions.
  • One minor style note: get_device_compute_capability(0) is called twice in the compute-capability guard (once for the comparison and once inside the error message string).

Confidence Score: 5/5

  • Safe to merge — no logic or runtime errors introduced; the feature is additive and opt-in via an env var.
  • All three enforcement guard points are logically correct and consistent with the existing fallback logic. Error messages are clear and include diagnostic context. The only finding is a trivial style suggestion (double get_device_compute_capability(0) call) that does not affect correctness or reliability.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Adds NVTE_JAX_ENFORCE_V2_GROUPED_GEMM env-var support: a cached helper reads it once per process, and _can_use_v2_grouped_gemm raises descriptive RuntimeErrors (instead of silently returning False) at each of the three guard points when enforcement is active. Logic is sound; one minor style note about a redundant get_device_compute_capability(0) call.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[_can_use_v2_grouped_gemm called] --> B[read enforce_v2_gmm via cached env-var]
    B --> C{_v2_grouped_gemm_available?}
    C -- No --> D{enforce_v2_gmm?}
    D -- Yes --> E[raise RuntimeError with reason string]
    D -- No --> F[return False]
    C -- Yes --> G{device SM < 100?}
    G -- Yes --> H{enforce_v2_gmm?}
    H -- Yes --> I[raise RuntimeError with compute cap]
    H -- No --> J[return False]
    G -- No --> K{NO_SCALING + BF16 + no bias?}
    K -- Yes --> L[return True  → use V2 path]
    K -- No --> M{enforce_v2_gmm?}
    M -- Yes --> N[raise RuntimeError with dtype/bias/mode]
    M -- No --> O[return False  → use V1 path]
Loading

Reviews (1): Last reviewed commit: "TE GMM v2 enforcement env var" | Re-trigger Greptile

Comment on lines 1957 to +1963
if get_device_compute_capability(0) < 100:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device"
f" compute capability of GPU 0 is {get_device_compute_capability(0)} and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant get_device_compute_capability(0) call

get_device_compute_capability(0) is already called in the if condition on line 1957, and then called a second time inside the error message on line 1961. While the call is likely cheap, it is cleaner and more efficient to capture the result once and reuse it.

Suggested change
if get_device_compute_capability(0) < 100:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device"
f" compute capability of GPU 0 is {get_device_compute_capability(0)} and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
cap = get_device_compute_capability(0)
if cap < 100:
if enforce_v2_gmm:
raise RuntimeError(
"The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device"
f" compute capability of GPU 0 is {cap} and"
" NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled."
)
return False

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant