Skip to content

[DRAFT] feat(zero2): add CPU offload support for Muon optimizer#7939

Draft
delock wants to merge 2 commits intodeepspeedai:masterfrom
delock:gma/muon_cpuoffload
Draft

[DRAFT] feat(zero2): add CPU offload support for Muon optimizer#7939
delock wants to merge 2 commits intodeepspeedai:masterfrom
delock:gma/muon_cpuoffload

Conversation

@delock
Copy link
Copy Markdown
Collaborator

@delock delock commented Mar 31, 2026

Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path by:

  1. If a parameter is cross parititon boundary, the full param grad instead of partial param grad will be copied to CPU.
  2. muon momentum buffer will be stored on CPU memory to save GPU memory.
  3. muon update will be done on CPU, utilizing CPU matmul.

@delock delock marked this pull request as draft March 31, 2026 07:02
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 54364fbe9a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

pad_tensor = torch.zeros(padded_size - self.bit16_groups_flat[i].numel(),
dtype=self.bit16_groups_flat[i].dtype,
device=self.bit16_groups_flat[i].device)
self.bit16_groups_flat[i] = torch.cat([self.bit16_groups_flat[i], pad_tensor])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Insert per-partition padding before Muon equal split

Appending a single padding block at the tail does not guarantee parameter-boundary partitioning: when an earlier partition is smaller than max_partition_size (e.g., sizes [4,5,1] for dp=3), get_data_parallel_partitions() still cuts at fixed max_partition_size offsets and splits a parameter across ranks. That breaks the new CPU-offload Muon path, which assumes unsplit parameters and writes a full update.view(-1) into a partition slice computed from grad_position, leading to shape mismatch or incorrect updates when source_offset != 0.

Useful? React with 👍 / 👎.

if self._is_muon_param_group(i):
dp_size = dist.get_world_size(group=self.real_dp_process_group[i])
max_ps = self._get_muon_max_partition_size(self.round_robin_bit16_groups[i], dp_size, orig_group_numel)
padded_size = max_ps * dp_size
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep Muon partition size aligned for NCCL boundaries

max_partition_size is used directly to set padded_size, but it is not rounded to the existing NCCL start-alignment factor. If max_partition_size is odd with fp16/bf16 tensors, partition starts after rank 0 become 2-byte shifted and fail the existing 4-byte alignment assertion in the same initialization flow. This makes valid Muon configurations crash depending on parameter shapes.

Useful? React with 👍 / 👎.

@delock delock force-pushed the gma/muon_cpuoffload branch 2 times, most recently from d802f0e to c058864 Compare March 31, 2026 10:07
delock added 2 commits March 31, 2026 03:07
Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path by:

1. Partition strategy: Muon param groups now partition by parameter
   boundaries (never split a param across ranks), padding to uniform
   max size for all-gather compatibility. Logs padding overhead ratio.

2. CPU Newton-Schulz: Add muon_update_cpu() and
   zeropower_via_newtonschulz5_cpu() using PyTorch CPU bf16 matmul
   as baseline. Architecture allows future replacement with AMX C++ kernel.

3. CPU offload integration: _apply_muon_update_for_cpu_offload() copies
   complete gradients to CPU, runs muon_update on CPU (momentum buffer
   stays on CPU), writes result to FP32 grad buffer. No extra PCIe transfers.

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Some CPUs lack hardware bf16 matmul support (AMX/AVX-512-BF16), causing
NS iterations to be ~800x slower than fp32 via MKL. This change uses
fp32 if CPU does not support bf16, reducing CPU offload NS time from
~18s to ~24ms for 512x2048 matrices.

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant