Skip to content

[train] Enable expandable_segments to reduce GPU memory fragmentation#1470

Draft
CharlieFRuan wants to merge 3 commits intomainfrom
feat/expandable-segments
Draft

[train] Enable expandable_segments to reduce GPU memory fragmentation#1470
CharlieFRuan wants to merge 3 commits intomainfrom
feat/expandable-segments

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 7, 2026

Summary

Problem

PyTorch's default CUDA allocator uses fixed-size memory segments via cudaMalloc. Over multiple training steps, freed blocks leave fragmented gaps — memory that is reserved but unusable for large contiguous allocations. This is especially problematic for long-context training with MoE models where allocation patterns vary between steps due to expert routing.

The symptom: training passes step 1, but OOMs on step 2+ with an error like:

CUDA out of memory. Tried to allocate 18.91 GiB.
Of the allocated memory 124.24 GiB is allocated by PyTorch,
and 15.50 GiB is reserved by PyTorch but unallocated.

That 15.50 GiB of reserved-but-unallocated memory is fragmentation — wasted GPU memory.

Solution

Enable expandable_segments programmatically after model initialization on all training worker processes. PyTorch's expandable_segments uses CUDA Virtual Memory Management (VMM) to create segments that can grow — non-contiguous physical blocks appear contiguous to PyTorch, eliminating fragmentation.

Key design decisions

  1. Enable after model init, not at process start: Model weights are allocated with the standard allocator so they reside in standard CUDA memory, which is compatible with CUDA IPC (used for weight sync in colocated mode). Only subsequent allocations (activations during forward/backward) use expandable segments.

  2. Toggle around CUDA IPC weight sync for colocate_all=True: expandable_segments is incompatible with cudaIpcGetMemHandle (VMM-based virtual address mappings are process-local). Before weight sync, we toggle it OFF so that any new buffer allocations (packed tensors for IPC) use standard CUDA memory. After sync completes, we toggle it back ON. Regular NCCL communication (allreduce, allgather) is unaffected — NCCL uses its own internal buffers.

  3. No toggling needed for colocate_all=False: Weight sync uses NCCL broadcast (not CUDA IPC), which works fine with expandable segments.

  4. Clone in vLLM native IPC path: The new inference IPC path (_send_chunks_vllm_native) uses .clone() to ensure weight tensors are copied into the current allocator's memory space when expandable_segments is toggled OFF.

Benchmark Results

GLM-4.7-Flash, 64K context, TP=2 EP=4 CP=1, colocate_all=true, 8×B200 (183 GiB each)

Step use_expandable_segments=false use_expandable_segments=true
1 PASS (84s) PASS (109s)
2 OOM — 15.50 GiB fragmented PASS (75s)
3 PASS (77s)
4 PASS (79s)
5 PASS (79s)

E2E correctness (colocate_all=True with CUDA IPC weight sync)

test_policy_local_engines_e2e: 4/4 PASSED

The full weight sync + inference pipeline works: model init → enable expandable_segments → train → toggle OFF → CUDA IPC weight sync → toggle ON → vLLM inference with updated weights.

How to reproduce

# OOM without expandable_segments (step 2 OOMs):
MEGATRON_TP=2 MEGATRON_EP=4 MEGATRON_CP=1 \
MAX_PROMPT_LENGTH=8192 MAX_GENERATE_LENGTH=57344 NUM_DUMMY_STEPS=5 \
  bash run_full_ctx_b200_glm.sh \
  trainer.placement.use_expandable_segments=false

# Stable with expandable_segments (all 5 steps pass):
MEGATRON_TP=2 MEGATRON_EP=4 MEGATRON_CP=1 \
MAX_PROMPT_LENGTH=8192 MAX_GENERATE_LENGTH=57344 NUM_DUMMY_STEPS=5 \
  bash run_full_ctx_b200_glm.sh \
  trainer.placement.use_expandable_segments=true

where run_full_ctx_b200_glm.sh script is here: https://gist.github.com/CharlieFRuan/e925cdc9d036fdc128f9d88208608cae

Test plan

  • Unit tests: 33/33 weight sync strategy tests pass
  • GPU E2E: test_worker_dispatch_offload — 4/4 colocate_all=True FSDP tests pass
  • GPU E2E: test_policy_local_engines_e2e — 4/4 colocate + weight sync + inference tests pass
  • Full context: GLM-4.7-Flash 64K, TP=2/EP=4, colocate_all=True — 5/5 steps pass with expandable_segments, OOMs on step 2 without
  • Verified expandable_segments toggle logging visible in worker logs

🤖 Generated with Claude Code


Open with Devin

Add `use_expandable_segments` flag (default: True) to PlacementConfig that
enables PyTorch's CUDA expandable_segments allocator on training workers.
This dramatically reduces GPU memory fragmentation during training,
unlocking longer context lengths and preventing step-2+ OOMs.

## Problem

PyTorch's default CUDA allocator uses fixed-size memory segments. Over
multiple training steps, freed blocks leave fragmented gaps — memory is
reserved but unusable for large contiguous allocations. This manifests as
OOMs even when total free memory should be sufficient, with the error
message suggesting `PYTORCH_ALLOC_CONF=expandable_segments:True`.

## Solution

Enable `expandable_segments` programmatically after model initialization
on all training worker processes (policy, critic, ref). The key design:

1. **Model weights are allocated BEFORE enabling expandable_segments**, so
   they reside in standard CUDA memory compatible with CUDA IPC
2. **Activations during training** (forward/backward) use expandable
   segments, eliminating fragmentation
3. **For colocate_all=True**: expandable_segments is toggled OFF before
   CUDA IPC weight sync and ON after, since CUDA IPC handles are
   incompatible with VMM-based allocations
4. **For colocate_all=False**: no toggling needed (weight sync uses NCCL
   broadcast, which is unaffected)

## Results (GLM-4.7-Flash, 64K context, TP=2 EP=4 CP=1, 8×B200)

| Step | `use_expandable_segments=false` | `use_expandable_segments=true` |
|------|---|----|
| 1 | PASS | PASS |
| 2 | **OOM** (15.50 GiB fragmented, only 15.82 GiB free for 18.91 GiB alloc) | PASS |
| 3 | — | PASS |
| 4 | — | PASS |
| 5 | — | PASS |

Without expandable_segments, 15.50 GiB of GPU memory is reserved but
unusable due to fragmentation. With expandable_segments, fragmentation
drops to ~266 MiB, freeing that memory for actual allocations.

## Files changed

- `skyrl/train/config/config.py` — add `use_expandable_segments: bool = True`
- `skyrl/backends/skyrl_train/workers/worker.py` — toggle helpers on Worker base class
- `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` — enable after init, toggle around weight sync
- `skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py` — same
- `skyrl/backends/skyrl_train/weight_sync/cuda_ipc_strategy.py` — clone tensors in vLLM native IPC path
- `ai_docs/run_full_ctx_b200_glm.sh` — B200 full context test script for GLM-4.7-Flash

Closes #1405

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the use_expandable_segments configuration to optimize GPU memory usage by reducing fragmentation during training. It implements logic to toggle the PyTorch CUDA allocator setting, enabling it after model initialization and disabling it during weight synchronization to ensure CUDA IPC compatibility. Additionally, a new script for testing GLM-4.7-Flash context boundaries is provided. Feedback suggests removing a redundant .clone() operation in the weight synchronization strategy to prevent unnecessary memory overhead and wrapping the synchronization logic in try...finally blocks to guarantee the allocator settings are correctly restored in case of failures.

Comment on lines +201 to +205
# clone() ensures the tensor is allocated with the *current*
# allocator settings. When expandable_segments has been toggled
# OFF before weight sync, the clone lives in standard CUDA
# memory which is compatible with cudaIpcGetMemHandle.
weight = tensor.detach().contiguous().clone()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .clone() call here appears to be redundant and could lead to significant memory overhead. Since expandable_segments is toggled OFF in the worker before extract_weights is called, the tensors yielded by the iterator (which are typically fresh allocations from full_tensor() or .to(dtype)) are already allocated using the standard CUDA allocator and are thus IPC-compatible. Materializing a full copy of the model weights via .clone() effectively doubles the memory required for weights during synchronization, which may cause OOM on memory-constrained GPUs for large models.

Suggested change
# clone() ensures the tensor is allocated with the *current*
# allocator settings. When expandable_segments has been toggled
# OFF before weight sync, the clone lives in standard CUDA
# memory which is compatible with cudaIpcGetMemHandle.
weight = tensor.detach().contiguous().clone()
# Ensure the tensor is detached and contiguous for IPC.
# Since expandable_segments was toggled OFF before extraction,
# the tensor is already in standard CUDA memory.
weight = tensor.detach().contiguous()

Comment on lines +308 to +309
if use_expandable:
self._set_expandable_segments(True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is safer to wrap the weight synchronization logic in a try...finally block to ensure that expandable_segments is re-enabled even if an exception occurs during the broadcast process. If synchronization fails (e.g., due to a network timeout or NCCL error), leaving the allocator in the standard mode could lead to memory fragmentation and OOMs in subsequent training steps.

Comment on lines +811 to +812
if use_expandable:
self._set_expandable_segments(True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the FSDP worker, consider using a try...finally block around the weight synchronization logic. This ensures that expandable_segments is toggled back ON even if send_chunks or other intermediate operations fail, preventing unexpected memory fragmentation in future steps.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 5 additional findings.

Open in Devin Review

The weight extractor creates fresh tensor allocations (via all_gather or
full_tensor) while expandable_segments is toggled OFF, so the tensors
are already in standard CUDA memory. The .clone() doubled memory usage
during weight sync unnecessarily.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: update .rst doc

# Enable expandable_segments after model init so weights are in standard
# CUDA memory (IPC-compatible). Only new allocations (activations during
# training) will use expandable segments.
self._maybe_enable_expandable_segments()
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: I don't like how there are both _maybe_enable_expandable_segments() and _set_expandable_segments(). Should fix

- Merge `_maybe_enable_expandable_segments` into `_set_expandable_segments`
  (single method that checks config internally, no-op when disabled)
- Remove `ai_docs/run_full_ctx_b200_glm.sh` from PR
- Add Memory Optimization section to placement.mdx documenting
  `use_expandable_segments` and how it interacts with colocation
- Simplify toggle logic in broadcast_to_inference_engines (check
  colocate_all directly since _set_expandable_segments handles config check)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[train] Check when PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True can be set

1 participant