From 7b7406ca04fda06150e317afb760cc12505e4e9b Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 24 Mar 2026 12:31:24 -0700 Subject: [PATCH 1/6] adds NVFP4 Fused Adam support Signed-off-by: Jonathan Mitchell --- .../fsdp2_tests/run_fsdp2_fused_adam.py | 6 - .../fsdp2_tests/run_fsdp2_model.py | 2 +- .../pytorch/tensor/nvfp4_tensor.py | 184 ++++++++++++++++++ 3 files changed, 185 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index 877fa66795..26817a4aec 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -142,12 +142,6 @@ def test_fused_adam_fp8_master_weights(recipe_name): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "NVFP4BlockScaling": - pytest.xfail( - f"{recipe_name}: quantized_model_init and FSDP2 is not currently supported, since the " - "block tensor is dequantized before we flatten it for FSDP2." - ) - world_size, device = _get_dist_info() model = _build_model(fp8_init=True, recipe=recipe) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py index fce565ed9a..5452305783 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py @@ -363,7 +363,7 @@ def _train(args): @pytest.mark.parametrize("fp8_init", [False, True]) @pytest.mark.parametrize("layer_type", ["LayerNormLinear", "TransformerLayer"]) def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type): - if recipe_name in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init: + if recipe_name == "Float8BlockScaling" and fp8_init: pytest.xfail(f"{recipe_name} + fp8_init: test_fp8_fsdp2_allgather is currently failing.") torch.manual_seed(42) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8ed1b4682c..5d4f3dec53 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -551,6 +551,152 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self._columnwise_data is not None, } + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Called by FSDP2 before all-gather of weights. + + Prepares sharded NVFP4 tensors for dim0 all-gather by unpadding scales + and transposing columnwise data/scales so M-related dims are in dim0. + """ + # pylint: disable=unused-argument + + if self._with_gemm_swizzled_scales: + raise NotImplementedError( + "FSDP2 is not supported for NVFP4Tensors with GEMM-swizzled scales." + ) + + shard_M = math.prod(self.shape[:-1]) + K = self.shape[-1] + + # Rowwise data: (shard_M, K//2) — M in dim0, pass as-is + rowwise_data = self._rowwise_data + # Rowwise scale: (round_up(shard_M, 128), inner) — unpad dim0 to shard_M + rowwise_scale_inv = self._rowwise_scale_inv + if rowwise_scale_inv is not None: + rowwise_scale_inv = rowwise_scale_inv[:shard_M, :] + + # Columnwise data: (K, shard_M//2) — transpose to (shard_M//2, K) + columnwise_data = self._columnwise_data + # Columnwise scale: (round_up(K, 128), round_up(ceil(shard_M/16), 4)) + columnwise_scale_inv = self._columnwise_scale_inv + + if columnwise_data is not None: + columnwise_data = columnwise_data.t().contiguous() + + if columnwise_scale_inv is not None: + # Unpad dim1 from round_up(ceil(shard_M/16), 4) to ceil(shard_M/16) + m_blocks = math.ceil(shard_M / NVFP4_BLOCK_SCALING_SIZE) + columnwise_scale_inv = columnwise_scale_inv[:, :m_blocks] + # Transpose to (m_blocks, round_up(K, 128)) so M-blocks are in dim0 + columnwise_scale_inv = columnwise_scale_inv.t().contiguous() + + # Always send both orientations (GEMM needs both for fwd/bwd) + rowwise_usage = True + sharded_tensors = (rowwise_data, rowwise_scale_inv) + columnwise_usage = self._quantizer.columnwise_usage + if columnwise_usage: + sharded_tensors += (columnwise_data, columnwise_scale_inv) + + # Pass amax via metadata (scalar, same on all ranks — not all-gathered) + metadata = ( + self._fp4_dtype, + rowwise_usage, + columnwise_usage, + self._amax_rowwise, + self._amax_columnwise, + K, + ) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata, + param_dtype: torch.dtype, + *, + out: Optional[NVFP4Tensor] = None, + ): + """Called by FSDP2 after all-gather of weights. + + Reverses the transforms from fsdp_pre_all_gather: repads rowwise scales, + transposes columnwise data/scales back, and constructs the full NVFP4Tensor. + """ + fp4_dtype, rowwise_usage, columnwise_usage, amax_rowwise, amax_columnwise, K = metadata + + # Extract rowwise tensors + rowwise_data, rowwise_scale_inv = ( + all_gather_outputs[:2] if rowwise_usage else (None, None) + ) + + # Compute full_M from the all-gathered data + if rowwise_data is not None: + full_M = rowwise_data.shape[0] + else: + # columnwise_data after all-gather is (full_M//2, K) + full_M = all_gather_outputs[-2].shape[0] * 2 + + # Repad rowwise scale dim0 to round_up(full_M, 128) + if rowwise_scale_inv is not None: + target_m = round_up_to_nearest_multiple(full_M, 128) + current_m = rowwise_scale_inv.shape[0] + if current_m < target_m: + rowwise_scale_inv = torch.nn.functional.pad( + rowwise_scale_inv, (0, 0, 0, target_m - current_m) + ) + + # Extract columnwise tensors — they were transposed in pre_all_gather + columnwise_data, columnwise_scale_inv = ( + all_gather_outputs[-2:] if columnwise_usage else (None, None) + ) + + if columnwise_data is not None: + # All-gathered shape: (full_M//2, K), transpose back to (K, full_M//2) + columnwise_data = columnwise_data.t().contiguous() + + if columnwise_scale_inv is not None: + # All-gathered shape: (full_m_blocks, round_up(K, 128)) + # Transpose back to (round_up(K, 128), full_m_blocks) + columnwise_scale_inv = columnwise_scale_inv.t().contiguous() + # Repad dim1 (M-block dim) to round_up(ceil(full_M/16), 4) + current_m_blocks = columnwise_scale_inv.shape[1] + target_m_blocks = round_up_to_nearest_multiple( + math.ceil(full_M / NVFP4_BLOCK_SCALING_SIZE), 4 + ) + if current_m_blocks < target_m_blocks: + columnwise_scale_inv = torch.nn.functional.pad( + columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) + ) + elif current_m_blocks > target_m_blocks: + columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks] + + logical_shape = (full_M, K) + + if out is not None: + # Update existing tensor in-place (subsequent iterations) + out._rowwise_data = rowwise_data + out._rowwise_scale_inv = rowwise_scale_inv + out._columnwise_data = columnwise_data + out._columnwise_scale_inv = columnwise_scale_inv + out._amax_rowwise = amax_rowwise + out._amax_columnwise = amax_columnwise + else: + # Construct new tensor (first iteration) + out = NVFP4Tensor( + shape=logical_shape, + dtype=param_dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=self._quantizer, + requires_grad=False, + with_gemm_swizzled_scales=False, + ) + out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) + return out, all_gather_outputs + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -564,6 +710,44 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return tensor.detach() return tensor.view(shape) + # as_strided — FSDP2 applies this on the unsharded param. + # When shape and strides match (no-op), return self to preserve the quantized type. + if func == aten.as_strided.default: + tensor = args[0] + shape = args[1] + strides = args[2] + if ( + len(shape) == len(strides) == 2 + and tuple(strides) == (shape[-1], 1) + and tuple(shape) == tuple(tensor.size()) + ): + return NVFP4Tensor.make_like(tensor) + + # slice — FSDP2 applies this for shard unpadding. + # When the slice covers the full dimension, return self. + if func == aten.slice.Tensor: + tensor = args[0] + dim = args[1] + start = args[2] + length = args[3] + if start == 0 and length == tensor.size(dim): + return NVFP4Tensor.make_like(tensor) + + # record_stream — FSDP2 records streams on all-gathered tensors. + if func == torch.ops.aten.record_stream.default: + qt, stream = args + for t in ( + qt._rowwise_data, + qt._columnwise_data, + qt._rowwise_scale_inv, + qt._columnwise_scale_inv, + qt._amax_rowwise, + qt._amax_columnwise, + ): + if t is not None and t.is_cuda: + t.record_stream(stream) + return None + # NVFP4 dequantize not supported. Add manual support for needed funcs. if func in (aten.empty_like.default, aten.zero_.default): tensor = args[0] From ae7d87b6c2ae258c5d987015bb449329b332a996 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 24 Mar 2026 12:32:56 -0700 Subject: [PATCH 2/6] Adds testing file Signed-off-by: Jonathan Mitchell --- tests/pytorch/test_nvfp4_fsdp2_hooks.py | 282 ++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 tests/pytorch/test_nvfp4_fsdp2_hooks.py diff --git a/tests/pytorch/test_nvfp4_fsdp2_hooks.py b/tests/pytorch/test_nvfp4_fsdp2_hooks.py new file mode 100644 index 0000000000..c4589e5100 --- /dev/null +++ b/tests/pytorch/test_nvfp4_fsdp2_hooks.py @@ -0,0 +1,282 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for NVFP4Tensor FSDP2 all-gather hooks. + +These tests verify the pre/post all-gather round-trip logic on a single GPU +without requiring torchrun or multi-GPU setup. +""" + +import math +from typing import List, Tuple + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import ( + NVFP4Quantizer, + NVFP4Tensor, +) +from transformer_engine.pytorch.utils import round_up_to_nearest_multiple +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +import transformer_engine_torch as tex + +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +# Shapes that exercise various M/K combinations: +# - (512, 256): both dims cleanly divisible by 128 +# - (640, 128): M not a multiple of 128*2 but divisible by 16 +# - (256, 1024): K > M +_test_shapes: List[Tuple[int, int]] = [ + (512, 256), + (640, 128), + (256, 1024), +] + + +def _make_nvfp4_tensor(shape: Tuple[int, int]) -> NVFP4Tensor: + """Create an NVFP4Tensor from random BF16 data.""" + quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + src = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + return quantizer(src) + + +def _simulate_all_gather( + sharded_tensors: Tuple[torch.Tensor, ...], + world_size: int, +) -> Tuple[torch.Tensor, ...]: + """Simulate FSDP2 all-gather by concatenating shards along dim0.""" + return tuple(torch.cat([t] * world_size, dim=0) for t in sharded_tensors) + + +@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4) +class TestNVFP4FSDP2Hooks: + """Tests for fsdp_pre_all_gather / fsdp_post_all_gather round-trip.""" + + @staticmethod + def setup_class(cls) -> None: + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + @pytest.mark.parametrize("shape", _test_shapes) + @pytest.mark.parametrize("world_size", [2, 4]) + def test_round_trip_shapes(self, shape: Tuple[int, int], world_size: int): + """Verify that pre_all_gather -> all_gather -> post_all_gather produces correct shapes.""" + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + + # Pre all-gather + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, + ) + + # Verify sharded tensor shapes + # sharded_tensors = (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv) + assert len(sharded_tensors) == 4, "Expected 4 tensors (rowwise + columnwise)" + + rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv = sharded_tensors + + # Rowwise data: (shard_M, K//2) — unmodified + assert rowwise_data.shape == (shard_M, K // 2) + # Rowwise scale: unpadded dim0 to shard_M + assert rowwise_scale_inv.shape[0] == shard_M + + # Columnwise data: transposed to (shard_M//2, K) + assert columnwise_data.shape == (shard_M // 2, K) + # Columnwise scale: transposed to (m_blocks, round_up(K, 128)) + m_blocks = math.ceil(shard_M / NVFP4_BLOCK_SCALING_SIZE) + assert columnwise_scale_inv.shape == (m_blocks, round_up_to_nearest_multiple(K, 128)) + + # Simulate all-gather + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + + # Post all-gather + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, + ) + + # Verify output is NVFP4Tensor with correct logical shape + assert isinstance(result, NVFP4Tensor) + assert tuple(result.shape) == (M, K) + + # Verify internal data shapes + assert result._rowwise_data.shape == (M, K // 2) + + expected_rowwise_scale_shape = ( + round_up_to_nearest_multiple(M, 128), + round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4), + ) + assert result._rowwise_scale_inv.shape == expected_rowwise_scale_shape + + assert result._columnwise_data.shape == (K, M // 2) + + expected_col_scale_shape = ( + round_up_to_nearest_multiple(K, 128), + round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4), + ) + assert result._columnwise_scale_inv.shape == expected_col_scale_shape + + @pytest.mark.parametrize("shape", _test_shapes) + def test_round_trip_data_integrity(self, shape: Tuple[int, int]): + """Verify that data survives the pre -> all_gather -> post round-trip.""" + world_size = 2 + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + + # Save original internal tensors for comparison + orig_rowwise_data = qt._rowwise_data.clone() + orig_rowwise_scale = qt._rowwise_scale_inv.clone() + orig_columnwise_data = qt._columnwise_data.clone() + orig_columnwise_scale = qt._columnwise_scale_inv.clone() + orig_amax_row = qt._amax_rowwise.clone() + orig_amax_col = qt._amax_columnwise.clone() + + # Pre all-gather + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, orig_size=None, contiguous_orig_stride=None, + module=None, mp_policy=None, + ) + + # Simulate all-gather (world_size copies — data from each "rank" is identical) + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + + # Post all-gather + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, metadata, param_dtype=torch.bfloat16, + ) + + # Since each "rank" has the same data, the full rowwise_data should be + # the original shard repeated world_size times + expected_rowwise_data = torch.cat([orig_rowwise_data] * world_size, dim=0) + assert torch.equal(result._rowwise_data, expected_rowwise_data) + + # Rowwise scale: each shard's unpadded scale is repeated, then repadded + # Check that the first shard_M rows of the scale match the original (unpadded) + assert torch.equal( + result._rowwise_scale_inv[:shard_M, :], + orig_rowwise_scale[:shard_M, :], + ) + + # Columnwise data: each shard contributes (K, shard_M//2) -> repeated gives (K, M//2) + expected_col_data = torch.cat([orig_columnwise_data] * world_size, dim=1) + assert torch.equal(result._columnwise_data, expected_col_data) + + # Amax values passed through metadata — should be preserved + assert torch.equal(result._amax_rowwise, orig_amax_row) + assert torch.equal(result._amax_columnwise, orig_amax_col) + + @pytest.mark.parametrize("shape", _test_shapes) + def test_round_trip_dequantize(self, shape: Tuple[int, int]): + """Verify that dequantized values are preserved through the round-trip.""" + world_size = 2 + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + orig_deq = qt.dequantize() + + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, orig_size=None, contiguous_orig_stride=None, + module=None, mp_policy=None, + ) + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, metadata, param_dtype=torch.bfloat16, + ) + + # The full tensor should dequantize to world_size copies of the shard + result_deq = result.dequantize() + expected_deq = torch.cat([orig_deq] * world_size, dim=0) + torch.testing.assert_close(result_deq, expected_deq) + + @pytest.mark.parametrize("shape", _test_shapes) + def test_in_place_update(self, shape: Tuple[int, int]): + """Verify the out= path (in-place update on subsequent iterations).""" + world_size = 2 + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, orig_size=None, contiguous_orig_stride=None, + module=None, mp_policy=None, + ) + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + + # First call: out=None -> creates new tensor + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, metadata, param_dtype=torch.bfloat16, + ) + first_deq = result.dequantize().clone() + + # Second call: out=result -> in-place update + result2, _ = qt.fsdp_post_all_gather( + all_gather_outputs, metadata, param_dtype=torch.bfloat16, out=result, + ) + assert result2 is result # same object + torch.testing.assert_close(result2.dequantize(), first_deq) + + def test_swizzled_scales_rejected(self): + """Verify that GEMM-swizzled scales raise NotImplementedError.""" + shape = (512, 256) + qt = _make_nvfp4_tensor(shape) + qt._with_gemm_swizzled_scales = True + + with pytest.raises(NotImplementedError, match="GEMM-swizzled"): + qt.fsdp_pre_all_gather( + mesh=None, orig_size=None, contiguous_orig_stride=None, + module=None, mp_policy=None, + ) + + +@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4) +class TestNVFP4DispatchHandlers: + """Tests for as_strided, slice, and record_stream dispatch handlers.""" + + def test_as_strided_noop(self): + """as_strided with matching shape/strides returns NVFP4Tensor.""" + qt = _make_nvfp4_tensor((256, 128)) + M, K = qt.shape + result = torch.ops.aten.as_strided.default(qt, [M, K], [K, 1], 0) + assert isinstance(result, NVFP4Tensor) + assert tuple(result.shape) == (M, K) + + def test_slice_noop(self): + """slice covering full dimension returns NVFP4Tensor.""" + qt = _make_nvfp4_tensor((256, 128)) + M, K = qt.shape + result = torch.ops.aten.slice.Tensor(qt, 0, 0, M) + assert isinstance(result, NVFP4Tensor) + assert tuple(result.shape) == (M, K) + + def test_record_stream(self): + """record_stream completes without error.""" + qt = _make_nvfp4_tensor((256, 128)) + stream = torch.cuda.Stream() + result = torch.ops.aten.record_stream.default(qt, stream) + assert result is None From abb7b04d375411d46a81632a25b1e4a7ae59852c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:34:45 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_nvfp4_fsdp2_hooks.py | 47 ++++++++++++++----- .../pytorch/tensor/nvfp4_tensor.py | 15 +++--- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_nvfp4_fsdp2_hooks.py b/tests/pytorch/test_nvfp4_fsdp2_hooks.py index c4589e5100..4eb9d07c8d 100644 --- a/tests/pytorch/test_nvfp4_fsdp2_hooks.py +++ b/tests/pytorch/test_nvfp4_fsdp2_hooks.py @@ -63,7 +63,7 @@ def _simulate_all_gather( class TestNVFP4FSDP2Hooks: """Tests for fsdp_pre_all_gather / fsdp_post_all_gather round-trip.""" - @staticmethod + @classmethod def setup_class(cls) -> None: torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -155,8 +155,11 @@ def test_round_trip_data_integrity(self, shape: Tuple[int, int]): # Pre all-gather sharded_tensors, metadata = qt.fsdp_pre_all_gather( - mesh=None, orig_size=None, contiguous_orig_stride=None, - module=None, mp_policy=None, + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, ) # Simulate all-gather (world_size copies — data from each "rank" is identical) @@ -164,7 +167,9 @@ def test_round_trip_data_integrity(self, shape: Tuple[int, int]): # Post all-gather result, _ = qt.fsdp_post_all_gather( - all_gather_outputs, metadata, param_dtype=torch.bfloat16, + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, ) # Since each "rank" has the same data, the full rowwise_data should be @@ -199,12 +204,17 @@ def test_round_trip_dequantize(self, shape: Tuple[int, int]): orig_deq = qt.dequantize() sharded_tensors, metadata = qt.fsdp_pre_all_gather( - mesh=None, orig_size=None, contiguous_orig_stride=None, - module=None, mp_policy=None, + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, ) all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) result, _ = qt.fsdp_post_all_gather( - all_gather_outputs, metadata, param_dtype=torch.bfloat16, + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, ) # The full tensor should dequantize to world_size copies of the shard @@ -223,20 +233,28 @@ def test_in_place_update(self, shape: Tuple[int, int]): qt = _make_nvfp4_tensor(shard_shape) sharded_tensors, metadata = qt.fsdp_pre_all_gather( - mesh=None, orig_size=None, contiguous_orig_stride=None, - module=None, mp_policy=None, + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, ) all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) # First call: out=None -> creates new tensor result, _ = qt.fsdp_post_all_gather( - all_gather_outputs, metadata, param_dtype=torch.bfloat16, + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, ) first_deq = result.dequantize().clone() # Second call: out=result -> in-place update result2, _ = qt.fsdp_post_all_gather( - all_gather_outputs, metadata, param_dtype=torch.bfloat16, out=result, + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, + out=result, ) assert result2 is result # same object torch.testing.assert_close(result2.dequantize(), first_deq) @@ -249,8 +267,11 @@ def test_swizzled_scales_rejected(self): with pytest.raises(NotImplementedError, match="GEMM-swizzled"): qt.fsdp_pre_all_gather( - mesh=None, orig_size=None, contiguous_orig_stride=None, - module=None, mp_policy=None, + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, ) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5d4f3dec53..0779cbcbc9 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -623,9 +623,7 @@ def fsdp_post_all_gather( fp4_dtype, rowwise_usage, columnwise_usage, amax_rowwise, amax_columnwise, K = metadata # Extract rowwise tensors - rowwise_data, rowwise_scale_inv = ( - all_gather_outputs[:2] if rowwise_usage else (None, None) - ) + rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None) # Compute full_M from the all-gathered data if rowwise_data is not None: @@ -727,10 +725,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # When the slice covers the full dimension, return self. if func == aten.slice.Tensor: tensor = args[0] - dim = args[1] - start = args[2] - length = args[3] - if start == 0 and length == tensor.size(dim): + dim = args[1] if len(args) > 1 else 0 + start = args[2] if len(args) > 2 else None + end = args[3] if len(args) > 3 else None + step = args[4] if len(args) > 4 else 1 + if step == 1 and (start is None or start == 0) and ( + end is None or end >= tensor.size(dim) + ): return NVFP4Tensor.make_like(tensor) # record_stream — FSDP2 records streams on all-gathered tensors. From 238f2df1ab7242d18d5f3b2d0fec8fd685aa512f Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 24 Mar 2026 14:50:42 -0700 Subject: [PATCH 4/6] responds to comments in review Signed-off-by: Jonathan Mitchell --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d33027a41a..8d19d3182b 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 +Subproject commit 8d19d3182bfbc304046a15e9236bec9ff31511fc diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 0779cbcbc9..1ff43cef1c 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -591,6 +591,10 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # Always send both orientations (GEMM needs both for fwd/bwd) rowwise_usage = True + assert self._rowwise_data is not None, ( + "FSDP2 requires rowwise data, but _rowwise_data is None. " + "Ensure the NVFP4Quantizer was created with rowwise=True." + ) sharded_tensors = (rowwise_data, rowwise_scale_inv) columnwise_usage = self._quantizer.columnwise_usage if columnwise_usage: From 7e1af130a88666a47a4cd92b287b7f1c16ff9dc6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 21:51:35 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 1ff43cef1c..7fb252b150 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -733,8 +733,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): start = args[2] if len(args) > 2 else None end = args[3] if len(args) > 3 else None step = args[4] if len(args) > 4 else 1 - if step == 1 and (start is None or start == 0) and ( - end is None or end >= tensor.size(dim) + if ( + step == 1 + and (start is None or start == 0) + and (end is None or end >= tensor.size(dim)) ): return NVFP4Tensor.make_like(tensor) From 96cf20f94c3d7f49188bae4efb64d0a14a4febd0 Mon Sep 17 00:00:00 2001 From: Jonathan Mitchell Date: Tue, 24 Mar 2026 15:06:08 -0700 Subject: [PATCH 6/6] addresses greptile comment Signed-off-by: Jonathan Mitchell --- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 7fb252b150..2a1aa6f998 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -567,6 +567,12 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m shard_M = math.prod(self.shape[:-1]) K = self.shape[-1] + assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"FSDP2 requires shard_M ({shard_M}) to be a multiple of " + f"NVFP4_BLOCK_SCALING_SIZE ({NVFP4_BLOCK_SCALING_SIZE}). " + "Adjust model dimensions or world size." + ) + # Rowwise data: (shard_M, K//2) — M in dim0, pass as-is rowwise_data = self._rowwise_data # Rowwise scale: (round_up(shard_M, 128), inner) — unpad dim0 to shard_M @@ -667,8 +673,12 @@ def fsdp_post_all_gather( columnwise_scale_inv = torch.nn.functional.pad( columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) ) - elif current_m_blocks > target_m_blocks: - columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks] + else: + assert current_m_blocks == target_m_blocks, ( + f"Columnwise scale m_blocks mismatch: got {current_m_blocks}, " + f"expected {target_m_blocks}. This should be unreachable when " + "shard_M is a multiple of NVFP4_BLOCK_SCALING_SIZE." + ) logical_shape = (full_M, K)