Skip to content

[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749

Open
jberchtold-nvidia wants to merge 25 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-refactor
Open

[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
jberchtold-nvidia wants to merge 25 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-refactor

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 10, 2026

Description

This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.

Addresses issue: #2648

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • grouped_gemm API signature change: replaced the single group_sizes positional argument with six keyword
    arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
    optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
    uniform/non-ragged dimension)
  • Removed explicit M/N/K parameters from C++ FFI: matrix dimensions are now derived automatically from XLA buffer
    shapes inside the C++ handler, eliminating manual dimension computation in Python
  • Removed is_grouped_dense_wgrad flag: the wgrad vs. forward distinction is now inferred from which dimension
    arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
    (num_groups, M, N) output)
  • New C++ config struct GroupedGemmV2Config: consolidates lhs_is_trans, rhs_is_trans, and scaling_mode into a
    single FFI attribute struct, replacing individual attribute bindings
  • New C++ helper make_grouped_tensor() overload: accepts first_dims/last_dims buffers, converts int32 group-size
    arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
  • dense.py updated: _grouped_dense_fwd_rule and _grouped_dense_bwd_rule updated to pass group_sizes via the
    appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
  • Tests updated: TestGroupedDense test cases migrated to the new keyword-argument API with explicit empty_gs =
    jnp.empty((0,), jnp.int32) sentinels for non-ragged axes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 10, 2026 17:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR replaces the single group_sizes parameter in the JAX grouped GEMM API with six per-tensor dimension arrays (lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims), removes explicit M/N/K attributes from the C++ FFI, and introduces GroupedNoScaleTensor to unify the noop and quantized code paths. The wgrad distinction is now inferred from which dim array is non-empty rather than a dedicated flag.

Key observations from this review:

  • Open bugs from prior review threads: Several significant issues flagged in earlier review rounds remain unresolved, including: uninitialized int64_sizes_ptr for out-only ragged callers in GroupedGemmFFI (any_ragged excludes output dims); incorrect M derivation when lhs_is_trans=True in the non-wgrad path (both Python and C++); divide-by-zero when num_gemms=0 and rhs_is_trans=True; removed K-dimension consistency check between lhs and rhs; incorrect wgrad/fwd discriminant in the rhs_layout_is_T block; assert replacing raise ValueError for bias validation (silently dropped under -O); and scaling_mode left as NO_SCALING when lhs=GroupedNoScaleTensor, rhs=GroupedScaledTensor1x.
  • FSDP regression (dense.py lines 411–412, 467–468): The previous FP8+FSDP code path (all-gather, per-tensor quantization, psum-scatter) is now unconditionally blocked via assert not kernel_fsdp_enabled. This is a runtime-breaking regression for any caller relying on kernel_fsdp_info.
  • Dequantizer fallback gap (dequantizer.py line 278): When first_dims is an empty sentinel (not None, .size==0) and last_dims is also a non-None empty array, the uniform-group fallback to original_shape[0] is bypassed, leaving group_sizes as an empty array that would break downstream jnp.repeat.
  • ScaledTensorFactory.create_1x condition widened: The grouped path is now entered whenever original_shape is not None, even without first_dims/last_dims. No current non-grouped caller passes original_shape, so there is no regression today, but it is fragile for future callers.
  • The GroupedNoScaleTensor pytree correctly inherits tree_unflatten from AbstractBaseTensor via cls(*children, *aux_data), and the child/aux ordering aligns with the dataclass field order.

