Skip to content

adds NVFP4 Fused Adam support#2797

Open
jomitchellnv wants to merge 6 commits intoNVIDIA:mainfrom
jomitchellnv:jm/nvfp4-block-fused-adam
Open

adds NVFP4 Fused Adam support#2797
jomitchellnv wants to merge 6 commits intoNVIDIA:mainfrom
jomitchellnv:jm/nvfp4-block-fused-adam

Conversation

@jomitchellnv
Copy link
Contributor

Description

Summary

  • Add FSDP2 all-gather hooks (fsdp_pre_all_gather / fsdp_post_all_gather) to NVFP4Tensor, enabling end-to-end
    FSDP2 training with NVFP4BlockScaling
  • Add aten.as_strided, aten.slice, and aten.record_stream dispatch handlers required by FSDP2's internal tensor
    operations
  • Remove NVFP4BlockScaling xfails from FSDP2 integration tests now that the hooks are in place

Details

Without FSDP2 hooks, FSDP2 attempts data_ptr() on the NVFP4Tensor wrapper subclass and crashes. This PR follows
the Float8BlockwiseQTensor FSDP2 hooks pattern (the closest analog since NVFP4 also stores columnwise data
transposed), with NVFP4-specific adjustments:

  • FP4 packing: Data last dim is K//2 (two FP4 values packed per uint8 byte)
  • Scale dtype: uint8 (vs float32 for Float8Blockwise)
  • Block size: 16 (NVFP4_BLOCK_SCALING_SIZE) vs 128 for Float8Blockwise
  • Scale padding: Rowwise scale dim0 padded to round_up(M, 128), columnwise scale dim1 padded to
    round_up(ceil(M/16), 4) — both unpadded before all-gather and repadded after
  • Amax tensors: _amax_rowwise and _amax_columnwise (shape (1,)) passed via metadata rather than all-gathered,
    since they're scalar and identical across ranks

Test plan

  • pytest tests/pytorch/test_nvfp4_fsdp2_hooks.py -v — 19 single-GPU unit tests covering round-trip shapes, data
    integrity, dequantize correctness, in-place update path, swizzled-scale rejection, and dispatch handlers
  • PYTHONPATH=... torchrun --nproc_per_node=2 -m pytest
    tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -k "fp8_master_weights and NVFP4" — multi-GPU
    FSDP2 + NVFP4 integration
  • PYTHONPATH=... pytest tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_tests -v — full
    FSDP2 fused_adam regression

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Jonathan Mitchell added 2 commits March 24, 2026 12:31
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
@jomitchellnv
Copy link
Contributor Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR adds FSDP2 fsdp_pre_all_gather / fsdp_post_all_gather hooks to NVFP4Tensor, enabling end-to-end FSDP2 training with NVFP4BlockScaling, along with three __torch_dispatch__ handlers (aten.as_strided, aten.slice.Tensor, aten.record_stream) required by FSDP2's internal tensor operations.

The implementation correctly handles NVFP4-specific layout details — FP4 packing (K//2), uint8 scale dtype, block size 16, scale padding/unpadding for both rowwise and columnwise orientations, and columnwise data transposition for dim0 all-gather alignment. Previous review concerns (shard-alignment assertion, rowwise_data null-assertion, setup_class decorator, slice end/step fix) have all been addressed.

Key changes:

  • fsdp_pre_all_gather: unpads rowwise scale, transposes columnwise data/scale to put M-dim in dim0, passes amax via metadata (scalar, not all-gathered)
  • fsdp_post_all_gather: repads rowwise/columnwise scales, transposes columnwise data/scale back, supports both first-call (new tensor) and subsequent-call (in-place out) paths
  • Removes NVFP4BlockScaling xfails from two FSDP2 integration tests
  • 19 new single-GPU unit tests covering shapes, data integrity, dequantize correctness, in-place update, swizzled-scale rejection, and dispatch handlers

Confidence Score: 4/5

  • PR is on the happy path to merge; all prior review concerns are resolved and the core logic is correct.
  • All previously flagged issues (shard-alignment assertion, rowwise_data null assertion, setup_class classmethod fix, slice handler end/step fix) have been addressed. The new implementation correctly handles NVFP4-specific layout transforms for FSDP2. The only remaining items are a P2 unused import in the test file and a mild defensive-coding gap in the as_strided handler for non-identity calls — neither affects production correctness in the FSDP2 workflow.
  • No files require special attention beyond the minor cleanup in tests/pytorch/test_nvfp4_fsdp2_hooks.py (unused import) and the as_strided handler in nvfp4_tensor.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds fsdp_pre_all_gather / fsdp_post_all_gather hooks plus aten.as_strided, aten.slice.Tensor, and aten.record_stream dispatch handlers; logic is correct for the FSDP2 all-gather round-trip with minor defensive-coding gaps in the as_strided handler.
tests/pytorch/test_nvfp4_fsdp2_hooks.py New unit-test file with 19 tests covering shape correctness, data integrity, dequantize correctness, in-place update, swizzled-scale rejection, and dispatch handlers; contains one unused import (tex) that should be removed.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Removes the NVFP4BlockScaling xfail from test_fused_adam_fp8_master_weights now that the FSDP2 hooks are in place; straightforward removal with no issues.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Narrows the xfail guard from Float8BlockScaling or NVFP4BlockScaling to Float8BlockScaling only, correctly reflecting that NVFP4 FSDP2 is now supported; no issues.

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant NVFP4Tensor
    participant AllGather

    Note over FSDP2,NVFP4Tensor: Before all-gather (per shard)
    FSDP2->>NVFP4Tensor: fsdp_pre_all_gather()
    NVFP4Tensor->>NVFP4Tensor: assert shard_M % 16 == 0
    NVFP4Tensor->>NVFP4Tensor: rowwise_scale[:shard_M, :]  (unpad dim0)
    NVFP4Tensor->>NVFP4Tensor: columnwise_data.t()  (K,M//2) → (M//2,K)
    NVFP4Tensor->>NVFP4Tensor: columnwise_scale[:, :m_blocks].t()  (m_blocks,K_pad)
    NVFP4Tensor-->>FSDP2: sharded_tensors=(rw_data, rw_scale, cw_data, cw_scale)<br/>metadata=(fp4_dtype, usages, amax_row, amax_col, K)

    Note over FSDP2,AllGather: All-gather across ranks
    FSDP2->>AllGather: concat sharded_tensors along dim0
    AllGather-->>FSDP2: all_gather_outputs (full_M across all ranks)

    Note over FSDP2,NVFP4Tensor: After all-gather (full param)
    FSDP2->>NVFP4Tensor: fsdp_post_all_gather(all_gather_outputs, metadata)
    NVFP4Tensor->>NVFP4Tensor: rowwise_scale repad dim0 → round_up(full_M,128)
    NVFP4Tensor->>NVFP4Tensor: columnwise_data.t()  (full_M//2,K) → (K,full_M//2)
    NVFP4Tensor->>NVFP4Tensor: columnwise_scale.t() + repad dim1 → round_up(ceil(M/16),4)
    NVFP4Tensor->>NVFP4Tensor: amax values injected from metadata (not all-gathered)
    NVFP4Tensor-->>FSDP2: NVFP4Tensor(shape=(full_M, K)) or in-place out update

    Note over FSDP2,NVFP4Tensor: Dispatch handlers (called by FSDP2 internals)
    FSDP2->>NVFP4Tensor: aten.as_strided (identity noop)
    NVFP4Tensor-->>FSDP2: make_like(tensor)
    FSDP2->>NVFP4Tensor: aten.slice.Tensor (full-dim noop)
    NVFP4Tensor-->>FSDP2: make_like(tensor)
    FSDP2->>NVFP4Tensor: aten.record_stream
    NVFP4Tensor->>NVFP4Tensor: record_stream on each sub-tensor
    NVFP4Tensor-->>FSDP2: None
Loading

Reviews (3): Last reviewed commit: "addresses greptile comment" | Re-trigger Greptile

Comment on lines +68 to +71
torch.manual_seed(42)
torch.cuda.manual_seed(42)

@pytest.mark.parametrize("shape", _test_shapes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 setup_class will TypeError at runtime

setup_class is decorated with @staticmethod, but its signature accepts a cls parameter. Pytest invokes class-level setup by calling TestNVFP4FSDP2Hooks.setup_class() with no arguments. Because this is a static method, Python does not inject cls, so the call has 0 arguments against a 1-argument signature — raising TypeError: setup_class() missing 1 required positional argument: 'cls' before any test in the class executes.

The seed initialisation would silently never run. Change the decorator to @classmethod:

Suggested change
torch.manual_seed(42)
torch.cuda.manual_seed(42)
@pytest.mark.parametrize("shape", _test_shapes)
@classmethod
def setup_class(cls) -> None:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

▎ Good catch — fixed. Changed to @classmethod so pytest actually invokes the seed initialization.

Comment on lines +726 to +732
if func == aten.slice.Tensor:
tensor = args[0]
dim = args[1]
start = args[2]
length = args[3]
if start == 0 and length == tensor.size(dim):
return NVFP4Tensor.make_like(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 slice.Tensor handler uses wrong variable name for end and skips step check

The fourth positional argument of aten.slice.Tensor is end, not length. The variable name is misleading, but more importantly the handler unconditionally indexes args[3] without checking whether a fifth step argument is provided and equals 1. A non-unit step with start=0, end=size would still pass the existing check and incorrectly return an unmodified NVFP4Tensor, silently skipping the stride.

Suggested change
if func == aten.slice.Tensor:
tensor = args[0]
dim = args[1]
start = args[2]
length = args[3]
if start == 0 and length == tensor.size(dim):
return NVFP4Tensor.make_like(tensor)
if func == aten.slice.Tensor:
tensor = args[0]
dim = args[1] if len(args) > 1 else 0
start = args[2] if len(args) > 2 else None
end = args[3] if len(args) > 3 else None
step = args[4] if len(args) > 4 else 1
if step == 1 and (start is None or start == 0) and (end is None or end >= tensor.size(dim)):
return NVFP4Tensor.make_like(tensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Renamed length → end, added defensive len(args) checks for optional positional args, and added step == 1 guard. Also handles the case where start/end are None (the ATen defaults

Comment on lines +592 to +597
# Always send both orientations (GEMM needs both for fwd/bwd)
rowwise_usage = True
sharded_tensors = (rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
sharded_tensors += (columnwise_data, columnwise_scale_inv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded rowwise_usage = True can silently pass None to FSDP2

rowwise_usage is hardcoded to True without checking whether _rowwise_data is actually populated. If an NVFP4Tensor is created with only columnwise data (rowwise=False), rowwise_data will be None but it will still be included in sharded_tensors. FSDP2 would then try to all-gather None, which will crash or corrupt silently.

Consider guarding this with an explicit check:

rowwise_usage = self._rowwise_data is not None

Or, if both orientations are always required for FSDP2 training, add an assertion to catch misconfigured tensors early:

rowwise_usage = True
assert self._rowwise_data is not None, (
    "FSDP2 requires rowwise data, but _rowwise_data is None. "
    "Ensure the NVFP4Quantizer was created with rowwise=True."
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assertion to fail early if _rowwise_data is None:
▎ assert self._rowwise_data is not None, (
▎ "FSDP2 requires rowwise data, but _rowwise_data is None. "
▎ "Ensure the NVFP4Quantizer was created with rowwise=True."
▎ )

Comment on lines +591 to +592

# Always send both orientations (GEMM needs both for fwd/bwd)
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jomitchellnv we shouldnt need both orientations. For forward pass only rowwise is needed and for backward pass only columnwise is needed.

Can you refer to this PR for the optimizations
#2789
Right now the columnwise allhgather implementation is transposing, allgathering in pre_allgather and then again transposing after post allgather which is expensive.
You can make it better one of the two ways for columnwise allgather

  1. Always allgather rowwise data similar to above PR. And if columnwise usage is needed then you can just transpose it in post allgather
  2. allgather only columnwise data if columnwise usage is set and then interleave the the columnwise stacked data.
    1 is fine to have for now.
    I would also add assertions for 2d Scaling similar to fp8 blockscaling implemntation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer to #2789 — that makes sense. I'll adopt option 1: only all-gather rowwise data +
scales, and derive columnwise locally in post_all_gather if columnwise_usage is set. Will update.

On the 2D scaling assertion — NVFP4 doesn't have _is_2D_scaled like Float8Blockwise, but I can add a guar on with_2d_quantization. Could you clarify what layout constraint you're thinking of? The NVFP4 rowwise scale has M in dim0 (round_up(M, 128), ...) regardless of that flag, so it should be compatible with dim0 all-gather either way. Want to make sure I add the right check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly with_2d_quantization should be asserted to be True. If with_2d_quantization is False then columnwise is not derivable from rowwise data/scales(Although this is not a common use-case)

@jomitchellnv jomitchellnv force-pushed the jm/nvfp4-block-fused-adam branch from 37a1b01 to abb7b04 Compare March 24, 2026 21:43
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/nvfp4-block-fused-adam branch from 178c7c3 to 238f2df Compare March 24, 2026 21:50
Comment on lines +666 to +671
if current_m_blocks < target_m_blocks:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, target_m_blocks - current_m_blocks)
)
elif current_m_blocks > target_m_blocks:
columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent data corruption when shard_M is not a multiple of NVFP4_BLOCK_SCALING_SIZE

The elif current_m_blocks > target_m_blocks trim branch silently discards valid scale data rather than surfacing an alignment error.

Here is why this happens: fsdp_pre_all_gather computes m_blocks = ceil(shard_M / 16) per shard. After all-gather, the concatenated total is world_size * ceil(shard_M / 16). The target is round_up(ceil(full_M / 16), 4). When shard_M % 16 != 0, world_size * ceil(shard_M / 16) > ceil(full_M / 16), and the trim condition fires.

Concrete example: full_M = 136, world_size = 8, shard_M = 17.

  • Each rank: ceil(17/16) = 2 m-blocks
  • All-gathered: 16 m-blocks
  • Target: round_up(ceil(136/16), 4) = round_up(9, 4) = 12
  • elif trims from 16 → 12, discarding 4 real m-block scale columns

Beyond the trim, this represents a deeper problem: each rank's m-block 0 covers its local rows 0–15, not the same global rows. When shard_M % 16 != 0, the per-rank m-blocks do not align with global m-blocks, and the all-gathered scale tensor is fundamentally scrambled before the trim even runs.

The fix is to assert the constraint at the top of fsdp_pre_all_gather rather than silently masking the symptom:

assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, (
    f"FSDP2 requires shard_M ({shard_M}) to be a multiple of "
    f"NVFP4_BLOCK_SCALING_SIZE ({NVFP4_BLOCK_SCALING_SIZE}). "
    "Ensure the weight's row count is divisible by world_size * 16."
)

And the elif trim should then become an unreachable assertion (or be removed), so any unexpected mismatch is caught loudly:

Suggested change
if current_m_blocks < target_m_blocks:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, target_m_blocks - current_m_blocks)
)
elif current_m_blocks > target_m_blocks:
columnwise_scale_inv = columnwise_scale_inv[:, :target_m_blocks]
if current_m_blocks < target_m_blocks:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, target_m_blocks - current_m_blocks)
)
else:
assert current_m_blocks == target_m_blocks, (
f"Unexpected m_block count after all-gather: "
f"got {current_m_blocks}, expected {target_m_blocks}"
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added an assertion in fsdp_pre_all_gather that shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, which
guarantees m-blocks align across ranks after all-gather. Converted the elif trim to an unreachable assertion —
if the shard alignment holds, current_m_blocks can never exceed target_m_blocks.

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Comment on lines +727 to +736
if func == aten.as_strided.default:
tensor = args[0]
shape = args[1]
strides = args[2]
if (
len(shape) == len(strides) == 2
and tuple(strides) == (shape[-1], 1)
and tuple(shape) == tuple(tensor.size())
):
return NVFP4Tensor.make_like(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 as_strided handler silently falls through for non-standard calls

When the shape/stride check does not match (e.g., FSDP2 applies as_strided with a non-unit stride, a shape mismatch, or a non-zero storage offset), the handler does not return anything and falls through to super().__torch_dispatch__(). That super call is unlikely to know how to handle an NVFP4Tensor, so the result will be wrong or raise an obscure error.

The storage_offset argument (args[3] when present) is also never inspected — if FSDP2 ever supplies a non-zero offset, the identity check still fires and silently returns the unmodified tensor, ignoring the offset.

Since the intent is to handle only the no-op identity case and let everything else fall through deliberately, consider making that explicit with a comment or raising NotImplementedError for the non-identity case so failures are surfaced early rather than producing undefined behaviour from the super dispatch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants