Conversation
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
|
/te-ci L1 pytorch |
Greptile SummaryThis PR adds FSDP2 The implementation correctly handles NVFP4-specific layout details — FP4 packing ( Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant FSDP2
participant NVFP4Tensor
participant AllGather
Note over FSDP2,NVFP4Tensor: Before all-gather (per shard)
FSDP2->>NVFP4Tensor: fsdp_pre_all_gather()
NVFP4Tensor->>NVFP4Tensor: assert shard_M % 16 == 0
NVFP4Tensor->>NVFP4Tensor: rowwise_scale[:shard_M, :] (unpad dim0)
NVFP4Tensor->>NVFP4Tensor: columnwise_data.t() (K,M//2) → (M//2,K)
NVFP4Tensor->>NVFP4Tensor: columnwise_scale[:, :m_blocks].t() (m_blocks,K_pad)
NVFP4Tensor-->>FSDP2: sharded_tensors=(rw_data, rw_scale, cw_data, cw_scale)<br/>metadata=(fp4_dtype, usages, amax_row, amax_col, K)
Note over FSDP2,AllGather: All-gather across ranks
FSDP2->>AllGather: concat sharded_tensors along dim0
AllGather-->>FSDP2: all_gather_outputs (full_M across all ranks)
Note over FSDP2,NVFP4Tensor: After all-gather (full param)
FSDP2->>NVFP4Tensor: fsdp_post_all_gather(all_gather_outputs, metadata)
NVFP4Tensor->>NVFP4Tensor: rowwise_scale repad dim0 → round_up(full_M,128)
NVFP4Tensor->>NVFP4Tensor: columnwise_data.t() (full_M//2,K) → (K,full_M//2)
NVFP4Tensor->>NVFP4Tensor: columnwise_scale.t() + repad dim1 → round_up(ceil(M/16),4)
NVFP4Tensor->>NVFP4Tensor: amax values injected from metadata (not all-gathered)
NVFP4Tensor-->>FSDP2: NVFP4Tensor(shape=(full_M, K)) or in-place out update
Note over FSDP2,NVFP4Tensor: Dispatch handlers (called by FSDP2 internals)
FSDP2->>NVFP4Tensor: aten.as_strided (identity noop)
NVFP4Tensor-->>FSDP2: make_like(tensor)
FSDP2->>NVFP4Tensor: aten.slice.Tensor (full-dim noop)
NVFP4Tensor-->>FSDP2: make_like(tensor)
FSDP2->>NVFP4Tensor: aten.record_stream
NVFP4Tensor->>NVFP4Tensor: record_stream on each sub-tensor
NVFP4Tensor-->>FSDP2: None
Reviews (3): Last reviewed commit: "addresses greptile comment" | Re-trigger Greptile |
| torch.manual_seed(42) | ||
| torch.cuda.manual_seed(42) | ||
|
|
||
| @pytest.mark.parametrize("shape", _test_shapes) |
There was a problem hiding this comment.
setup_class will TypeError at runtime
setup_class is decorated with @staticmethod, but its signature accepts a cls parameter. Pytest invokes class-level setup by calling TestNVFP4FSDP2Hooks.setup_class() with no arguments. Because this is a static method, Python does not inject cls, so the call has 0 arguments against a 1-argument signature — raising TypeError: setup_class() missing 1 required positional argument: 'cls' before any test in the class executes.
The seed initialisation would silently never run. Change the decorator to @classmethod:
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) | |
| @pytest.mark.parametrize("shape", _test_shapes) | |
| @classmethod | |
| def setup_class(cls) -> None: | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) |
There was a problem hiding this comment.
▎ Good catch — fixed. Changed to @classmethod so pytest actually invokes the seed initialization.
| if func == aten.slice.Tensor: | ||
| tensor = args[0] | ||
| dim = args[1] | ||
| start = args[2] | ||
| length = args[3] | ||
| if start == 0 and length == tensor.size(dim): | ||
| return NVFP4Tensor.make_like(tensor) |
There was a problem hiding this comment.
slice.Tensor handler uses wrong variable name for end and skips step check
The fourth positional argument of aten.slice.Tensor is end, not length. The variable name is misleading, but more importantly the handler unconditionally indexes args[3] without checking whether a fifth step argument is provided and equals 1. A non-unit step with start=0, end=size would still pass the existing check and incorrectly return an unmodified NVFP4Tensor, silently skipping the stride.
| if func == aten.slice.Tensor: | |
| tensor = args[0] | |
| dim = args[1] | |
| start = args[2] | |
| length = args[3] | |
| if start == 0 and length == tensor.size(dim): | |
| return NVFP4Tensor.make_like(tensor) | |
| if func == aten.slice.Tensor: | |
| tensor = args[0] | |
| dim = args[1] if len(args) > 1 else 0 | |
| start = args[2] if len(args) > 2 else None | |
| end = args[3] if len(args) > 3 else None | |
| step = args[4] if len(args) > 4 else 1 | |
| if step == 1 and (start is None or start == 0) and (end is None or end >= tensor.size(dim)): | |
| return NVFP4Tensor.make_like(tensor) |
There was a problem hiding this comment.
Fixed. Renamed length → end, added defensive len(args) checks for optional positional args, and added step == 1 guard. Also handles the case where start/end are None (the ATen defaults
| # Always send both orientations (GEMM needs both for fwd/bwd) | ||
| rowwise_usage = True | ||
| sharded_tensors = (rowwise_data, rowwise_scale_inv) | ||
| columnwise_usage = self._quantizer.columnwise_usage | ||
| if columnwise_usage: | ||
| sharded_tensors += (columnwise_data, columnwise_scale_inv) |
There was a problem hiding this comment.
Hardcoded
rowwise_usage = True can silently pass None to FSDP2
rowwise_usage is hardcoded to True without checking whether _rowwise_data is actually populated. If an NVFP4Tensor is created with only columnwise data (rowwise=False), rowwise_data will be None but it will still be included in sharded_tensors. FSDP2 would then try to all-gather None, which will crash or corrupt silently.
Consider guarding this with an explicit check:
rowwise_usage = self._rowwise_data is not NoneOr, if both orientations are always required for FSDP2 training, add an assertion to catch misconfigured tensors early:
rowwise_usage = True
assert self._rowwise_data is not None, (
"FSDP2 requires rowwise data, but _rowwise_data is None. "
"Ensure the NVFP4Quantizer was created with rowwise=True."
)There was a problem hiding this comment.
Added an assertion to fail early if _rowwise_data is None:
▎ assert self._rowwise_data is not None, (
▎ "FSDP2 requires rowwise data, but _rowwise_data is None. "
▎ "Ensure the NVFP4Quantizer was created with rowwise=True."
▎ )
|
|
||
| # Always send both orientations (GEMM needs both for fwd/bwd) |
There was a problem hiding this comment.
@jomitchellnv we shouldnt need both orientations. For forward pass only rowwise is needed and for backward pass only columnwise is needed.
Can you refer to this PR for the optimizations
#2789
Right now the columnwise allhgather implementation is transposing, allgathering in pre_allgather and then again transposing after post allgather which is expensive.
You can make it better one of the two ways for columnwise allgather
- Always allgather rowwise data similar to above PR. And if columnwise usage is needed then you can just transpose it in post allgather
- allgather only columnwise data if columnwise usage is set and then interleave the the columnwise stacked data.
1 is fine to have for now.
I would also add assertions for 2d Scaling similar to fp8 blockscaling implemntation.
There was a problem hiding this comment.
Thanks for the pointer to #2789 — that makes sense. I'll adopt option 1: only all-gather rowwise data +
scales, and derive columnwise locally in post_all_gather if columnwise_usage is set. Will update.
On the 2D scaling assertion — NVFP4 doesn't have _is_2D_scaled like Float8Blockwise, but I can add a guar on with_2d_quantization. Could you clarify what layout constraint you're thinking of? The NVFP4 rowwise scale has M in dim0 (round_up(M, 128), ...) regardless of that flag, so it should be compatible with dim0 all-gather either way. Want to make sure I add the right check.
There was a problem hiding this comment.
Yes, exactly with_2d_quantization should be asserted to be True. If with_2d_quantization is False then columnwise is not derivable from rowwise data/scales(Although this is not a common use-case)
for more information, see https://pre-commit.ci
37a1b01 to
abb7b04
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
178c7c3 to
238f2df
Compare
for more information, see https://pre-commit.ci
| if current_m_blocks < target_m_blocks: | ||
| columnwise_scale_inv = torch.nn.functional.pad( | ||
| columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) | ||
| ) | ||
| elif current_m_blocks > target_m_blocks: | ||
| columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks] |
There was a problem hiding this comment.
Silent data corruption when
shard_M is not a multiple of NVFP4_BLOCK_SCALING_SIZE
The elif current_m_blocks > target_m_blocks trim branch silently discards valid scale data rather than surfacing an alignment error.
Here is why this happens: fsdp_pre_all_gather computes m_blocks = ceil(shard_M / 16) per shard. After all-gather, the concatenated total is world_size * ceil(shard_M / 16). The target is round_up(ceil(full_M / 16), 4). When shard_M % 16 != 0, world_size * ceil(shard_M / 16) > ceil(full_M / 16), and the trim condition fires.
Concrete example: full_M = 136, world_size = 8, shard_M = 17.
- Each rank:
ceil(17/16) = 2m-blocks - All-gathered:
16m-blocks - Target:
round_up(ceil(136/16), 4) = round_up(9, 4) = 12 eliftrims from 16 → 12, discarding 4 real m-block scale columns
Beyond the trim, this represents a deeper problem: each rank's m-block 0 covers its local rows 0–15, not the same global rows. When shard_M % 16 != 0, the per-rank m-blocks do not align with global m-blocks, and the all-gathered scale tensor is fundamentally scrambled before the trim even runs.
The fix is to assert the constraint at the top of fsdp_pre_all_gather rather than silently masking the symptom:
assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"FSDP2 requires shard_M ({shard_M}) to be a multiple of "
f"NVFP4_BLOCK_SCALING_SIZE ({NVFP4_BLOCK_SCALING_SIZE}). "
"Ensure the weight's row count is divisible by world_size * 16."
)And the elif trim should then become an unreachable assertion (or be removed), so any unexpected mismatch is caught loudly:
| if current_m_blocks < target_m_blocks: | |
| columnwise_scale_inv = torch.nn.functional.pad( | |
| columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) | |
| ) | |
| elif current_m_blocks > target_m_blocks: | |
| columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks] | |
| if current_m_blocks < target_m_blocks: | |
| columnwise_scale_inv = torch.nn.functional.pad( | |
| columnwise_scale_inv, (0, target_m_blocks - current_m_blocks) | |
| ) | |
| else: | |
| assert current_m_blocks == target_m_blocks, ( | |
| f"Unexpected m_block count after all-gather: " | |
| f"got {current_m_blocks}, expected {target_m_blocks}" | |
| ) |
There was a problem hiding this comment.
Fixed. Added an assertion in fsdp_pre_all_gather that shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, which
guarantees m-blocks align across ranks after all-gather. Converted the elif trim to an unreachable assertion —
if the shard alignment holds, current_m_blocks can never exceed target_m_blocks.
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
| if func == aten.as_strided.default: | ||
| tensor = args[0] | ||
| shape = args[1] | ||
| strides = args[2] | ||
| if ( | ||
| len(shape) == len(strides) == 2 | ||
| and tuple(strides) == (shape[-1], 1) | ||
| and tuple(shape) == tuple(tensor.size()) | ||
| ): | ||
| return NVFP4Tensor.make_like(tensor) |
There was a problem hiding this comment.
as_strided handler silently falls through for non-standard calls
When the shape/stride check does not match (e.g., FSDP2 applies as_strided with a non-unit stride, a shape mismatch, or a non-zero storage offset), the handler does not return anything and falls through to super().__torch_dispatch__(). That super call is unlikely to know how to handle an NVFP4Tensor, so the result will be wrong or raise an obscure error.
The storage_offset argument (args[3] when present) is also never inspected — if FSDP2 ever supplies a non-zero offset, the identity check still fires and silently returns the unmodified tensor, ignoring the offset.
Since the intent is to handle only the no-op identity case and let everything else fall through deliberately, consider making that explicit with a comment or raising NotImplementedError for the non-identity case so failures are surfaced early rather than producing undefined behaviour from the super dispatch.
Description
Summary
FSDP2 training with NVFP4BlockScaling
operations
Details
Without FSDP2 hooks, FSDP2 attempts data_ptr() on the NVFP4Tensor wrapper subclass and crashes. This PR follows
the Float8BlockwiseQTensor FSDP2 hooks pattern (the closest analog since NVFP4 also stores columnwise data
transposed), with NVFP4-specific adjustments:
round_up(ceil(M/16), 4) — both unpadded before all-gather and repadded after
since they're scalar and identical across ranks
Test plan
integrity, dequantize correctness, in-place update path, swizzled-scale rejection, and dispatch handlers
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -k "fp8_master_weights and NVFP4" — multi-GPU
FSDP2 + NVFP4 integration
FSDP2 fused_adam regression
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: