From ef51ab40335dd461ecf1a76fd6e95c3d89ef5489 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 23 Mar 2026 00:36:21 +0000 Subject: [PATCH 1/7] done Signed-off-by: Varun Thumbe --- .../pytorch/tensor/float8_blockwise_tensor.py | 119 +++++++----------- 1 file changed, 42 insertions(+), 77 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ab496d5a9e..8a7aeb1dc5 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -10,7 +10,7 @@ from typing import Any, Optional, Tuple, Union import torch - +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe @@ -634,42 +634,31 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m "layout has M in dim1, which is incompatible with FSDP2 dim0 all-gather." ) - block_len = self._quantizer.block_len # 128 - - # Prepare rowwise tensors — for 2D scaling, M is in dim0 of both data and scale_inv, - # so they naturally align with FSDP2's dim0 all-gather. No unpadding needed. - rowwise_data = self._rowwise_data - rowwise_scale_inv = self._rowwise_scale_inv - - # Prepare columnwise tensors — columnwise data is transposed (K, M) and - # columnwise scale_inv is (ceil(K/128), round_up(ceil(M/128), 4)). - # M is in dim1 for both, so we must transpose to put M in dim0 for all-gather. - columnwise_data = self._columnwise_data - columnwise_scale_inv = self._columnwise_scale_inv - - if columnwise_data is not None: - # Transpose (K, shard_M) -> (shard_M, K) so M is in dim0 - columnwise_data = columnwise_data.t().contiguous() - - if columnwise_scale_inv is not None: - # Original shape: (ceil(K/128), round_up(ceil(shard_M/128), 4)) - # Strip padding from dim1 (the M-block dimension), transpose, then all-gather - shard_M = math.prod(self.shape[:-1]) - m_blocks = (shard_M + block_len - 1) // block_len # ceil(shard_M/128) - columnwise_scale_inv = columnwise_scale_inv[:, :m_blocks] # unpad dim1 - columnwise_scale_inv = columnwise_scale_inv.t().contiguous() # (m_blocks, k_blocks) - - # Always send both rowwise and columnwise data. - # Unlike MXFP8 (where both forms share the same shape), Float8Blockwise has - # differently-shaped rowwise (M, K) and columnwise (K, M) data. The GEMM kernel - # needs both forms available to perform forward and backward operations, so we - # cannot optimize by sending only one usage based on forward/backward pass. - rowwise_usage = True - sharded_tensors = (rowwise_data, rowwise_scale_inv) - columnwise_usage = self._quantizer.columnwise_usage - if columnwise_usage: - sharded_tensors += (columnwise_data, columnwise_scale_inv) - + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Rowwise data must be available for FSDP2 all-gather with 2D block scaling." + + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + + # If weights are resharded after forward pass, only the relevant usage + # is needed based on whether it's a forward or backward pass. + # If not resharded, the same all-gathered weights are reused in backward, + # so both usages may be needed. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + rowwise_usage = not is_backward_pass + columnwise_usage = is_backward_pass + else: + rowwise_usage = True + columnwise_usage = self._quantizer.columnwise_usage + + # For 2D block scaling (128x128 blocks), columnwise data and scales are + # the transpose of rowwise data and scales. Only all-gather the rowwise + # tensors; columnwise will be derived locally via _create_columnwise() + # in post_all_gather, halving all-gather communication volume. + sharded_tensors = (self._rowwise_data, self._rowwise_scale_inv) metadata = (self._fp8_dtype, self._is_2D_scaled, rowwise_usage, columnwise_usage) return sharded_tensors, metadata @@ -694,59 +683,35 @@ def fsdp_post_all_gather( """ fp8_dtype, is_2D_scaled, rowwise_usage, columnwise_usage = metadata - # Extract rowwise tensors from all-gather outputs - rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None) - - # Extract columnwise tensors — they were transposed in pre_all_gather, - # so we need to transpose them back. - columnwise_data, columnwise_scale_inv = ( - all_gather_outputs[-2:] if columnwise_usage else (None, None) - ) - - if columnwise_data is not None: - # All-gathered shape is (full_M, K), transpose back to (K, full_M) - columnwise_data = columnwise_data.t().contiguous() - - if columnwise_scale_inv is not None: - # All-gathered shape is (full_m_blocks, k_blocks), - # transpose back to (k_blocks, full_m_blocks) - columnwise_scale_inv = columnwise_scale_inv.t().contiguous() - # Repad dim1 (M-block dimension) to multiple of 4 for GEMM alignment - current_m_blocks = columnwise_scale_inv.shape[1] - pad_amount = (4 - current_m_blocks % 4) % 4 - if pad_amount > 0: - columnwise_scale_inv = torch.nn.functional.pad( - columnwise_scale_inv, (0, pad_amount) - ) - - # Determine the logical shape from the all-gathered data - if rowwise_data is not None: - data_shape = rowwise_data.shape - else: - # columnwise_data is (K, full_M), logical shape is (full_M, K) - data_shape = (columnwise_data.shape[1], columnwise_data.shape[0]) + # Only rowwise data+scales were all-gathered (columnwise is derived locally). + rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] + data_shape = rowwise_data.shape if out is not None: - # Update existing tensor in-place (subsequent iterations) out._rowwise_data = rowwise_data out._rowwise_scale_inv = rowwise_scale_inv - out._columnwise_data = columnwise_data - out._columnwise_scale_inv = columnwise_scale_inv else: - # Construct new tensor (first iteration). - # Float8BlockwiseQTensor constructor copies the quantizer, - # so the sharded tensor's quantizer remains independent. out = Float8BlockwiseQTensor( shape=data_shape, dtype=param_dtype, fp8_dtype=fp8_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, + columnwise_data=None, + columnwise_scale_inv=None, quantizer=self._quantizer, is_2D_scaled=is_2D_scaled, ) + + # For 2D block scaling, derive columnwise data and scales from rowwise + # via local fp8 transpose instead of all-gathering them separately. + if columnwise_usage: + out._create_columnwise() + # remove usages if not needed. + out.update_usage( + rowwise_usage=rowwise_usage, + columnwise_usage=columnwise_usage, + ) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs @@ -976,4 +941,4 @@ def backward( requires_grad=grad.requires_grad, ) return dgrad, None - return grad.view(ctx.shape), None + return grad.view(ctx.shape), None \ No newline at end of file From d504f05bc934fd050bddb0656bba4b5ffcef7c0b Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 23 Mar 2026 00:40:59 +0000 Subject: [PATCH 2/7] one review comment form greptile Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 8a7aeb1dc5..cb4571a092 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -10,7 +10,6 @@ from typing import Any, Optional, Tuple, Union import torch -from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe @@ -625,6 +624,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m metadata: Metadata needed for reconstructing the tensor after all-gather. """ # pylint: disable=unused-argument + from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState from transformer_engine.pytorch.distributed import _get_module_fsdp_state if not self._is_2D_scaled: From e5594bcfd635e581a46d7e933fc5abf6f6f8f9e3 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 23 Mar 2026 00:43:52 +0000 Subject: [PATCH 3/7] instead part of the comment not needed Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index cb4571a092..ff9c0c06d3 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -704,7 +704,7 @@ def fsdp_post_all_gather( ) # For 2D block scaling, derive columnwise data and scales from rowwise - # via local fp8 transpose instead of all-gathering them separately. + # via local fp8 transpose. if columnwise_usage: out._create_columnwise() # remove usages if not needed. From bb70a33cdd1c6edb44673d8eb8ba79e0bd892357 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 00:45:53 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ff9c0c06d3..51f5baa430 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -941,4 +941,4 @@ def backward( requires_grad=grad.requires_grad, ) return dgrad, None - return grad.view(ctx.shape), None \ No newline at end of file + return grad.view(ctx.shape), None From 06c0952708b9b8934c0b634c2bf897ebac4005c6 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 23 Mar 2026 05:52:22 +0000 Subject: [PATCH 5/7] address review comments Signed-off-by: Varun Thumbe --- .../pytorch/tensor/float8_blockwise_tensor.py | 13 +++++++++++-- transformer_engine/pytorch/tensor/float8_tensor.py | 10 ++++++++-- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 10 ++++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ff9c0c06d3..7b0ed68289 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -624,6 +624,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m metadata: Metadata needed for reconstructing the tensor after all-gather. """ # pylint: disable=unused-argument + # PyTorch FSDP2 private API – tested with PyTorch 2.5+; from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState from transformer_engine.pytorch.distributed import _get_module_fsdp_state @@ -639,14 +640,20 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m ), "Rowwise data must be available for FSDP2 all-gather with 2D block scaling." fsdp_state = _get_module_fsdp_state(module) - reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + param_group = fsdp_state._fsdp_param_group + if param_group is None: + raise RuntimeError( + "FSDP state for this module has no parameter group; " + "cannot determine reshard_after_forward." + ) + reshard_after_forward = param_group._reshard_after_forward # If weights are resharded after forward pass, only the relevant usage # is needed based on whether it's a forward or backward pass. # If not resharded, the same all-gathered weights are reused in backward, # so both usages may be needed. if reshard_after_forward: - training_state = fsdp_state._fsdp_param_group._training_state + training_state = param_group._training_state is_backward_pass = training_state == TrainingState.PRE_BACKWARD rowwise_usage = not is_backward_pass columnwise_usage = is_backward_pass @@ -690,6 +697,8 @@ def fsdp_post_all_gather( if out is not None: out._rowwise_data = rowwise_data out._rowwise_scale_inv = rowwise_scale_inv + out._columnwise_data = None + out._columnwise_scale_inv = None else: out = Float8BlockwiseQTensor( shape=data_shape, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5f00bc8017..e8284eaa53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -860,14 +860,20 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m self._quantizer.with_amax_reduction = True fsdp_state = _get_module_fsdp_state(module) - reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + param_group = fsdp_state._fsdp_param_group + if param_group is None: + raise RuntimeError( + "FSDP state for this module has no parameter group; " + "cannot determine reshard_after_forward." + ) + reshard_after_forward = param_group._reshard_after_forward # If weights are resharded after forward pass, then its enough to set the quantizer usages # based on whether its forward or backward pass for the allgathered weights. # If not resharded after forward pass, the same weights allgathered in forward # are used again in backward and so we dont change the quantizer usages which might need # both rowwise and columnwise usages. if reshard_after_forward: - training_state = fsdp_state._fsdp_param_group._training_state + training_state = param_group._training_state is_backward_pass = training_state == TrainingState.PRE_BACKWARD # In case of hopper/L40, only one of data/transpose is needed # based on forward or backward pass. So setting the quantizer usages appropriately. diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index baff9cc2aa..965f59b320 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -634,7 +634,13 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # Get FSDP state fsdp_state = _get_module_fsdp_state(module) - reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + param_group = fsdp_state._fsdp_param_group + if param_group is None: + raise RuntimeError( + "FSDP state for this module has no parameter group; " + "cannot determine reshard_after_forward." + ) + reshard_after_forward = param_group._reshard_after_forward # Remove padding from scale inverses before allgather # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] @@ -662,7 +668,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # are used again in backward. And hence if we need the columnwise data/scale_inv, # we need to send them as well for allgather in forward pass itself. if reshard_after_forward: - training_state = fsdp_state._fsdp_param_group._training_state + training_state = param_group._training_state is_backward_pass = training_state == TrainingState.PRE_BACKWARD # Allgather only the necessary tensors based on forward/backward pass rowwise_usage = not is_backward_pass From 347c2761bdf8299a7d6296d8765be5509c68d094 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 23 Mar 2026 08:54:38 -0700 Subject: [PATCH 6/7] Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 --- .../pytorch/tensor/float8_blockwise_tensor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index c1b44a4bee..16c91dcbe5 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -635,9 +635,10 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m "layout has M in dim1, which is incompatible with FSDP2 dim0 all-gather." ) - assert ( - self._rowwise_data is not None and self._rowwise_scale_inv is not None - ), "Rowwise data must be available for FSDP2 all-gather with 2D block scaling." + if self._rowwise_data is None or self._rowwise_scale_inv is None: + raise RuntimeError( + "Rowwise data must be available for FSDP2 all-gather with 2D block scaling." + ) fsdp_state = _get_module_fsdp_state(module) param_group = fsdp_state._fsdp_param_group From 0750751bcef91ad27c425dc12bc9df6a1abf4727 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 23 Mar 2026 08:57:11 -0700 Subject: [PATCH 7/7] No need to set it to None Remove unnecessary columnwise data and scale inv assignments. Signed-off-by: vthumbe1503 --- transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 16c91dcbe5..bbfc43e9bb 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -698,8 +698,6 @@ def fsdp_post_all_gather( if out is not None: out._rowwise_data = rowwise_data out._rowwise_scale_inv = rowwise_scale_inv - out._columnwise_data = None - out._columnwise_scale_inv = None else: out = Float8BlockwiseQTensor( shape=data_shape,