Confidence Score: 2/5

  • Not safe to merge — multiple open correctness bugs from prior review rounds remain unaddressed, plus a new FSDP runtime regression.
  • Eight concrete bugs identified in prior review threads are still open (uninitialized workspace pointer, wrong M derivation, divide-by-zero, removed K-check, incorrect discriminant, assert-vs-ValueError, scaling_mode left as NO_SCALING). This review adds two more: the FSDP path hard-crash regression in dense.py and the dequantizer empty-sentinel fallback gap. The V1 and V2 C++ FFI handlers both carry the uninitialized-pointer and wrong-M bugs, meaning the primary execution paths for non-SM100 hardware are affected. Score reflects the accumulation of unresolved issues across rounds.
  • transformer_engine/jax/cpp_extensions/gemm.py (M derivation, num_gemms=0, assert-bias, discriminant), transformer_engine/jax/csrc/extensions/gemm.cpp (uninitialized int64_sizes_ptr, M derivation in both FFI handlers), transformer_engine/jax/dense.py (FSDP regression), transformer_engine/jax/quantize/dequantizer.py (empty-sentinel fallback)

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Core API refactor replacing group_sizes+M/N/K with six per-tensor dim arrays and tensor wrapper types. Several previously-flagged bugs remain open: incorrect M derivation when lhs_is_trans=True, divide-by-zero when all dim arrays are empty, removed K-consistency check, incorrect wgrad/fwd discriminant in rhs_layout_is_T block, and assert used instead of raise ValueError for bias validation.
transformer_engine/jax/csrc/extensions/gemm.cpp Replaces explicit M/N/K attrs and is_grouped_dense_wgrad flag with GroupedGemmV2Config/GroupedGemmConfig structs and six per-tensor dims buffers. New make_grouped_tensor V2 overload handles int32→int64 conversion per-slot. Previously flagged: any_ragged excludes out-dims so int64_sizes_ptr can be uninitialized for out-only ragged callers; non-wgrad M derivation incorrect when lhs_is_trans=True for both FFI handlers.
transformer_engine/jax/dense.py Fwd/bwd rules unified: the noop and FP8 branches are merged via grouped_quantize returning GroupedNoScaleTensor for None quantizers. FSDP support was silently removed via assertion (functionality regression). flatten_axis_k is now always computed (previously None for noop) but has no impact since the noop quantizer path ignores it.
transformer_engine/jax/quantize/tensor.py Adds GroupedNoScaleTensor pytree class and migrates GroupedScaledTensor1x from group_sizes/group_axis to first_dims/last_dims. ScaledTensorFactory.create_1x condition widened from group_sizes is not None to first_dims is not None or last_dims is not None or original_shape is not None, which could pull non-grouped callers into the grouped path if they happen to pass original_shape.
transformer_engine/jax/quantize/dequantizer.py Updated to use first_dims/last_dims instead of group_sizes/group_axis. Fallback guard for deriving uniform group sizes has a subtle gap: an empty-sentinel last_dims (not None, but .size==0) would bypass the original_shape fallback and produce an empty group_sizes.

Sequence Diagram

sequenceDiagram
    participant C as grouped_dense (Python)
    participant GG as grouped_gemm (Python)
    participant P as GroupedGemmPrimitive
    participant V2 as GroupedGemmV2FFI (C++)
    participant V1 as GroupedGemmFFI (C++)

    C->>C: grouped_quantize(x, quantizer) → GroupedNoScaleTensor / GroupedScaledTensor1x
    C->>C: grouped_quantize(kernel, quantizer) → GroupedNoScaleTensor / GroupedScaledTensor1x
    C->>GG: grouped_gemm(lhs, rhs, contracting_dims=...)
    GG->>GG: extract lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims
    GG->>GG: infer out_first_dims / out_last_dims from ragged operand
    GG->>GG: compute lhs/rhs_axis_boundary, lhs/rhs_left/right_size, out_shape
    GG->>P: bind(lhs_data, rhs_data, lhs_first_dims…out_last_dims, lhs_axis_boundary…)
    alt use_v2_ffi (SM100+, BF16, no bias)
        P->>V2: GroupedGemmV2FFI(…, GroupedGemmV2Config)
        V2->>V2: grouped_gemm_num_gemms() from non-empty dims or alpha.count
        V2->>V2: make_grouped_tensor(rhs, rhs_first_dims, rhs_last_dims, int64_workspace)
        V2->>V2: make_grouped_tensor(lhs, lhs_first_dims, lhs_last_dims, int64_workspace)
        V2->>V2: make_grouped_tensor(output, out_first_dims, out_last_dims, int64_workspace)
        V2->>V2: nvte_grouped_gemm(rhs, lhs, output)
    else legacy path (FP8 / bias / older GPU)
        P->>V1: GroupedGemmFFI(…, GroupedGemmConfig)
        V1->>V1: derive m/n/k from lhs/rhs_left/right_size + is_rhs_ragged
        V1->>V1: cudaMemcpyAsync group_sizes D→H (if not async)
        V1->>V1: per-group GEMM loop
    end
Loading

Reviews (6): Last reviewed commit: "Merge branch 'main' into jberchtold/gmm-..." | Re-trigger Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from 35171af to 88bb7da Compare March 10, 2026 18:56
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from 20fadc7 to 025f598 Compare March 10, 2026 23:26
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from a427b9e to 089e530 Compare March 10, 2026 23:59
jberchtold-nvidia and others added 3 commits March 10, 2026 17:04
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review March 11, 2026 20:01
Comment on lines +1334 to +1341
def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int:
"""Non-contracting output size M from the 2-D LHS buffer."""
return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0]


def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int:
"""Non-contracting output size N from the 2-D RHS buffer."""
return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest calling it lhs_non_contracting_dims and rhs_non_contracting_dims as M and N are still ambiguous.

Besides, I think we should not assume that lhs and rhs are 2D but can be N-D.


Args:
lhs_data: Left-hand side input matrix data, 1D flattened array
lhs_data: Left-hand side input matrix data, 2D array [rows, cols]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the LHS needs to be transposed, we won't be able to have a 2D shape.

Also, I would prefer us not to reshape/merge any axes until C++. Looking into the future, especially when we have a solution to handle the EP part, we may not need to go with shard_map anymore.

rhs_first_dims_aval,
rhs_last_dims_aval,
out_first_dims_aval,
out_last_dims_aval,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does out_xxx_dims_aval need to be the inputs for the primitives? Can't the primitive come up with that after having other dims and contracting dims info?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that it doesn't need to be an input the the grouped_gemm API. To avoid differing inner/outer primitive signatures, I've kept this as an arg to the primitive but am now deriving out first and last dims from the inputs inside the grouped_gemm function instead of requiring the user to specify it.

Comment on lines +1977 to +1982
lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,)
lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,)
rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,)
rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,)
out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,)
out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either the GroupedScaledTensor should carry this information, or one should be able to interpolate this from grouped_sizes + contracting_dims.

jberchtold-nvidia and others added 2 commits March 12, 2026 14:50
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment on lines +696 to +699
first_dims is not None
or last_dims is not None
or (original_shape is not None and group_axis is not None)
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_axis is not None is always True — condition is wider than intended

group_axis has a default value of 0, so group_axis is not None evaluates to True for every caller that does not explicitly pass group_axis=None. This means the third branch of the or:

or (original_shape is not None and group_axis is not None)

reduces to simply original_shape is not None, which is a much broader guard than the old group_sizes is not None. Any call to ScaledTensorFactory.make_grouped(…, original_shape=shape) — even without first_dims or last_dims — now enters the grouped path and returns a GroupedScaledTensor1x with both dim arrays set to None. This silently changes the return type for callers that provided original_shape for informational purposes only, and those callers will now see num_groups derived implicitly from original_shape[group_axis] instead of receiving a plain ScaledTensor1x.

The condition should be restricted to the cases where grouping is actually requested:

Suggested change
first_dims is not None
or last_dims is not None
or (original_shape is not None and group_axis is not None)
):
if (
first_dims is not None
or last_dims is not None
):

If the "uniform grouped" case (kernel rhs without explicit per-group sizes) needs to be handled here, it should be expressed with an explicit sentinel argument rather than overloading original_shape.

Comment on lines +2040 to 2056

if isinstance(rhs, GroupedNoScaleTensor):
rhs_data = rhs.data
rhs_shape = rhs.original_shape
rhs_scale_inv = jnp.empty((0,), jnp.float32)
rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs
rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs
elif isinstance(rhs, GroupedScaledTensor1x):
rhs_shape = rhs.original_shape
rhs_data = rhs.data.reshape(rhs_shape)
rhs_scale_inv = rhs.scale_inv
if lhs.scaling_mode != rhs.scaling_mode:
rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs
rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs
if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode:
raise ValueError(
f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode},"
f" rhs.scaling_mode={rhs.scaling_mode}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaling_mode left as NO_SCALING when lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x

When lhs is a GroupedNoScaleTensor, the lhs block sets scaling_mode = ScalingMode.NO_SCALING. The subsequent rhs block only overrides scaling_mode when isinstance(lhs, GroupedScaledTensor1x):

if isinstance(lhs, GroupedScaledTensor1x):
    scaling_mode = lhs.scaling_mode   # never executes for GroupedNoScaleTensor lhs

So if a caller passes lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x, scaling_mode stays NO_SCALING while rhs_scale_inv holds real scale values. C++ will then use NO_SCALING logic and ignore the rhs scales entirely, producing silently wrong numerical results rather than a clear error.

The scaling-mode consistency check that guards against mismatched GroupedScaledTensor1x pairs does not fire here either because isinstance(lhs, GroupedScaledTensor1x) is False.

Add an explicit cross-type guard early in the rhs block:

elif isinstance(rhs, GroupedScaledTensor1x):
    if isinstance(lhs, GroupedNoScaleTensor):
        raise TypeError(
            "lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x; "
            "both operands must use the same tensor type."
        )
    ...

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

flatten_axis,
original_shape,
group_axis=0,
last_dims=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think last_dims and first_dims should be positioned together.

self.first_dims = first_dims
self.last_dims = last_dims
self.original_shape = original_shape
self.group_axis = group_axis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we store first_dims and last_dims now, I think we no longer need the group_axis.

Comment on lines +474 to +475
def tree_unflatten(cls, aux_data, children):
"""Reconstructs the tensor from its flattened representation."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to reuse the base tree_unflatten as the order is still cls(*children, *aux_data).

Comment on lines +494 to +495
first_dims=ctx_kernel.first_dims,
last_dims=ctx_kernel.last_dims,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to pass dims here, as the tensors should already carry them.

Comment on lines +524 to +538
if is_noop_quantizer_set:
grouped_gemm_x = GroupedNoScaleTensor(
data=grouped_gemm_x,
first_dims=group_sizes,
last_dims=None,
group_axis=0,
original_shape=grouped_gemm_x.shape,
)
grouped_gemm_kernel = GroupedNoScaleTensor(
data=grouped_gemm_kernel,
first_dims=None,
last_dims=None,
group_axis=0,
original_shape=grouped_gemm_kernel.shape,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making the grouped_quantize to return GroupedNoScaleTensor when the quantizer set is empty?

Comment on lines +1390 to +1391
out_first_dims_aval,
out_last_dims_aval,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the out_xxx_dims could be the return buffers. Why should it be input buffers?

Comment on lines +1457 to +1458
lhs_shape = lhs_data_aval.shape
rhs_shape = rhs_data_aval.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't do this as the input could be both 1D.

jberchtold-nvidia and others added 5 commits March 17, 2026 10:08
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

if quantizer is None:
if isinstance(x, NoScaleTensor):
if isinstance(x, GroupedNoScaleTensor):
assert amax is None, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After cleaning the FSDP, I think amax is no longer relevant and should be cleaned as well.

FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_data
.Arg<Buffer_Type>() // lhs_data (2D)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 2D?

Comment on lines +669 to +670
if (lhs_first_dims.element_count() > 0) {
return lhs_first_dims.dimensions()[0];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first_dims is intended to be with shape (G,), isn't element_count() and dimensions()[0] should return the same result?

jberchtold-nvidia and others added 4 commits March 24, 2026 14:12
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants