Optimize fp8 block scaling Allgather for FSDP2#2789
Optimize fp8 block scaling Allgather for FSDP2#2789vthumbe1503 wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Greptile SummaryThis PR introduces a communication-halving optimization for FP8 block-scaled weights under PyTorch FSDP2: instead of all-gathering both rowwise and columnwise tensors, only rowwise data and scales are all-gathered, and the columnwise view is derived locally via Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant FSDP2
participant fsdp_pre_all_gather
participant AllGather
participant fsdp_post_all_gather
participant _create_columnwise
FSDP2->>fsdp_pre_all_gather: call (module, mesh, …)
fsdp_pre_all_gather->>fsdp_pre_all_gather: check _fsdp_param_group != None
fsdp_pre_all_gather->>fsdp_pre_all_gather: read reshard_after_forward & training_state
alt reshard_after_forward=True, forward pass
fsdp_pre_all_gather-->>FSDP2: sharded=(rowwise_data, rowwise_scale_inv), usage=(row=T, col=F)
else reshard_after_forward=True, PRE_BACKWARD
fsdp_pre_all_gather-->>FSDP2: sharded=(rowwise_data, rowwise_scale_inv), usage=(row=F, col=T)
else reshard_after_forward=False
fsdp_pre_all_gather-->>FSDP2: sharded=(rowwise_data, rowwise_scale_inv), usage=(row=T, col=quantizer.cw)
end
FSDP2->>AllGather: all-gather rowwise_data + rowwise_scale_inv
AllGather-->>FSDP2: full rowwise_data, full rowwise_scale_inv
FSDP2->>fsdp_post_all_gather: call (all_gather_outputs, metadata, param_dtype, out)
fsdp_post_all_gather->>fsdp_post_all_gather: extract rowwise_data, rowwise_scale_inv
alt out is None (first iteration)
fsdp_post_all_gather->>fsdp_post_all_gather: construct Float8BlockwiseQTensor (cw=None)
else out is not None (subsequent iterations)
fsdp_post_all_gather->>fsdp_post_all_gather: update out._rowwise_data / scale_inv in-place
end
alt columnwise_usage=True
fsdp_post_all_gather->>_create_columnwise: derive columnwise via fp8_transpose (reuse buffer)
_create_columnwise-->>fsdp_post_all_gather: _columnwise_data, _columnwise_scale_inv set
end
fsdp_post_all_gather->>fsdp_post_all_gather: update_usage(row, col) — clears unused form
fsdp_post_all_gather-->>FSDP2: (Float8BlockwiseQTensor, all_gather_outputs)
Reviews (5): Last reviewed commit: "Merge branch 'main' into optimize_fp8_bl..." | Re-trigger Greptile |
| fsdp_state = _get_module_fsdp_state(module) | ||
| reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward |
There was a problem hiding this comment.
Unguarded access to
_fsdp_param_group
fsdp_state._fsdp_param_group is typed as Optional[FSDPParamGroup] in PyTorch's FSDP2 internals — it is None for any FSDP module that does not directly manage parameters (e.g. a container module whose children are individually sharded). Accessing ._reshard_after_forward on it unconditionally will raise AttributeError: 'NoneType' object has no attribute '_reshard_after_forward' in that case.
While in practice fsdp_pre_all_gather is only called for tensors managed by a param group, this assumption is implicit. A guard makes the failure mode explicit and easier to diagnose:
fsdp_state = _get_module_fsdp_state(module)
param_group = fsdp_state._fsdp_param_group
if param_group is None:
raise RuntimeError(
"FSDP state for this module has no parameter group; "
"cannot determine reshard_after_forward."
)
reshard_after_forward = param_group._reshard_after_forwardSigned-off-by: Varun Thumbe <vthumbe@nvidia.com>
…03/TransformerEngine into optimize_fp8_blockwise_scaling
|
/te-ci L1 pytorch |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Remove unnecessary columnwise data and scale inv assignments. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
| # PyTorch FSDP2 private API – tested with PyTorch 2.5+; | ||
| from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState |
There was a problem hiding this comment.
Inconsistent import style for
TrainingState
TrainingState is imported at the module level (line 10) in float8_tensor.py and at line 13 in mxfp8_tensor.py, but here it's imported lazily inside fsdp_pre_all_gather. While the inline comment about the private API and PyTorch version is valuable, the inconsistency across the three sibling files may confuse readers.
Consider either:
- Moving the
TrainingStateimport to the module level and placing the version comment there (matching the other two files), or - Adding the same lazy-import pattern and version comment to
float8_tensor.pyandmxfp8_tensor.pyfor symmetry.
| # PyTorch FSDP2 private API – tested with PyTorch 2.5+; | |
| from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState | |
| # PyTorch FSDP2 private API – tested with PyTorch 2.5+; | |
| from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState | |
| from transformer_engine.pytorch.distributed import _get_module_fsdp_state |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if reshard_after_forward: | ||
| training_state = param_group._training_state | ||
| is_backward_pass = training_state == TrainingState.PRE_BACKWARD | ||
| rowwise_usage = not is_backward_pass | ||
| columnwise_usage = is_backward_pass | ||
| else: | ||
| rowwise_usage = True | ||
| columnwise_usage = self._quantizer.columnwise_usage |
There was a problem hiding this comment.
columnwise_usage not derived from training state in non-resharded path
When reshard_after_forward=False, the same all-gathered weight is reused through both forward and backward passes. The code sets:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usageThis means whether columnwise data gets derived locally (and kept) is entirely controlled by the sharded quantizer's setting, not the actual pass. The comment in the previous code explicitly noted that both forms were needed when not resharding. If self._quantizer.columnwise_usage is False (e.g. on an architecture that doesn't need the transpose), columnwise data won't be created and won't be available for the backward pass GEMM.
This matches the pre-existing float8_tensor.py behavior (same pattern there), so it's presumably already validated by the existing usage assumptions — but it would be worth a brief comment here documenting that self._quantizer.columnwise_usage must be True whenever the backward GEMM needs columnwise access for the non-resharding path.
|
Btw I don't see any test files that were updated. I'd expect a test under tests/pytorch/fsdp/ or similar validating that the locally-derived columnwise output matches the old all-gathered columnwise output. |
jomitchellnv
left a comment
There was a problem hiding this comment.
LGTM just hoping theres some test coverage around this new implementation. i think i wrote some last time but not sure
|
/te-ci L1 pytorch |
ksivaman
left a comment
There was a problem hiding this comment.
LGTM, CI pending, consistent with other recipes
Description
Eliminate Columnwise allgather for fp8_model_init with fsdp2. For weights when FP8 blockscaling is used, we typically use 2d. And in such a case, columnwise data and scale inv is just the transpose of the rowwise data and scale inverse. And so allgathering the rowwise data/scales are enough
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: