Skip to content

[PyT] Fix FSDP2 memory leaks for FP8 weight workspaces and transpose caches#2805

Open
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-fsdp2-mem-leak
Open

[PyT] Fix FSDP2 memory leaks for FP8 weight workspaces and transpose caches#2805
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-fsdp2-mem-leak

Conversation

@pstjohn
Copy link
Copy Markdown
Contributor

@pstjohn pstjohn commented Mar 26, 2026

Summary

Fixes memory leaks where FP8 quantized weight workspaces and transpose caches accumulate during FSDP2 training, defeating FSDP2's per-layer memory savings.

Approach

When FSDP2 is detected (via _get_module_fsdp_state), the fix applies to layernorm_mlp.py, layernorm_linear.py, and linear.py:

  1. Skip columnwise/transpose creation during forward — backward's FSDP2 all-gather recreates them
  2. Disable workspace caching (cache_name=None) — prevents _fp8_workspaces from retaining per-layer copies
  3. Don't save workspace copies for backward — re-quantize from FSDP2 all-gathered weight instead
  4. Clear _transpose after backward dgrad GEMMs — prevents transpose data persisting on reusable buffers

Guarded to Float8Quantizer, Float8CurrentScalingQuantizer, and MXFP8Quantizer. Blockwise quantizers (Float8BlockScaling, NVFP4BlockScaling) have separate internal caching not yet addressed.

Test changes

Updates the memory leak detection tests from #2803:

  • Removes xfail markers for fixed recipes
  • Adds targeted xfail for blockwise recipes
  • Increases backward test tolerance to 1 MiB (from 256 KiB) to account for temporary backward workspace re-creation

Test plan

  • All FSDP2 memory leak tests pass (14 passed, 6 xfailed for blockwise, 2 skipped)
  • All existing FSDP2 functional tests pass (model tests, fused_adam tests)
  • No regressions in test_torch_fsdp2.py (4 passed)

🤖 Generated with Claude Code

Closes #2681
Closes #2717

pstjohn and others added 3 commits March 25, 2026 16:47
Add tests that demonstrate two known memory issues with FSDP2 + FP8:

- Issue NVIDIA#2681: FP8 weight copies created during te.autocast() forward pass
  accumulate across layers instead of being freed between layers, defeating
  FSDP2's memory efficiency. Detected by comparing per-layer forward memory
  increments against a bf16 baseline using layer hooks.

- Issue NVIDIA#2717: Transpose cache tensors (_create_transpose) allocated during
  backward persist until the next forward pass instead of being freed after
  backward completes. Detected by comparing the backward memory delta
  (post_bwd - post_fwd) against a bf16 baseline.

New tests:
- test_bf16_no_excess_forward_memory: control, validates per-layer measurement
- test_bf16_no_excess_backward_memory: control, validates backward delta comparison
- test_fp8_temp_accumulation_across_layers: xfail, detects NVIDIA#2681
- test_transpose_cache_retained_after_backward: xfail, detects NVIDIA#2717

All parametrized over 5 FP8 recipes x {no_quant_init, quant_init}.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… constant

- Fix standalone runner to not pass recipe/quantized_model_init args to
  bf16 control tests (which take no arguments)
- Fix stale comment referencing 4-layer model (now 8 layers)
- Remove unused MEASURED_STEPS constant

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…caches

Fix memory leaks where FP8 quantized weight copies and transpose caches
accumulate during FSDP2 training, defeating FSDP2's per-layer memory
savings (Issues NVIDIA#2681, NVIDIA#2717).

Changes to layernorm_mlp.py, layernorm_linear.py, linear.py:
- Detect FSDP2 via _get_module_fsdp_state; guard to tensor-scaling and
  MXFP8 quantizers whose backward re-creation is validated.
- Skip columnwise/transpose creation on weight quantizers during forward
  so FP8 caches don't accumulate across layers.
- Disable workspace caching (cache_name=None) under FSDP2 to prevent
  _fp8_workspaces from retaining per-layer copies.
- Don't save separate FP8 workspace copies for backward; re-create from
  the FSDP2 all-gathered weight in backward instead.
- Clear Float8TensorStorage._transpose after backward dgrad GEMMs to
  prevent transpose data persisting on FSDP2's reusable buffers.

Test changes (run_fsdp2_mem_leak.py):
- Remove xfail markers for fixed recipes (DelayedScaling,
  Float8CurrentScaling, MXFP8BlockScaling).
- Add targeted xfail for Float8BlockScaling/NVFP4BlockScaling whose
  blockwise storage classes have separate internal caching.
- Increase backward test tolerance to 1 MiB to account for temporary
  workspace re-creation during backward.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 26, 2026

Greptile Summary

This PR fixes two FSDP2 memory leaks in TransformerEngine's FP8 training path — quantized weight workspaces accumulating across layers during forward (Issue #2681) and _create_transpose tensors persisting on FSDP2's reusable all-gather buffers after backward (Issue #2717). The fix is applied consistently across linear.py, layernorm_linear.py, and layernorm_mlp.py, guarded to Float8/MXFP8 quantizer types whose backward reconstruction is validated.\n\nKey changes:\n- FSDP2 is detected at runtime via _get_module_fsdp_state; the result gates four optimizations per module\n- columnwise=False is set during forward when FSDP2 is detected, preventing transpose creation\n- cache_name=None disables _fp8_workspaces persistence across layers\n- wt_save=None prevents the quantized workspace from being saved in autograd context; backward re-quantizes from the FSDP2 all-gathered weight\n- weight._transpose = None is cleared after dgrad GEMMs to release memory on the shared all-gather buffer\n- A new run_fsdp2_mem_leak.py test file provides direct forward-accumulation and backward-retention regression tests with appropriate xfail markers for blockwise recipes\n\nIssues found:\n- Redundant MXFP8Quantizer import inside linear.py's forward body — P2 style\n- ctx.is_fsdp2 in layernorm_mlp.py stores fsdp2_skip_columnwise (a stricter compound condition), which could confuse readers — P2 style\n- Incidental pre-existing bug fix where isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) (always False) was corrected to isinstance(ctx.fc1_weight, QuantizedTensorStorage) — changes non-FSDP2 behavior and deserves a separate callout — P1\n- _is_safe_for_fsdp2 in layernorm_mlp.py is derived from only fc1_weight_quantizer but guards FC2 as well — P2 style

Confidence Score: 4/5

Safe to merge with one minor follow-up: the bundled pre-existing bug fix in layernorm_mlp.py should be acknowledged; all other issues are style/naming.

The core FSDP2 memory-leak fixes are logically sound and consistently applied across all three module files. The new test suite directly validates the fixed scenarios with appropriate tolerances and xfail markers for unresolved blockwise cases. Backward reconstruction of the FP8 workspace is correct: origin_weight is always saved (never None) and the quantizer is always non-None when wt_save=None can be set. The P1 comment covers an incidental fix to a pre-existing bug (wrong isinstance target) that changes non-FSDP2 behavior but is covered by passing functional tests. Remaining comments are P2 style/naming issues.

transformer_engine/pytorch/module/layernorm_mlp.py — contains the bundled isinstance bug fix and the ctx.is_fsdp2 naming confusion; warrants a close read on the non-FSDP2 FC1 dgrad path.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/layernorm_mlp.py Most complex changes: FSDP2 detection for two weights (FC1/FC2), workspace skip, None save, two backward reconstructions, two transpose clears. Contains a bundled pre-existing bug fix (isinstance target) and ctx.is_fsdp2 naming confusion.
transformer_engine/pytorch/module/linear.py Adds FSDP2 detection and workspace/cache fixes consistently; minor redundant MXFP8Quantizer import inside the method body.
transformer_engine/pytorch/module/layernorm_linear.py FSDP2 detection, workspace caching disabled, wt_save=None, and transpose clearing — logic is consistent with linear.py.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_mem_leak.py New test file with well-structured memory leak detection for forward accumulation and backward transpose retention; appropriate xfail markers for unresolved blockwise recipes.
tests/pytorch/distributed/test_torch_fsdp2.py Adds a new test entry point that invokes the memory leak test file via torchrun; straightforward and consistent with existing test runners.

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2 Runtime
    participant FWD as Forward Pass (_Linear / _LayerNormLinear / _LayerNormMLP)
    participant CTX as Autograd Context
    participant BWD as Backward Pass

    FSDP2->>FWD: All-gather weight (BF16/FP8 param)
    FWD->>FWD: detect FSDP2 via _get_module_fsdp_state()
    Note over FWD: is_fsdp2=True → columnwise=False,<br/>cache_name=None, wt_save=None
    FWD->>FWD: get_weight_workspace() → weightmat (FP8, rowwise only)
    FWD->>CTX: save(inputmat, wt_save=None, origin_weight, ...)
    FWD-->>FSDP2: Reshard weight (buffer freed)

    FSDP2->>BWD: All-gather weight (BF16/FP8 param)
    BWD->>CTX: restore_from_func_ctx() → weight_fp8=None
    BWD->>BWD: weight_fp8 is None → re-quantize(origin_weight, columnwise=True)
    BWD->>BWD: dgrad GEMM (weight_fp8 transposed)
    BWD->>BWD: weight._transpose = None (clear cache, Issue #2717)
    BWD-->>FSDP2: Reshard weight (buffer freed, no dangling _transpose)
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/module/linear.py, line 876-880 (link)

    P2 Redundant MXFP8Quantizer import inside forward method

    MXFP8Quantizer is already imported at the module level (line 67), so re-importing it here is unnecessary and adds a small overhead on every forward call.

  2. transformer_engine/pytorch/module/layernorm_mlp.py, line 680-695 (link)

    P2 FSDP2 safety check uses only fc1_weight_quantizer but guards both FC1 and FC2

    _is_safe_for_fsdp2 checks isinstance(fc1_weight_quantizer, _fsdp2_safe_quantizers) or isinstance(fc1_weight, Float8Tensor), but the same flag controls workspace skipping for fc2_weight as well. In practice both weights share the same quantizer type within a layer, but checking both explicitly would make the assumption visible and guard against future mixed-precision configurations.

    _is_safe_for_fsdp2 = (
        isinstance(fc1_weight_quantizer, _fsdp2_safe_quantizers) or isinstance(fc1_weight, Float8Tensor)
    ) and (
        isinstance(fc2_weight_quantizer, _fsdp2_safe_quantizers) or isinstance(fc2_weight, Float8Tensor)
    )

Reviews (1): Last reviewed commit: "[PyT] Fix FSDP2 memory leaks for FP8 wei..." | Re-trigger Greptile

ctx.fc2_weight_requires_grad = fc2_weight.requires_grad
ctx.fc1_weight = fc1_weight
ctx.fc2_weight = fc2_weight
ctx.is_fsdp2 = fsdp2_skip_columnwise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 ctx.is_fsdp2 stores fsdp2_skip_columnwise, not the raw is_fsdp2 flag

fsdp2_skip_columnwise is is_fsdp2 and not is_recomputation and _is_safe_for_fsdp2, which is more restrictive than is_fsdp2. Storing it as ctx.is_fsdp2 could mislead a reader into thinking it simply encodes "this is an FSDP2 run." The backward's transpose-clearing guard (getattr(ctx, "is_fsdp2", False)) happens to be correct because you only want to clear during a real (non-recompute) backward, but the naming makes the invariant implicit.

Consider renaming to ctx.fsdp2_skip_columnwise (matching the local variable name) to make the semantics self-documenting.

Suggested change
ctx.is_fsdp2 = fsdp2_skip_columnwise
ctx.fsdp2_skip_columnwise = fsdp2_skip_columnwise

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 1436 to 1438
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensorStorage
ctx.fc1_weight, QuantizedTensorStorage
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incidental bug fix: isinstance target changed from quantizer to weight

The pre-existing code read:

if ctx.fc1_weight_quantizer is not None and isinstance(
    ctx.fc1_weight_quantizer, QuantizedTensorStorage   # always False
):
    ctx.fc1_weight.update_usage(columnwise_usage=True)

Because a Quantizer is never a QuantizedTensorStorage, this condition was always False and update_usage was never called on the FC1 weight in this branch. This PR silently fixes that by checking ctx.fc1_weight instead — which is correct — but it also changes backward behavior for non-FSDP2 paths (e.g., fp8_init with primary-FP8 weights). The passing functional tests provide coverage, but a comment or separate commit noting this as a pre-existing bug fix would keep the diff more reviewable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant