[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749jberchtold-nvidia wants to merge 25 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR replaces the single Key observations from this review:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant C as grouped_dense (Python)
participant GG as grouped_gemm (Python)
participant P as GroupedGemmPrimitive
participant V2 as GroupedGemmV2FFI (C++)
participant V1 as GroupedGemmFFI (C++)
C->>C: grouped_quantize(x, quantizer) → GroupedNoScaleTensor / GroupedScaledTensor1x
C->>C: grouped_quantize(kernel, quantizer) → GroupedNoScaleTensor / GroupedScaledTensor1x
C->>GG: grouped_gemm(lhs, rhs, contracting_dims=...)
GG->>GG: extract lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims
GG->>GG: infer out_first_dims / out_last_dims from ragged operand
GG->>GG: compute lhs/rhs_axis_boundary, lhs/rhs_left/right_size, out_shape
GG->>P: bind(lhs_data, rhs_data, lhs_first_dims…out_last_dims, lhs_axis_boundary…)
alt use_v2_ffi (SM100+, BF16, no bias)
P->>V2: GroupedGemmV2FFI(…, GroupedGemmV2Config)
V2->>V2: grouped_gemm_num_gemms() from non-empty dims or alpha.count
V2->>V2: make_grouped_tensor(rhs, rhs_first_dims, rhs_last_dims, int64_workspace)
V2->>V2: make_grouped_tensor(lhs, lhs_first_dims, lhs_last_dims, int64_workspace)
V2->>V2: make_grouped_tensor(output, out_first_dims, out_last_dims, int64_workspace)
V2->>V2: nvte_grouped_gemm(rhs, lhs, output)
else legacy path (FP8 / bias / older GPU)
P->>V1: GroupedGemmFFI(…, GroupedGemmConfig)
V1->>V1: derive m/n/k from lhs/rhs_left/right_size + is_rhs_ragged
V1->>V1: cudaMemcpyAsync group_sizes D→H (if not async)
V1->>V1: per-group GEMM loop
end
Reviews (6): Last reviewed commit: "Merge branch 'main' into jberchtold/gmm-..." | Re-trigger Greptile |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
35171af to
88bb7da
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
20fadc7 to
025f598
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
a427b9e to
089e530
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
| def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: | ||
| """Non-contracting output size M from the 2-D LHS buffer.""" | ||
| return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] | ||
|
|
||
|
|
||
| def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: | ||
| """Non-contracting output size N from the 2-D RHS buffer.""" | ||
| return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] |
There was a problem hiding this comment.
Suggest calling it lhs_non_contracting_dims and rhs_non_contracting_dims as M and N are still ambiguous.
Besides, I think we should not assume that lhs and rhs are 2D but can be N-D.
|
|
||
| Args: | ||
| lhs_data: Left-hand side input matrix data, 1D flattened array | ||
| lhs_data: Left-hand side input matrix data, 2D array [rows, cols] |
There was a problem hiding this comment.
When the LHS needs to be transposed, we won't be able to have a 2D shape.
Also, I would prefer us not to reshape/merge any axes until C++. Looking into the future, especially when we have a solution to handle the EP part, we may not need to go with shard_map anymore.
| rhs_first_dims_aval, | ||
| rhs_last_dims_aval, | ||
| out_first_dims_aval, | ||
| out_last_dims_aval, |
There was a problem hiding this comment.
Why does out_xxx_dims_aval need to be the inputs for the primitives? Can't the primitive come up with that after having other dims and contracting dims info?
There was a problem hiding this comment.
Agreed that it doesn't need to be an input the the grouped_gemm API. To avoid differing inner/outer primitive signatures, I've kept this as an arg to the primitive but am now deriving out first and last dims from the inputs inside the grouped_gemm function instead of requiring the user to specify it.
| lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) | ||
| lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) | ||
| rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) | ||
| rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) | ||
| out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) | ||
| out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) |
There was a problem hiding this comment.
Either the GroupedScaledTensor should carry this information, or one should be able to interpolate this from grouped_sizes + contracting_dims.
for more information, see https://pre-commit.ci
| first_dims is not None | ||
| or last_dims is not None | ||
| or (original_shape is not None and group_axis is not None) | ||
| ): |
There was a problem hiding this comment.
group_axis is not None is always True — condition is wider than intended
group_axis has a default value of 0, so group_axis is not None evaluates to True for every caller that does not explicitly pass group_axis=None. This means the third branch of the or:
or (original_shape is not None and group_axis is not None)reduces to simply original_shape is not None, which is a much broader guard than the old group_sizes is not None. Any call to ScaledTensorFactory.make_grouped(…, original_shape=shape) — even without first_dims or last_dims — now enters the grouped path and returns a GroupedScaledTensor1x with both dim arrays set to None. This silently changes the return type for callers that provided original_shape for informational purposes only, and those callers will now see num_groups derived implicitly from original_shape[group_axis] instead of receiving a plain ScaledTensor1x.
The condition should be restricted to the cases where grouping is actually requested:
| first_dims is not None | |
| or last_dims is not None | |
| or (original_shape is not None and group_axis is not None) | |
| ): | |
| if ( | |
| first_dims is not None | |
| or last_dims is not None | |
| ): |
If the "uniform grouped" case (kernel rhs without explicit per-group sizes) needs to be handled here, it should be expressed with an explicit sentinel argument rather than overloading original_shape.
|
|
||
| if isinstance(rhs, GroupedNoScaleTensor): | ||
| rhs_data = rhs.data | ||
| rhs_shape = rhs.original_shape | ||
| rhs_scale_inv = jnp.empty((0,), jnp.float32) | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| elif isinstance(rhs, GroupedScaledTensor1x): | ||
| rhs_shape = rhs.original_shape | ||
| rhs_data = rhs.data.reshape(rhs_shape) | ||
| rhs_scale_inv = rhs.scale_inv | ||
| if lhs.scaling_mode != rhs.scaling_mode: | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: | ||
| raise ValueError( | ||
| f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," | ||
| f" rhs.scaling_mode={rhs.scaling_mode}" |
There was a problem hiding this comment.
scaling_mode left as NO_SCALING when lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x
When lhs is a GroupedNoScaleTensor, the lhs block sets scaling_mode = ScalingMode.NO_SCALING. The subsequent rhs block only overrides scaling_mode when isinstance(lhs, GroupedScaledTensor1x):
if isinstance(lhs, GroupedScaledTensor1x):
scaling_mode = lhs.scaling_mode # never executes for GroupedNoScaleTensor lhsSo if a caller passes lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x, scaling_mode stays NO_SCALING while rhs_scale_inv holds real scale values. C++ will then use NO_SCALING logic and ignore the rhs scales entirely, producing silently wrong numerical results rather than a clear error.
The scaling-mode consistency check that guards against mismatched GroupedScaledTensor1x pairs does not fire here either because isinstance(lhs, GroupedScaledTensor1x) is False.
Add an explicit cross-type guard early in the rhs block:
elif isinstance(rhs, GroupedScaledTensor1x):
if isinstance(lhs, GroupedNoScaleTensor):
raise TypeError(
"lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x; "
"both operands must use the same tensor type."
)
...|
/te-ci |
| flatten_axis, | ||
| original_shape, | ||
| group_axis=0, | ||
| last_dims=None, |
There was a problem hiding this comment.
Nit: I think last_dims and first_dims should be positioned together.
| self.first_dims = first_dims | ||
| self.last_dims = last_dims | ||
| self.original_shape = original_shape | ||
| self.group_axis = group_axis |
There was a problem hiding this comment.
Since we store first_dims and last_dims now, I think we no longer need the group_axis.
| def tree_unflatten(cls, aux_data, children): | ||
| """Reconstructs the tensor from its flattened representation.""" |
There was a problem hiding this comment.
I think we should be able to reuse the base tree_unflatten as the order is still cls(*children, *aux_data).
transformer_engine/jax/dense.py
Outdated
| first_dims=ctx_kernel.first_dims, | ||
| last_dims=ctx_kernel.last_dims, |
There was a problem hiding this comment.
I think we don't need to pass dims here, as the tensors should already carry them.
transformer_engine/jax/dense.py
Outdated
| if is_noop_quantizer_set: | ||
| grouped_gemm_x = GroupedNoScaleTensor( | ||
| data=grouped_gemm_x, | ||
| first_dims=group_sizes, | ||
| last_dims=None, | ||
| group_axis=0, | ||
| original_shape=grouped_gemm_x.shape, | ||
| ) | ||
| grouped_gemm_kernel = GroupedNoScaleTensor( | ||
| data=grouped_gemm_kernel, | ||
| first_dims=None, | ||
| last_dims=None, | ||
| group_axis=0, | ||
| original_shape=grouped_gemm_kernel.shape, | ||
| ) |
There was a problem hiding this comment.
How about making the grouped_quantize to return GroupedNoScaleTensor when the quantizer set is empty?
| out_first_dims_aval, | ||
| out_last_dims_aval, |
There was a problem hiding this comment.
But the out_xxx_dims could be the return buffers. Why should it be input buffers?
| lhs_shape = lhs_data_aval.shape | ||
| rhs_shape = rhs_data_aval.shape |
There was a problem hiding this comment.
Can't do this as the input could be both 1D.
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
| if quantizer is None: | ||
| if isinstance(x, NoScaleTensor): | ||
| if isinstance(x, GroupedNoScaleTensor): | ||
| assert amax is None, ( |
There was a problem hiding this comment.
After cleaning the FSDP, I think amax is no longer relevant and should be cleaned as well.
| FFI::Bind() | ||
| .Ctx<FFI_Stream_Type>() // stream | ||
| .Arg<Buffer_Type>() // lhs_data | ||
| .Arg<Buffer_Type>() // lhs_data (2D) |
| if (lhs_first_dims.element_count() > 0) { | ||
| return lhs_first_dims.dimensions()[0]; |
There was a problem hiding this comment.
If the first_dims is intended to be with shape (G,), isn't element_count() and dimensions()[0] should return the same result?
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Description
This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.
Addresses issue: #2648
Type of change
Changes
Please list the changes introduced in this PR:
arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
uniform/non-ragged dimension)
shapes inside the C++ handler, eliminating manual dimension computation in Python
arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
(num_groups, M, N) output)
single FFI attribute struct, replacing individual attribute bindings
arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
jnp.empty((0,), jnp.int32) sentinels for non-ragged axes
Checklist: