Skip to content
Open
123 changes: 48 additions & 75 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, Optional, Tuple, Union

import torch

import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
Expand Down Expand Up @@ -625,6 +624,8 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
metadata: Metadata needed for reconstructing the tensor after all-gather.
"""
# pylint: disable=unused-argument
# PyTorch FSDP2 private API – tested with PyTorch 2.5+;
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
Comment on lines +627 to +628
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Inconsistent import style for TrainingState

TrainingState is imported at the module level (line 10) in float8_tensor.py and at line 13 in mxfp8_tensor.py, but here it's imported lazily inside fsdp_pre_all_gather. While the inline comment about the private API and PyTorch version is valuable, the inconsistency across the three sibling files may confuse readers.

Consider either:

  • Moving the TrainingState import to the module level and placing the version comment there (matching the other two files), or
  • Adding the same lazy-import pattern and version comment to float8_tensor.py and mxfp8_tensor.py for symmetry.
Suggested change
# PyTorch FSDP2 private API – tested with PyTorch 2.5+;
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
# PyTorch FSDP2 private API – tested with PyTorch 2.5+;
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
from transformer_engine.pytorch.distributed import _get_module_fsdp_state

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

from transformer_engine.pytorch.distributed import _get_module_fsdp_state

if not self._is_2D_scaled:
Expand All @@ -634,42 +635,38 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
"layout has M in dim1, which is incompatible with FSDP2 dim0 all-gather."
)

block_len = self._quantizer.block_len # 128

# Prepare rowwise tensors — for 2D scaling, M is in dim0 of both data and scale_inv,
# so they naturally align with FSDP2's dim0 all-gather. No unpadding needed.
rowwise_data = self._rowwise_data
rowwise_scale_inv = self._rowwise_scale_inv

# Prepare columnwise tensors — columnwise data is transposed (K, M) and
# columnwise scale_inv is (ceil(K/128), round_up(ceil(M/128), 4)).
# M is in dim1 for both, so we must transpose to put M in dim0 for all-gather.
columnwise_data = self._columnwise_data
columnwise_scale_inv = self._columnwise_scale_inv

if columnwise_data is not None:
# Transpose (K, shard_M) -> (shard_M, K) so M is in dim0
columnwise_data = columnwise_data.t().contiguous()

if columnwise_scale_inv is not None:
# Original shape: (ceil(K/128), round_up(ceil(shard_M/128), 4))
# Strip padding from dim1 (the M-block dimension), transpose, then all-gather
shard_M = math.prod(self.shape[:-1])
m_blocks = (shard_M + block_len - 1) // block_len # ceil(shard_M/128)
columnwise_scale_inv = columnwise_scale_inv[:, :m_blocks] # unpad dim1
columnwise_scale_inv = columnwise_scale_inv.t().contiguous() # (m_blocks, k_blocks)

# Always send both rowwise and columnwise data.
# Unlike MXFP8 (where both forms share the same shape), Float8Blockwise has
# differently-shaped rowwise (M, K) and columnwise (K, M) data. The GEMM kernel
# needs both forms available to perform forward and backward operations, so we
# cannot optimize by sending only one usage based on forward/backward pass.
rowwise_usage = True
sharded_tensors = (rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
sharded_tensors += (columnwise_data, columnwise_scale_inv)
if self._rowwise_data is None or self._rowwise_scale_inv is None:
raise RuntimeError(
"Rowwise data must be available for FSDP2 all-gather with 2D block scaling."
)

fsdp_state = _get_module_fsdp_state(module)
param_group = fsdp_state._fsdp_param_group
if param_group is None:
raise RuntimeError(
"FSDP state for this module has no parameter group; "
"cannot determine reshard_after_forward."
)
reshard_after_forward = param_group._reshard_after_forward

# If weights are resharded after forward pass, only the relevant usage
# is needed based on whether it's a forward or backward pass.
# If not resharded, the same all-gathered weights are reused in backward,
# so both usages may be needed.
if reshard_after_forward:
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
else:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage
Comment on lines +656 to +663
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 columnwise_usage not derived from training state in non-resharded path

When reshard_after_forward=False, the same all-gathered weight is reused through both forward and backward passes. The code sets:

rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage

This means whether columnwise data gets derived locally (and kept) is entirely controlled by the sharded quantizer's setting, not the actual pass. The comment in the previous code explicitly noted that both forms were needed when not resharding. If self._quantizer.columnwise_usage is False (e.g. on an architecture that doesn't need the transpose), columnwise data won't be created and won't be available for the backward pass GEMM.

This matches the pre-existing float8_tensor.py behavior (same pattern there), so it's presumably already validated by the existing usage assumptions — but it would be worth a brief comment here documenting that self._quantizer.columnwise_usage must be True whenever the backward GEMM needs columnwise access for the non-resharding path.


# For 2D block scaling (128x128 blocks), columnwise data and scales are
# the transpose of rowwise data and scales. Only all-gather the rowwise
# tensors; columnwise will be derived locally via _create_columnwise()
# in post_all_gather, halving all-gather communication volume.
sharded_tensors = (self._rowwise_data, self._rowwise_scale_inv)
metadata = (self._fp8_dtype, self._is_2D_scaled, rowwise_usage, columnwise_usage)
return sharded_tensors, metadata

Expand All @@ -694,59 +691,35 @@ def fsdp_post_all_gather(
"""
fp8_dtype, is_2D_scaled, rowwise_usage, columnwise_usage = metadata

# Extract rowwise tensors from all-gather outputs
rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None)

# Extract columnwise tensors — they were transposed in pre_all_gather,
# so we need to transpose them back.
columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if columnwise_usage else (None, None)
)

if columnwise_data is not None:
# All-gathered shape is (full_M, K), transpose back to (K, full_M)
columnwise_data = columnwise_data.t().contiguous()

if columnwise_scale_inv is not None:
# All-gathered shape is (full_m_blocks, k_blocks),
# transpose back to (k_blocks, full_m_blocks)
columnwise_scale_inv = columnwise_scale_inv.t().contiguous()
# Repad dim1 (M-block dimension) to multiple of 4 for GEMM alignment
current_m_blocks = columnwise_scale_inv.shape[1]
pad_amount = (4 - current_m_blocks % 4) % 4
if pad_amount > 0:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, pad_amount)
)

# Determine the logical shape from the all-gathered data
if rowwise_data is not None:
data_shape = rowwise_data.shape
else:
# columnwise_data is (K, full_M), logical shape is (full_M, K)
data_shape = (columnwise_data.shape[1], columnwise_data.shape[0])
# Only rowwise data+scales were all-gathered (columnwise is derived locally).
rowwise_data, rowwise_scale_inv = all_gather_outputs[:2]
data_shape = rowwise_data.shape

if out is not None:
# Update existing tensor in-place (subsequent iterations)
out._rowwise_data = rowwise_data
out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
else:
# Construct new tensor (first iteration).
# Float8BlockwiseQTensor constructor copies the quantizer,
# so the sharded tensor's quantizer remains independent.
out = Float8BlockwiseQTensor(
shape=data_shape,
dtype=param_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=self._quantizer,
is_2D_scaled=is_2D_scaled,
)

# For 2D block scaling, derive columnwise data and scales from rowwise
# via local fp8 transpose.
if columnwise_usage:
out._create_columnwise()
# remove usages if not needed.
out.update_usage(
rowwise_usage=rowwise_usage,
columnwise_usage=columnwise_usage,
)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs

Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,14 +860,20 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
self._quantizer.with_amax_reduction = True

fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
param_group = fsdp_state._fsdp_param_group
if param_group is None:
raise RuntimeError(
"FSDP state for this module has no parameter group; "
"cannot determine reshard_after_forward."
)
reshard_after_forward = param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward and so we dont change the quantizer usages which might need
# both rowwise and columnwise usages.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,13 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m

# Get FSDP state
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
param_group = fsdp_state._fsdp_param_group
if param_group is None:
raise RuntimeError(
"FSDP state for this module has no parameter group; "
"cannot determine reshard_after_forward."
)
reshard_after_forward = param_group._reshard_after_forward

# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
Expand Down Expand Up @@ -662,7 +668,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
training_state = param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
rowwise_usage = not is_backward_pass
Expand Down
Loading