If model parameters are DTensors, optimizer states should also be DTensors.#2795
If model parameters are DTensors, optimizer states should also be DTensors.#2795cspades wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
…sor. Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Greptile SummaryThis PR fixes a bug where Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A{param is DTensor?} -->|Yes| B[Extract local_param via _local_tensor]
A -->|No| C[local_param = param]
B --> D[Create local data tensor via torch.empty_like]
C --> D
D --> E{dtype == uint8?}
E -->|Yes| F[quantizer.make_empty with data.shape]
E -->|No| G[state = data]
F --> H[state.quantize_ with local data]
H --> I{param is DTensor?}
G --> I
I -->|Yes| J[DTensor.from_local with global shape and stride]
I -->|No| K[state stored as plain tensor]
J --> L[state stored as DTensor]
M[step / get_unscaled_state / set_scaled_state] --> N{state is DTensor?}
N -->|Yes| O[Unwrap via _local_tensor]
N -->|No| P[Use state directly]
O --> Q[Pass local tensor to kernel / scale / copy]
P --> Q
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
|
@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor? |
Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?
The fix is to keep the optimizer state in |
Add Greptile bug-fixes. Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
|
/te-ci L1 pytorch -> TBA |
for more information, see https://pre-commit.ci
Description
FusedAdam's optimizer state is converted into a non-distributed Tensor, which is loaded as a global / un-sharded state dictionary by Torch DCP. We wrap the optimizer state as a DTensor matching the distribution characteristics of the original DTensor parameter the state is associated with.Fixes a bug introduced by the new
DTensor(QuantizedTensor)(FSDP2-only) use case introduced in #2698 (as Megatron-FSDP just usesDTensor(Float32)for the distributed optimizer state).Testing
--use-precision-aware-optimizer(TBA)Type of change
Checklist: