diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ab496d5a9e..bbfc43e9bb 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -10,7 +10,6 @@ from typing import Any, Optional, Tuple, Union import torch - import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe @@ -625,6 +624,8 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m metadata: Metadata needed for reconstructing the tensor after all-gather. """ # pylint: disable=unused-argument + # PyTorch FSDP2 private API – tested with PyTorch 2.5+; + from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState from transformer_engine.pytorch.distributed import _get_module_fsdp_state if not self._is_2D_scaled: @@ -634,42 +635,38 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m "layout has M in dim1, which is incompatible with FSDP2 dim0 all-gather." ) - block_len = self._quantizer.block_len # 128 - - # Prepare rowwise tensors — for 2D scaling, M is in dim0 of both data and scale_inv, - # so they naturally align with FSDP2's dim0 all-gather. No unpadding needed. - rowwise_data = self._rowwise_data - rowwise_scale_inv = self._rowwise_scale_inv - - # Prepare columnwise tensors — columnwise data is transposed (K, M) and - # columnwise scale_inv is (ceil(K/128), round_up(ceil(M/128), 4)). - # M is in dim1 for both, so we must transpose to put M in dim0 for all-gather. - columnwise_data = self._columnwise_data - columnwise_scale_inv = self._columnwise_scale_inv - - if columnwise_data is not None: - # Transpose (K, shard_M) -> (shard_M, K) so M is in dim0 - columnwise_data = columnwise_data.t().contiguous() - - if columnwise_scale_inv is not None: - # Original shape: (ceil(K/128), round_up(ceil(shard_M/128), 4)) - # Strip padding from dim1 (the M-block dimension), transpose, then all-gather - shard_M = math.prod(self.shape[:-1]) - m_blocks = (shard_M + block_len - 1) // block_len # ceil(shard_M/128) - columnwise_scale_inv = columnwise_scale_inv[:, :m_blocks] # unpad dim1 - columnwise_scale_inv = columnwise_scale_inv.t().contiguous() # (m_blocks, k_blocks) - - # Always send both rowwise and columnwise data. - # Unlike MXFP8 (where both forms share the same shape), Float8Blockwise has - # differently-shaped rowwise (M, K) and columnwise (K, M) data. The GEMM kernel - # needs both forms available to perform forward and backward operations, so we - # cannot optimize by sending only one usage based on forward/backward pass. - rowwise_usage = True - sharded_tensors = (rowwise_data, rowwise_scale_inv) - columnwise_usage = self._quantizer.columnwise_usage - if columnwise_usage: - sharded_tensors += (columnwise_data, columnwise_scale_inv) + if self._rowwise_data is None or self._rowwise_scale_inv is None: + raise RuntimeError( + "Rowwise data must be available for FSDP2 all-gather with 2D block scaling." + ) + fsdp_state = _get_module_fsdp_state(module) + param_group = fsdp_state._fsdp_param_group + if param_group is None: + raise RuntimeError( + "FSDP state for this module has no parameter group; " + "cannot determine reshard_after_forward." + ) + reshard_after_forward = param_group._reshard_after_forward + + # If weights are resharded after forward pass, only the relevant usage + # is needed based on whether it's a forward or backward pass. + # If not resharded, the same all-gathered weights are reused in backward, + # so both usages may be needed. + if reshard_after_forward: + training_state = param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + rowwise_usage = not is_backward_pass + columnwise_usage = is_backward_pass + else: + rowwise_usage = True + columnwise_usage = self._quantizer.columnwise_usage + + # For 2D block scaling (128x128 blocks), columnwise data and scales are + # the transpose of rowwise data and scales. Only all-gather the rowwise + # tensors; columnwise will be derived locally via _create_columnwise() + # in post_all_gather, halving all-gather communication volume. + sharded_tensors = (self._rowwise_data, self._rowwise_scale_inv) metadata = (self._fp8_dtype, self._is_2D_scaled, rowwise_usage, columnwise_usage) return sharded_tensors, metadata @@ -694,59 +691,35 @@ def fsdp_post_all_gather( """ fp8_dtype, is_2D_scaled, rowwise_usage, columnwise_usage = metadata - # Extract rowwise tensors from all-gather outputs - rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None) - - # Extract columnwise tensors — they were transposed in pre_all_gather, - # so we need to transpose them back. - columnwise_data, columnwise_scale_inv = ( - all_gather_outputs[-2:] if columnwise_usage else (None, None) - ) - - if columnwise_data is not None: - # All-gathered shape is (full_M, K), transpose back to (K, full_M) - columnwise_data = columnwise_data.t().contiguous() - - if columnwise_scale_inv is not None: - # All-gathered shape is (full_m_blocks, k_blocks), - # transpose back to (k_blocks, full_m_blocks) - columnwise_scale_inv = columnwise_scale_inv.t().contiguous() - # Repad dim1 (M-block dimension) to multiple of 4 for GEMM alignment - current_m_blocks = columnwise_scale_inv.shape[1] - pad_amount = (4 - current_m_blocks % 4) % 4 - if pad_amount > 0: - columnwise_scale_inv = torch.nn.functional.pad( - columnwise_scale_inv, (0, pad_amount) - ) - - # Determine the logical shape from the all-gathered data - if rowwise_data is not None: - data_shape = rowwise_data.shape - else: - # columnwise_data is (K, full_M), logical shape is (full_M, K) - data_shape = (columnwise_data.shape[1], columnwise_data.shape[0]) + # Only rowwise data+scales were all-gathered (columnwise is derived locally). + rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] + data_shape = rowwise_data.shape if out is not None: - # Update existing tensor in-place (subsequent iterations) out._rowwise_data = rowwise_data out._rowwise_scale_inv = rowwise_scale_inv - out._columnwise_data = columnwise_data - out._columnwise_scale_inv = columnwise_scale_inv else: - # Construct new tensor (first iteration). - # Float8BlockwiseQTensor constructor copies the quantizer, - # so the sharded tensor's quantizer remains independent. out = Float8BlockwiseQTensor( shape=data_shape, dtype=param_dtype, fp8_dtype=fp8_dtype, rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, + columnwise_data=None, + columnwise_scale_inv=None, quantizer=self._quantizer, is_2D_scaled=is_2D_scaled, ) + + # For 2D block scaling, derive columnwise data and scales from rowwise + # via local fp8 transpose. + if columnwise_usage: + out._create_columnwise() + # remove usages if not needed. + out.update_usage( + rowwise_usage=rowwise_usage, + columnwise_usage=columnwise_usage, + ) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5f00bc8017..e8284eaa53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -860,14 +860,20 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m self._quantizer.with_amax_reduction = True fsdp_state = _get_module_fsdp_state(module) - reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + param_group = fsdp_state._fsdp_param_group + if param_group is None: + raise RuntimeError( + "FSDP state for this module has no parameter group; " + "cannot determine reshard_after_forward." + ) + reshard_after_forward = param_group._reshard_after_forward # If weights are resharded after forward pass, then its enough to set the quantizer usages # based on whether its forward or backward pass for the allgathered weights. # If not resharded after forward pass, the same weights allgathered in forward # are used again in backward and so we dont change the quantizer usages which might need # both rowwise and columnwise usages. if reshard_after_forward: - training_state = fsdp_state._fsdp_param_group._training_state + training_state = param_group._training_state is_backward_pass = training_state == TrainingState.PRE_BACKWARD # In case of hopper/L40, only one of data/transpose is needed # based on forward or backward pass. So setting the quantizer usages appropriately. diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index baff9cc2aa..965f59b320 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -634,7 +634,13 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # Get FSDP state fsdp_state = _get_module_fsdp_state(module) - reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + param_group = fsdp_state._fsdp_param_group + if param_group is None: + raise RuntimeError( + "FSDP state for this module has no parameter group; " + "cannot determine reshard_after_forward." + ) + reshard_after_forward = param_group._reshard_after_forward # Remove padding from scale inverses before allgather # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] @@ -662,7 +668,7 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m # are used again in backward. And hence if we need the columnwise data/scale_inv, # we need to send them as well for allgather in forward pass itself. if reshard_after_forward: - training_state = fsdp_state._fsdp_param_group._training_state + training_state = param_group._training_state is_backward_pass = training_state == TrainingState.PRE_BACKWARD # Allgather only the necessary tensors based on forward/backward pass rowwise_usage = not is_backward_pass