-
Notifications
You must be signed in to change notification settings - Fork 675
Optimize fp8 block scaling Allgather for FSDP2 #2789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ef51ab4
d504f05
e5594bc
bb70a33
06c0952
d46c82b
347c276
0750751
a4e655e
83f0fe8
2295cae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,7 +10,6 @@ | |||||||||||
| from typing import Any, Optional, Tuple, Union | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
|
|
||||||||||||
| import transformer_engine_torch as tex | ||||||||||||
| from transformer_engine_torch import DType as TE_DType | ||||||||||||
| from transformer_engine.common.recipe import Float8BlockScaling, Recipe | ||||||||||||
|
|
@@ -625,6 +624,8 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m | |||||||||||
| metadata: Metadata needed for reconstructing the tensor after all-gather. | ||||||||||||
| """ | ||||||||||||
| # pylint: disable=unused-argument | ||||||||||||
| # PyTorch FSDP2 private API – tested with PyTorch 2.5+; | ||||||||||||
| from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState | ||||||||||||
|
Comment on lines
+627
to
+628
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider either:
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||
| from transformer_engine.pytorch.distributed import _get_module_fsdp_state | ||||||||||||
|
|
||||||||||||
| if not self._is_2D_scaled: | ||||||||||||
|
|
@@ -634,42 +635,38 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m | |||||||||||
| "layout has M in dim1, which is incompatible with FSDP2 dim0 all-gather." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| block_len = self._quantizer.block_len # 128 | ||||||||||||
|
|
||||||||||||
| # Prepare rowwise tensors — for 2D scaling, M is in dim0 of both data and scale_inv, | ||||||||||||
| # so they naturally align with FSDP2's dim0 all-gather. No unpadding needed. | ||||||||||||
| rowwise_data = self._rowwise_data | ||||||||||||
| rowwise_scale_inv = self._rowwise_scale_inv | ||||||||||||
|
|
||||||||||||
| # Prepare columnwise tensors — columnwise data is transposed (K, M) and | ||||||||||||
| # columnwise scale_inv is (ceil(K/128), round_up(ceil(M/128), 4)). | ||||||||||||
| # M is in dim1 for both, so we must transpose to put M in dim0 for all-gather. | ||||||||||||
| columnwise_data = self._columnwise_data | ||||||||||||
| columnwise_scale_inv = self._columnwise_scale_inv | ||||||||||||
|
|
||||||||||||
| if columnwise_data is not None: | ||||||||||||
| # Transpose (K, shard_M) -> (shard_M, K) so M is in dim0 | ||||||||||||
| columnwise_data = columnwise_data.t().contiguous() | ||||||||||||
|
|
||||||||||||
| if columnwise_scale_inv is not None: | ||||||||||||
| # Original shape: (ceil(K/128), round_up(ceil(shard_M/128), 4)) | ||||||||||||
| # Strip padding from dim1 (the M-block dimension), transpose, then all-gather | ||||||||||||
| shard_M = math.prod(self.shape[:-1]) | ||||||||||||
| m_blocks = (shard_M + block_len - 1) // block_len # ceil(shard_M/128) | ||||||||||||
| columnwise_scale_inv = columnwise_scale_inv[:, :m_blocks] # unpad dim1 | ||||||||||||
| columnwise_scale_inv = columnwise_scale_inv.t().contiguous() # (m_blocks, k_blocks) | ||||||||||||
|
|
||||||||||||
| # Always send both rowwise and columnwise data. | ||||||||||||
| # Unlike MXFP8 (where both forms share the same shape), Float8Blockwise has | ||||||||||||
| # differently-shaped rowwise (M, K) and columnwise (K, M) data. The GEMM kernel | ||||||||||||
| # needs both forms available to perform forward and backward operations, so we | ||||||||||||
| # cannot optimize by sending only one usage based on forward/backward pass. | ||||||||||||
| rowwise_usage = True | ||||||||||||
| sharded_tensors = (rowwise_data, rowwise_scale_inv) | ||||||||||||
| columnwise_usage = self._quantizer.columnwise_usage | ||||||||||||
| if columnwise_usage: | ||||||||||||
| sharded_tensors += (columnwise_data, columnwise_scale_inv) | ||||||||||||
| if self._rowwise_data is None or self._rowwise_scale_inv is None: | ||||||||||||
| raise RuntimeError( | ||||||||||||
| "Rowwise data must be available for FSDP2 all-gather with 2D block scaling." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| fsdp_state = _get_module_fsdp_state(module) | ||||||||||||
| param_group = fsdp_state._fsdp_param_group | ||||||||||||
| if param_group is None: | ||||||||||||
| raise RuntimeError( | ||||||||||||
| "FSDP state for this module has no parameter group; " | ||||||||||||
| "cannot determine reshard_after_forward." | ||||||||||||
| ) | ||||||||||||
| reshard_after_forward = param_group._reshard_after_forward | ||||||||||||
|
|
||||||||||||
| # If weights are resharded after forward pass, only the relevant usage | ||||||||||||
| # is needed based on whether it's a forward or backward pass. | ||||||||||||
| # If not resharded, the same all-gathered weights are reused in backward, | ||||||||||||
| # so both usages may be needed. | ||||||||||||
| if reshard_after_forward: | ||||||||||||
| training_state = param_group._training_state | ||||||||||||
| is_backward_pass = training_state == TrainingState.PRE_BACKWARD | ||||||||||||
| rowwise_usage = not is_backward_pass | ||||||||||||
| columnwise_usage = is_backward_pass | ||||||||||||
| else: | ||||||||||||
| rowwise_usage = True | ||||||||||||
| columnwise_usage = self._quantizer.columnwise_usage | ||||||||||||
|
Comment on lines
+656
to
+663
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usageThis means whether columnwise data gets derived locally (and kept) is entirely controlled by the sharded quantizer's setting, not the actual pass. The comment in the previous code explicitly noted that both forms were needed when not resharding. If This matches the pre-existing |
||||||||||||
|
|
||||||||||||
| # For 2D block scaling (128x128 blocks), columnwise data and scales are | ||||||||||||
| # the transpose of rowwise data and scales. Only all-gather the rowwise | ||||||||||||
| # tensors; columnwise will be derived locally via _create_columnwise() | ||||||||||||
| # in post_all_gather, halving all-gather communication volume. | ||||||||||||
| sharded_tensors = (self._rowwise_data, self._rowwise_scale_inv) | ||||||||||||
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
| metadata = (self._fp8_dtype, self._is_2D_scaled, rowwise_usage, columnwise_usage) | ||||||||||||
| return sharded_tensors, metadata | ||||||||||||
|
|
||||||||||||
|
|
@@ -694,59 +691,35 @@ def fsdp_post_all_gather( | |||||||||||
| """ | ||||||||||||
| fp8_dtype, is_2D_scaled, rowwise_usage, columnwise_usage = metadata | ||||||||||||
|
|
||||||||||||
| # Extract rowwise tensors from all-gather outputs | ||||||||||||
| rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None) | ||||||||||||
|
|
||||||||||||
| # Extract columnwise tensors — they were transposed in pre_all_gather, | ||||||||||||
| # so we need to transpose them back. | ||||||||||||
| columnwise_data, columnwise_scale_inv = ( | ||||||||||||
| all_gather_outputs[-2:] if columnwise_usage else (None, None) | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| if columnwise_data is not None: | ||||||||||||
| # All-gathered shape is (full_M, K), transpose back to (K, full_M) | ||||||||||||
| columnwise_data = columnwise_data.t().contiguous() | ||||||||||||
|
|
||||||||||||
| if columnwise_scale_inv is not None: | ||||||||||||
| # All-gathered shape is (full_m_blocks, k_blocks), | ||||||||||||
| # transpose back to (k_blocks, full_m_blocks) | ||||||||||||
| columnwise_scale_inv = columnwise_scale_inv.t().contiguous() | ||||||||||||
| # Repad dim1 (M-block dimension) to multiple of 4 for GEMM alignment | ||||||||||||
| current_m_blocks = columnwise_scale_inv.shape[1] | ||||||||||||
| pad_amount = (4 - current_m_blocks % 4) % 4 | ||||||||||||
| if pad_amount > 0: | ||||||||||||
| columnwise_scale_inv = torch.nn.functional.pad( | ||||||||||||
| columnwise_scale_inv, (0, pad_amount) | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Determine the logical shape from the all-gathered data | ||||||||||||
| if rowwise_data is not None: | ||||||||||||
| data_shape = rowwise_data.shape | ||||||||||||
| else: | ||||||||||||
| # columnwise_data is (K, full_M), logical shape is (full_M, K) | ||||||||||||
| data_shape = (columnwise_data.shape[1], columnwise_data.shape[0]) | ||||||||||||
| # Only rowwise data+scales were all-gathered (columnwise is derived locally). | ||||||||||||
| rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] | ||||||||||||
| data_shape = rowwise_data.shape | ||||||||||||
|
|
||||||||||||
| if out is not None: | ||||||||||||
| # Update existing tensor in-place (subsequent iterations) | ||||||||||||
| out._rowwise_data = rowwise_data | ||||||||||||
| out._rowwise_scale_inv = rowwise_scale_inv | ||||||||||||
| out._columnwise_data = columnwise_data | ||||||||||||
| out._columnwise_scale_inv = columnwise_scale_inv | ||||||||||||
| else: | ||||||||||||
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
| # Construct new tensor (first iteration). | ||||||||||||
| # Float8BlockwiseQTensor constructor copies the quantizer, | ||||||||||||
| # so the sharded tensor's quantizer remains independent. | ||||||||||||
| out = Float8BlockwiseQTensor( | ||||||||||||
| shape=data_shape, | ||||||||||||
| dtype=param_dtype, | ||||||||||||
| fp8_dtype=fp8_dtype, | ||||||||||||
| rowwise_data=rowwise_data, | ||||||||||||
| rowwise_scale_inv=rowwise_scale_inv, | ||||||||||||
| columnwise_data=columnwise_data, | ||||||||||||
| columnwise_scale_inv=columnwise_scale_inv, | ||||||||||||
| columnwise_data=None, | ||||||||||||
| columnwise_scale_inv=None, | ||||||||||||
| quantizer=self._quantizer, | ||||||||||||
| is_2D_scaled=is_2D_scaled, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # For 2D block scaling, derive columnwise data and scales from rowwise | ||||||||||||
| # via local fp8 transpose. | ||||||||||||
| if columnwise_usage: | ||||||||||||
| out._create_columnwise() | ||||||||||||
| # remove usages if not needed. | ||||||||||||
| out.update_usage( | ||||||||||||
| rowwise_usage=rowwise_usage, | ||||||||||||
| columnwise_usage=columnwise_usage, | ||||||||||||
| ) | ||||||||||||
| out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) | ||||||||||||
| return out, all_gather_outputs | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.