Skip to content

GEMM + Swiglu fused Grouped MLP for MXFP8#2769

Open
ksivaman wants to merge 11 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_grouped_mlp
Open

GEMM + Swiglu fused Grouped MLP for MXFP8#2769
ksivaman wants to merge 11 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_grouped_mlp

Conversation

@ksivaman
Copy link
Member

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR adds a fused GEMM + SwiGLU forward/backward kernel for MXFP8 Grouped MLP on SM100+ (Blackwell) hardware, wiring together experimental CuTe DSL kernels from the cuDNN front-end with the existing GroupedLinear + ScaledSwiGLU op pipeline. It also introduces a single_grouped_parameter mode that packs per-group weights into a single GroupedTensor parameter, fixes a pre-existing columnwise scale shape bug (get_scale_shape(..., False → True)), adds CUDA-graph-aware grouped tensor construction, and lowers the cuBLAS minimum version requirement from 13.3 to 13.2 by explicitly handling zero-work groups in C++.

Key concerns remaining from previous review rounds (not yet resolved per the thread history):

  • global_alpha_tensor stale-size bug when two instances have different num_groups
  • Backward _get_kernel_constants missing the elif guard for norm_const_tensor
  • sf_vec_size=32 hardcoded in backward instead of using MXFP8_BLOCK_SCALING_SIZE
  • Debug assert messages with !!!! in grouped_linear.py
  • from pickle import TRUE unused import in backward_grouped_mlp.py
  • Hard assert is_supported() in the test (should be pytest.skip)
  • No validation that total token count is divisible by 128

New findings from this review:

  • cuBLAS version downgrade to 13.2 with stale comments (cublaslt_grouped_gemm.cu): Both macros were lowered but comments still say "requires 13.3+". More importantly, the original rationale for 13.3 was a documented cuBLAS 13.2 wgrad bug (uninitialized gradient data when a group has k=0). The new zero-work guard only fires when all tensors in the batch have zero work — it does not protect against the mixed-batch case where a single group has k=0.
  • Inconsistent scale_inv shape in backward (backward_grouped_mlp.py:323): fc1_dy_row_scale is stored as 2D (M, N//32) before passing to make_grouped_tensor_from_buffers, while the equivalent forward-pass tensor (fc2_in_row_scale) is explicitly .reshape(-1) to 1D. This inconsistency is currently harmless (contiguous memory layout is identical) but is confusing and fragile.
  • Redundant quantizer.set_usage() call for FC2 non-single-parameter weights (forward_grouped_mlp.py:204): set_usage is called before and inside the if not is_quantized_tensor branch with identical arguments.

Confidence Score: 2/5

  • Not yet safe to merge — multiple confirmed bugs and a potential correctness regression from the cuBLAS version downgrade remain open.
  • Several issues flagged in the previous review round are still unresolved (stale global_alpha_tensor, missing elif for norm_const_tensor, hardcoded sf_vec_size, unused import, hard test assertion, missing token divisibility guard). This review adds a new concern: lowering the cuBLAS floor to 13.2 reintroduces a documented wgrad correctness bug for mixed-batch zero-group inputs that the new all-zero-work guard does not fully cover. The fused kernel path itself is non-trivial (6D tensor permutations, swizzled MXFP8 scales, split pointer arithmetic), so the cumulative unresolved risk is high.
  • transformer_engine/common/gemm/cublaslt_grouped_gemm.cu (cuBLAS version regression), transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py (norm_const_tensor elif, sf_vec_size, scale shape), transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py (global_alpha_tensor stale size)

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py New fused op for MXFP8 FC1+SwiGLU+FC2 forward pass using CuTe DSL kernel; contains a minor redundant set_usage call and carries the stale global_alpha_tensor bug and missing token divisibility guard flagged in prior threads.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py New fused op for MXFP8 backward pass (dSwiGLU + dgrad/wgrad GEMMs); carries the hardcoded sf_vec_size=32, missing elif for norm_const_tensor, and an inconsistent 2D row-scale shape vs the forward's 1D convention flagged in this review.
transformer_engine/pytorch/ops/basic/grouped_linear.py Adds single_grouped_parameter mode (all groups packed into one GroupedTensor parameter), accumulate_into_main_grad support, and associated grad-routing logic; debug assert messages with "!!!!" remain unaddressed.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Lowers cuBLAS minimum version from 13.3 to 13.2 and adds zero-work fast-paths; header comments still say "requires 13.3+" and the original bug for mixed-batch k=0 groups may be reintroduced.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Adds make_tensor_offsets and make_grouped_tensor_from_rowwise_data helpers, fixes a pre-existing columnwise scale shape bug (False→True in get_scale_shape), and handles CUDA graph capture gracefully.
tests/pytorch/test_fusible_ops.py Expands grouped MLP test with single_grouped_parameter and accumulate_into_main_grad axes, adds CUDA-graph test; the hard assert on is_supported() (no skip guard) was flagged in prior threads.

Sequence Diagram

sequenceDiagram
    participant Input as Input Tensor<br/>(M×K, BF16)
    participant FwdFused as ForwardGroupedMLP<br/>CuTeGEMMSwiGLU_MXFP8
    participant CuDNN as cudnn<br/>grouped_gemm_swiglu<br/>_wrapper_sm100
    participant FC2GEMM as general_grouped_gemm<br/>FC2 (cuBLAS)
    participant BwdFused as BackwardGroupedMLP<br/>CuTeGEMMDSwiGLU_MXFP8
    participant CuDNN2 as cudnn<br/>grouped_gemm_dswiglu<br/>_wrapper_sm100

    Note over Input,FC2GEMM: Forward Pass (MXFP8, SM100+)

    Input->>FwdFused: x [M×K], split_sizes, scales
    FwdFused->>FwdFused: group_quantize(x) → grouped_fc1_x (FP8)
    FwdFused->>FwdFused: quantize FC1/FC2 weights (FP8)
    FwdFused->>CuDNN: fc1_x_data, fc1_w_data,<br/>fc1_x_scales, fc1_w_scales,<br/>split_points, alpha_tensor, prob_scales
    CuDNN-->>FwdFused: c_tensor (BF16, pre-SwiGLU)<br/>d_tensor (FP8, post-SwiGLU row)<br/>d_col_tensor (FP8, post-SwiGLU col)<br/>sfd_row/col (MXFP8 scales)
    FwdFused->>FwdFused: pack grouped_fc2_x<br/>from d_tensor + sfd scales
    FwdFused->>FC2GEMM: grouped_fc2_weight × grouped_fc2_x
    FC2GEMM-->>FwdFused: fc2_out [M×N]
    FwdFused->>FwdFused: save swiglu_in, fc1/fc2 weights,<br/>fc1/fc2 inputs for backward

    Note over Input,CuDNN2: Backward Pass (MXFP8, SM100+)

    FwdFused->>BwdFused: grad_output [M×N]
    BwdFused->>BwdFused: group_quantize(dy) → grouped_fc2_dy (FP8)
    BwdFused->>CuDNN2: fc2_dy_data, fc2_w_col_data,<br/>swiglu_in, scales, split_points
    CuDNN2-->>BwdFused: d_row/col_tensor (FC1 grad, FP8)<br/>sfd_row/col (MXFP8 scales)<br/>dprob_tensor (grad_scales)
    BwdFused->>FC2GEMM: FC2 wgrad: grouped_fc2_x^T × grouped_fc2_dy
    BwdFused->>FC2GEMM: FC1 dgrad: grouped_fc1_weight × grouped_fc1_dy
    BwdFused->>FC2GEMM: FC1 wgrad: grouped_fc1_x^T × grouped_fc1_dy
    FC2GEMM-->>BwdFused: grad_input, fc1_wgrad, fc2_wgrad
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into fused_mxfp8_gro..." | Re-trigger Greptile

import os
import functools
import math
from pickle import TRUE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unused import from pickle import TRUE

TRUE is imported from the pickle module but is never used anywhere in this file. pickle.TRUE is an internal pickle opcode byte string (b'I01\n'), not a Python True value. This is almost certainly a leftover from development and should be removed.

Suggested change
from pickle import TRUE
from typing import Optional

Comment on lines +202 to +205
return instance

self.with_gemm_swizzled_scales = with_gemm_swizzled_scales

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unreachable code after return

The line self.with_gemm_swizzled_scales = with_gemm_swizzled_scales is placed after return instance in __new__, so it will never execute. The intent seems to be handled already in _initialize_storage_fields (where instance.with_gemm_swizzled_scales = with_gemm_swizzled_scales is correctly set), so this line is both unreachable and redundant.

Suggested change
return instance
self.with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance
def has_data(self) -> bool:

Comment on lines +484 to +500
global global_alpha_tensor
alpha_tensor = self._mxfp8_alpha_tensor
norm_const_tensor = self._mxfp8_norm_const_tensor
if (
alpha_tensor is None
or alpha_tensor.numel() != num_groups
or alpha_tensor.dtype != dtype
or alpha_tensor.device != device
):
if global_alpha_tensor is None:
global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device)
alpha_tensor = global_alpha_tensor
norm_const_tensor = alpha_tensor[:1]
self._mxfp8_alpha_tensor = alpha_tensor
self._mxfp8_norm_const_tensor = norm_const_tensor
elif (
norm_const_tensor is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 global_alpha_tensor stale for multiple instances with different num_groups

In _get_kernel_constants, the module-level global_alpha_tensor is created only once (when it is None). If a second ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 instance is created with a different num_groups, the stale global (wrong size) is assigned to self._mxfp8_alpha_tensor. On every subsequent call, the condition alpha_tensor.numel() != num_groups will remain True, and the wrong-sized tensor from the global will keep being assigned—without ever being recreated.

Concretely:

  1. Instance A (4 groups): creates global_alpha_tensor with 4 elements ✓
  2. Instance B (8 groups): global_alpha_tensor is None is False → skips creation → assigns the 4-element tensor as alpha_tensor → wrong size passed to the kernel.

The fix is to always recreate the global when the cached size/dtype/device doesn't match:

if (
    global_alpha_tensor is None
    or global_alpha_tensor.numel() != num_groups
    or global_alpha_tensor.dtype != dtype
    or global_alpha_tensor.device != device
):
    global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device)
alpha_tensor = global_alpha_tensor

The same issue exists in backward_grouped_mlp.py's _get_kernel_constants.

Comment on lines +541 to +542
assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!"
assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Debug assertion messages should use proper error messages

These assert statements with !!!! in the messages look like debugging leftovers. While the assertions themselves are reasonable sanity checks, the message style is not production-appropriate. Consider using RuntimeError or a cleaner assert message:

Suggested change
assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!"
assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!"
if self._accumulate_into_main_grad:
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
f"Expected 'main_grad' attribute on weight parameter, but it was not found."
)
if weight_param.main_grad is None:
raise RuntimeError(
f"'main_grad' on weight parameter is None."
)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

norm_const_tensor=norm_const_tensor,
d_dtype=torch.float8_e4m3fn,
cd_major="n",
sf_vec_size=32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 sf_vec_size hardcoded instead of using the shared constant

The backward kernel call uses sf_vec_size=32 hardcoded, while the forward kernel in forward_grouped_mlp.py correctly uses sf_vec_size=MXFP8_BLOCK_SCALING_SIZE (which equals 32). Both values happen to be the same today, but for consistency and maintainability, the backward should also import and use MXFP8_BLOCK_SCALING_SIZE.

Suggested change
sf_vec_size=32,
sf_vec_size=MXFP8_BLOCK_SCALING_SIZE,

(add from ...constants import MXFP8_BLOCK_SCALING_SIZE to the imports)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from vthumbe1503 March 17, 2026 04:55
ksivaman and others added 4 commits March 16, 2026 21:59
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman marked this pull request as ready for review March 17, 2026 05:38
Comment on lines +3617 to +3637
dtype=torch.float32,
)
fc1_weight.main_grad.fill_(value)
fc2_weight.main_grad.fill_(value)

def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]:
if single_grouped_parameter:
fc1_main_grad = fc1.weight.main_grad.detach().clone()
fc2_main_grad = fc2.weight.main_grad.detach().clone()
else:
fc1_main_grad = torch.stack(
[
getattr(fc1, f"weight{group_idx}").main_grad.detach().clone()
for group_idx in range(group_size)
],
dim=0,
)
fc2_main_grad = torch.stack(
[
getattr(fc2, f"weight{group_idx}").main_grad.detach().clone()
for group_idx in range(group_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hard assert on is_supported() causes test failure instead of skip

assert te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() will raise AssertionError on any Blackwell system where the cudnn front-end package is installed for MXFP8 quantization but the specific grouped-GEMM-SwiGLU kernel is not available. In that scenario the test should be skipped, not failed. The maybe_skip_quantization guard only checks whether MXFP8 quantization in general is supported — it does not gate on the availability of the fused kernel.

Consider wrapping the fusion-presence block in a skip guard:

if (
    quantization == "mxfp8"
    and dtype in (torch.bfloat16, torch.float16)
    and not bias
    and glu_interleave_size == 32
):
    if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
        pytest.skip("MXFP8 fused grouped MLP kernel not available on this system")
    forward_ops = module._module_groups[0]._forward_ops
    ...

Comment on lines +722 to +742
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]

# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]

# Return list of ops
out.extend(window)
return out


# Register fusion if available
if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
register_backward_fusion(fuse_backward_ops, prepend=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 norm_const_tensor not refreshed when only alpha_tensor is stale

The backward's _get_kernel_constants is missing the elif guard that the forward version has. If self._mxfp8_alpha_tensor is somehow valid but self._mxfp8_norm_const_tensor is None (e.g., after a partial object state restore or future refactoring), the function returns None for norm_const_tensor, which will crash the kernel call.

The forward version correctly handles this with:

elif (
    norm_const_tensor is None
    or norm_const_tensor.numel() != 1
    or norm_const_tensor.dtype != dtype
    or norm_const_tensor.device != device
):
    norm_const_tensor = alpha_tensor[:1]
    self._mxfp8_norm_const_tensor = norm_const_tensor

Add the same defensive elif branch here for consistency and safety.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +267 to +280
fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1])
fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn)
fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0)
fc1_x_scales = grouped_fc1_x.scale_inv
fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu)
fc1_x_scales = fc1_x_scales.view(
1,
in_shape[0] // 128,
in_shape[1] // 128,
32,
4,
4,
)
fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No validation that total token count is divisible by 128

The scale tensor view at lines 272–279 uses integer division in_shape[0] // 128 to reshape the MXFP8 scale buffer. If in_shape[0] (i.e., sum(split_sizes)) is not divisible by 128, the view shape product will not match the actual buffer size and either produce incorrect behavior (wrong permute dimensions) or a runtime error with a confusing message.

The constructor checks that in_features % 256 == 0 and out_features % 256 == 0, but nothing validates that the token dimension sum(split_sizes) is divisible by 128 (required by the MXFP8 block-scaling layout). A user passing split sizes like [64, 65] would hit this silently.

The same assumption appears in the backward pass at backward_grouped_mlp.py lines 243–250.

Consider adding a guard before the view:

if in_shape[0] % 128 != 0:
    raise ValueError(
        f"Total token count must be divisible by 128 for MXFP8 fused kernel, "
        f"but got sum(split_sizes)={in_shape[0]}."
    )

bool with_gemm_swizzled = false;
if (py::hasattr(tensor, "_with_gemm_swizzled_scales")) {
with_gemm_swizzled = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
} else if (py::hasattr(tensor, "with_gemm_swizzled_scales") &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a second attribute name? Shouldn't we just use the _with_gemm_swizzled_scales attribute?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably came from merging the changes. These file changes are not needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksivaman can you remove these changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should get rid of the "_" and just use with_gemm_swizzled_scales everywhere

) -> None:
super().__init__(name)

# Temporary for quick testing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmmm.....

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet done with the full review, but cursory glance shows some leftover debugging code and some other random things that should be cleaned up.

Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);

nvte_grouped_gemm_with_discrete_out(grouped_A.get_handle(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of this name, but it was added in another PR, so not a problem here.

getTensorShape(*tensor_offsets));
}
nvte_set_grouped_tensor_swizzled_scales(out_cpp.data(),
static_cast<uint8_t>(with_gemm_swizzled_scales));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_with_gemm_swizzled_scales below already swizzles the scales. This is no longer needed. MR509 was not fully cleaned up/synced with main. I guess thats why the left over.

None of this file's changes are needed really.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, please clean it up

Comment on lines +3672 to +3673
del warmup_out, warmup_x, warmup_probs, warmup_dy
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect users to do the same things when capturing this?

Comment on lines +563 to +565
*
* \param[in] tensor Grouped tensor.
* \param[in] val 1 if scales are swizzled, 0 otherwise.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This File change is also not needed. Got rid of this experimental API now.

return 0;
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.with_gemm_swizzled_scales ? 1 : 0;
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this file can be reverted

} // namespace

// MXFP8 support for grouped GEMM requires cuBLAS 13.3+
#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to revert this changes

* - scale_inv is stored in row-major per group.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already added in previous PR. not needed

dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets
dst.logical_shape = src.logical_shape
dst.quantized_tensors = src.quantized_tensors
dst.with_gemm_swizzled_scales = src.with_gemm_swizzled_scales
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There needs to be _ in here. Attribute for swizzled scales used _


// MXFP8 support for grouped GEMM requires cuBLAS 13.3+
#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300
#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be changed - cublas 13.2 has issues with wgrad.

}

TEST(SwizzleGroupedTestSuite, TestGroupedSwizzleMXFP8) {
performTestGroupedSwizzleMXFP8(3, 256, 256);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very limited coverage - what about the cases where

  • we have different M or K?
  • M/K are not nice numbers?
  • Some of M/K are 0?

quantizer = self.get_quantizer("forward", 1)

recipe = None if quantizer is None else quantizer._get_compatible_recipe()
if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this error out instead of silently returning?

maybe_dequantize,
)

global_alpha_tensor = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this as global? This effectively makes it impossible to run multiple instances of those ops in parallel without silent data corruption.

maybe_dequantize,
)

global_alpha_tensor = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in the other file. Why do we need to have this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure that the alpha pointer is constant for cuda graphs. Without caching this tensor, the pointer changes and we get NaN in CG

Comment on lines +3461 to +3462
assert te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported()
assert te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should skip those tests instead I think? Or do we expect this to be a core requirement to have those fusions supported?

Comment on lines +122 to +124
# TODO(ksivaman): Proper support for meta device.
# We do not want to reset params later as it wipes off
# main_grad and related attributes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to be done in this PR?

* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the code of this function in this PR?

instance.quantized_tensors = None
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales

instance.with_gemm_swizzled_scales = with_gemm_swizzled_scales
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment with @vthumbe1503 , but I think we should get rid of _with_gemm_swizzled_scales

@staticmethod
def make_tensor_offsets(first_dims: torch.Tensor, logical_last_dim: int) -> torch.Tensor:
"""Calculate GPU offsets from first dim splits."""
return torch.cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This op can leverage the new fused kernel from @ksivaman

else noop_cat([w._rowwise_scale_inv for w in grouped_fc1_weight])
)
fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu)
fc1_w_scales = fc1_w_scales.view(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we use the grouped swizzle kernel here?

fc1_w_data = (
grouped_fc1_weight.rowwise_data
if fc1_op.single_grouped_parameter
else noop_cat([w._rowwise_data for w in grouped_fc1_weight])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about the discreate weight version?

fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1)
fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]).contiguous()
fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"]
fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we choose to swizzle here and then call make_grouped_tensor_from_buffers and mark with_gemm_swizzled_scales=True, we should be able to leverage grouped swizzle kernel from @vthumbe1503 ?

fc2_dy_data = fc2_dy_data.unsqueeze(0).permute(1, 2, 0)
fc2_dy_scales = grouped_fc2_dy.scale_inv
fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu)
fc2_dy_scales = fc2_dy_scales.view(
Copy link
Collaborator

@zhongbozhu zhongbozhu Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same ^ we should be able to use grouped quantize swizzle fusion here

logical_last_dim=fc1_weight_shape[1],
)

general_grouped_gemm_for_grouped_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FC1 DGRAD, would it be better if we use contiguous layout gemm with activation fusion disabled?

weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this if else is pretty long, please add some comments to the else

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants