Skip to content

If model parameters are DTensors, optimizer states should also be DTensors.#2795

Open
cspades wants to merge 7 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp
Open

If model parameters are DTensors, optimizer states should also be DTensors.#2795
cspades wants to merge 7 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Mar 24, 2026

Description

  • There is a bug where if the model parameters (either FSDP2 distributed parameters or Megatron-FSDP main weight DTensors) are DTensors, then FusedAdam's optimizer state is converted into a non-distributed Tensor, which is loaded as a global / un-sharded state dictionary by Torch DCP. We wrap the optimizer state as a DTensor matching the distribution characteristics of the original DTensor parameter the state is associated with.

Fixes a bug introduced by the new DTensor(QuantizedTensor) (FSDP2-only) use case introduced in #2698 (as Megatron-FSDP just uses DTensor(Float32) for the distributed optimizer state).

Testing

  • TE CI/CD
TE_PATH=/workspace/TransformerEngine ./qa/L1_pytorch_distributed_unittest/test.sh
  • Megatron-LM + --use-precision-aware-optimizer (TBA)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cspades and others added 3 commits March 24, 2026 09:48
@cspades cspades marked this pull request as ready for review March 24, 2026 17:42
@cspades
Copy link
Member Author

cspades commented Mar 24, 2026

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR fixes a bug where FusedAdam optimizer states were initialized as plain (non-distributed) tensors even when the corresponding model parameters were DTensors (arising from FSDP2 distributed parameters or Megatron-FSDP main-weight DTensors). Plain optimizer states cause Torch DCP to treat them as global/un-sharded tensors during checkpointing, leading to incorrect checkpoint behavior.

Key changes:

  • _initialize_state: Optimizer states are now re-wrapped as DTensor via DTensor.from_local after initialization when the originating parameter is a DTensor. Also fixes a pre-existing shape bug where param.shape (global) was used instead of data.shape (local shard) when allocating the FP8 (uint8) optimizer state — previously this would raise a shape-mismatch error on sharded FP8 parameters.
  • get_unscaled_state / set_scaled_state: Both now unwrap DTensor states to their local shard before performing scale/copy operations, ensuring the underlying CUDA kernels receive plain tensors.
  • step(): DTensor states are unwrapped before being appended to the kernel tensor lists. New assert guards validate that the parameter, its gradient, and all optimizer states either all are or all are not DTensors, catching mismatches early.
  • Minor docstring and comment cleanups throughout.

Confidence Score: 4/5

  • The PR is on a clear path to merge; the core bug fix is sound and the prior FP8 shape-mismatch concern is now correctly addressed with data.shape.
  • The DTensor wrapping/unwrapping logic is internally consistent across _initialize_state, get_unscaled_state, set_scaled_state, and step(). The previously flagged param.shapedata.shape regression is fixed. One minor style issue remains (verbose tensor printing in assertion messages) but it does not affect correctness. The complex DTensor(Float8Tensor) interaction warrants confidence from the CI/CD test pass mentioned in the PR.
  • The _initialize_state DTensor-wrapping block (lines 432–441) is the most novel and complex addition — worth a focused review pass, particularly the stride=param.stride() argument when optimizer state dtype differs from param dtype.

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Bug fix: optimizer states are now wrapped as DTensors when the corresponding parameter is a DTensor (FSDP2 / Megatron-FSDP). Includes the param.shapedata.shape correction for FP8 state allocation, consistent DTensor unwrapping in get_unscaled_state / set_scaled_state / step(), and new parity assertions in step().

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A{param is DTensor?} -->|Yes| B[Extract local_param via _local_tensor]
    A -->|No| C[local_param = param]
    B --> D[Create local data tensor via torch.empty_like]
    C --> D
    D --> E{dtype == uint8?}
    E -->|Yes| F[quantizer.make_empty with data.shape]
    E -->|No| G[state = data]
    F --> H[state.quantize_ with local data]
    H --> I{param is DTensor?}
    G --> I
    I -->|Yes| J[DTensor.from_local with global shape and stride]
    I -->|No| K[state stored as plain tensor]
    J --> L[state stored as DTensor]

    M[step / get_unscaled_state / set_scaled_state] --> N{state is DTensor?}
    N -->|Yes| O[Unwrap via _local_tensor]
    N -->|No| P[Use state directly]
    O --> Q[Pass local tensor to kernel / scale / copy]
    P --> Q
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@vthumbe1503
Copy link
Collaborator

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

@cspades
Copy link
Member Author

cspades commented Mar 24, 2026

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?

  • Add fused_adam, quantized_model_init, and fsdp2 example #2698 introduced logic during the FusedAdam.__init__ such that if the TE model parameters are DTensor, then it will change the optimizer state to normal Tensor.
    • The reason is because empty_like does not pick up the correct dtype from DTensor (from in-line commentary), when the local data is QuantizedTensor. Note that Megatron-FSDP's main weights are FP32, not QuantizedTensor, so our code worked with the original FusedAdam.
  • When Megatron-FSDP (or Megatron-LM's distributed optimizer) performs its first optimizer.step(), Megatron-FSDP exposes FP32 DTensor main weights to the FusedAdam optimizer, and because of the above logic, normal Tensor optimizer states are constructed from the DTensor main weights.
  • Megatron-FSDP depends on DTensor optimizer states for DCP checkpointing of FusedAdam's state, because we employ un-even sharding. Instead, it now sees normal Tensors, and this may break our DCP integration and/or un-even DTensor metadata.

The fix is to keep the optimizer state in DTensor form if the model is in DTensor form, and localize or perform in-place operations to the local Tensor for all FusedAdam operations.

cspades and others added 3 commits March 24, 2026 12:01
Add Greptile bug-fixes.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 24, 2026

/te-ci L1 pytorch -> TBA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants