Skip to content

Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes#1024

Open
kevalmorabia97 wants to merge 6 commits intomainfrom
kmorabia/minitron-full-te-spec
Open

Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes#1024
kevalmorabia97 wants to merge 6 commits intomainfrom
kmorabia/minitron-full-te-spec

Conversation

@kevalmorabia97
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 commented Mar 11, 2026

What does this PR do?

Type of change: Improvement

Quantization recently added support for Full TE spec. Adding same for Pruning as well so we can retire ModelOpt spec and just use standard TE spec.
NOTE: We still dont support TEGroupedGemm and instead use TE SequentialMLP for now (but this can be configured in standard TE Spec so we dont need modelopt spec)

Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.

[Bug fix]: Previously NAS-based pruning for MoE models would hang when evaluating MMLU for pruned candidate models because of a bug. Fixed in this PR as well

[Bug fix]: Previously hidden size importance hooks were not applied to pre_mlp_layernorm for MoE layers. Fixed in this PR as well resulting in a significant improvement in MMLU for Qwen3-30B-A3B

Testing

  • Unit tests updated and passing
  • Compare pruning results for Qwen3-8B -> 6B. ⚠️ Difference in MMLU scores resulting in a different best picked model. But scores more or less in similar range - difference may be because of different kernel for TE layers
Least important 6 layers:
    ModelOpt Spec: 27, 28, 29, 31, 32, 33
    TE Spec: 27, 28, 30, 31, 32, 33

Top 10 pruned candidates:
| num_layers | hidden_size | ffn_hidden_size | Params (B) | MMLU (ModelOpt Spec) | MMLU (TE Spec) |
|------------|-------------|-----------------|------------|----------------------|----------------|
| 34         | 3328        | 11264           | 5.99       | 0.390                | 0.397          |
| 30         | 3584        | 11776           | 5.99       | 0.572 [BEST]         | 0.575          |
| 36         | 3840        | 8192            | 5.98       | 0.511                | 0.511          |
| 36         | 3584        | 9216            | 5.98       | 0.477                | 0.497          |
| 36         | 3072        | 11776           | 5.97       | 0.278                | 0.252          |
| 32         | 3584        | 10752           | 5.96       | 0.542                | 0.541          |
| 36         | 3328        | 10240           | 5.92       | 0.365                | 0.412          |
| 34         | 3840        | 8704            | 5.91       | 0.537                | 0.539          |
| 30         | 4096        | 9216            | 5.90       | 0.566                | 0.591 [BEST]   |
| 34         | 3584        | 9728            | 5.89       | 0.499                | 0.510          |
  • Compare pruning results for Nemotron-Nano-9B-v2 -> 7B. MMLU scores slight difference but best pruned model selection same
Least important 8 layers (Before and After): [43, 44, 45, 46, 47, 48, 50, 52]

Top 10 pruned candidates:
| num_layers | hidden_size | mamba_num_heads | mamba_head_dim | ffn_hidden_size | Params (B) | MMLU (ModelOpt Spec) | MMLU (TE Spec) |
|------------|-------------|------------------|---------------|-----------------|------------|----------------------|----------------|
| 50         | 4480        | 128              | 56            | 15680           | 7.00       | 0.211                | 0.202          |
| 56         | 4096        | 96               | 80            | 14336           | 7.00       | 0.438                | 0.436          |
| 48         | 4352        | 120              | 80            | 13824           | 7.00       | 0.679 [BEST]         | 0.679 [BEST]   |
| 56         | 4352        | 112              | 80            | 10240           | 7.00       | 0.516                | 0.520          |
| 54         | 4480        | 104              | 80            | 11264           | 7.00       | 0.263                | 0.262          |
| 46         | 4480        | 128              | 72            | 14848           | 7.00       | 0.610                | 0.617          |
| 50         | 4480        | 112              | 64            | 15680           | 7.00       | 0.426                | 0.421          |
| 54         | 4096        | 112              | 80            | 13312           | 7.00       | 0.579                | 0.589          |
| 56         | 4352        | 120              | 72            | 10752           | 7.00       | 0.466                | 0.469          |
| 52         | 4352        | 120              | 72            | 12800           | 7.00       | 0.561                | 0.560          |
  • Compare pruning results for Qwen3-30B-A3B -> 24B. Previously there was a bug in hooks added so now we see a big improvement
Top 10 pruned candidates (~1 hour per candidate MMLU computation so skipped after 3):
| num_layers | hidden_size | num_attention_heads | num_moe_experts | Params (B)| MMLU (ModelOpt Spec) | MMLU (TE Spec) |
|------------|-------------|---------------------|-----------------|-----------|----------------------|----------------|
| 46         | 2048        | 28                  | 104             | 23.98B    | 0.663                | 0.698          |
| 40         | 2048        | 28                  | 120             | 23.95B    | 0.577                | 0.668          |
| 46         | 1792        | 24                  | 120             | 23.94B    | 0.435                | 0.500          |
| 46         | 2048        | 24                  | 104             | 23.88B    |                      |                |
| 40         | 2048        | 24                  | 120             | 23.87B    |                      |                |
| 46         | 1792        | 20                  | 120             | 23.85B    |                      |                |
| 40         | 2048        | 20                  | 120             | 23.78B    |                      |                |
| 46         | 2048        | 20                  | 104             | 23.78B    |                      |                |
| 42         | 2048        | 32                  | 112             | 23.62B    |                      |                |
| 48         | 1792        | 32                  | 112             | 23.54B    |                      |                |

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ⚠️ TE has different kernels so pruned model may be slightly different because of different numerics
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

OMNIML-3504

Summary by CodeRabbit

Release Notes

  • New Features

    • Added full Transformer Engine specification support for Minitron pruning, enabling pruning without custom ModelOpt specifications.
  • Documentation

    • Updated container image tags and Docker configuration in bridge examples.
    • Enhanced pruning workflow documentation.
  • Improvements

    • Refined activation collection strategy for pruning processes.
    • Improved tokenizer configuration handling for better compatibility.

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: de38bd26-c432-4768-bcc4-f85704da30d9

📥 Commits

Reviewing files that changed from the base of the PR and between 72a5b3d and 7b198d6.

📒 Files selected for processing (12)
  • CHANGELOG.rst
  • examples/megatron_bridge/README.md
  • examples/megatron_bridge/prune_minitron.py
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/nas/plugins/transformer_engine.py
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • modelopt/torch/utils/plugins/mbridge.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
  • tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/nas/plugins/transformer_engine.py

📝 Walkthrough

Walkthrough

This PR migrates the NVIDIA Model Optimizer codebase to use full Transformer Engine specifications for Minitron pruning. The changes consolidate TE module integrations from a dedicated plugin into the Megatron plugin, introduce a new factory function for TE-compatible Mamba stacks, update activation collection logic to handle TE's fused layer normalization outputs, and align test utilities and test suites to exercise the TE implementation paths.

Changes

Cohort / File(s) Summary
Documentation Updates
CHANGELOG.rst, examples/megatron_bridge/README.md, examples/megatron_bridge/prune_minitron.py
Changelog version bump and feature documentation for TE Minitron support; updated NeMo container version, Docker run configuration, and log message formatting in pruning examples.
Transformer Engine NAS Plugin Migration
modelopt/torch/nas/plugins/megatron.py, modelopt/torch/nas/plugins/transformer_engine.py
Consolidated TE dynamic module definitions (TENorm, TELinear variants) and MoE handling into megatron.py; deleted dedicated transformer_engine.py plugin file; added get_te_mamba_stack_spec() factory function.
Minitron Pruning Updates
modelopt/torch/prune/plugins/mcore_minitron.py
Reworked activation importance estimation to handle TE's fused layer normalization outputs via new _collect_activations() helper; updated hook registration for TELayerNormColumnParallelLinear modules with return_layernorm_output patching.
Bridge/Utilities Updates
modelopt/torch/utils/plugins/mbridge.py, tests/_test_utils/torch/megatron/models.py
Replaced static spec assignments with factory function calls (get_te_mamba_stack_spec, get_gpt_layer_with_transformer_engine_spec); added conditional logic for selecting modelopt vs. transformer_engine implementations in model construction.
Test Suite Updates
tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py, tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py, tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py, tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
Updated imports and type assertions to reference TE-based dynamic modules; added transformer_impl="transformer_engine" parameters to model construction calls; adjusted pruning assertions to reflect new MambaLayer API shapes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • danielkorzekwa
  • ChenhanYu
  • AAnoosheh
  • LianaMikael
  • realAsma
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main changes: adding Full TE Spec support for Megatron Pruning DynamicModules and fixing MoE-related bugs, which aligns with the PR's primary objectives.
Security Anti-Patterns ✅ Passed Pull request does not introduce critical security anti-patterns; trust_remote_code is properly exposed as configurable and validated through is_safe_repo().

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kmorabia/minitron-full-te-spec
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can approve the review once all CodeRabbit's comments are resolved.

Enable the reviews.request_changes_workflow setting to automatically approve the review once all CodeRabbit's comments are resolved.

@kevalmorabia97 kevalmorabia97 changed the title Add Full TE Spec support for Megatron Pruning DynamicModules Add Full TE Spec and GroupedMLP support for Megatron Pruning DynamicModules Mar 11, 2026
@codecov
Copy link

codecov bot commented Mar 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.12%. Comparing base (26cad67) to head (7b198d6).
⚠️ Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1024      +/-   ##
==========================================
+ Coverage   70.09%   70.12%   +0.03%     
==========================================
  Files         221      221              
  Lines       25459    25459              
==========================================
+ Hits        17845    17854       +9     
+ Misses       7614     7605       -9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 changed the title Add Full TE Spec and GroupedMLP support for Megatron Pruning DynamicModules Add Full TE Spec support for Megatron Pruning DynamicModules Mar 11, 2026
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch from ade6edf to d4820c8 Compare March 11, 2026 21:09
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch from d4820c8 to 98d5291 Compare March 11, 2026 21:23
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch 3 times, most recently from f3071a3 to 8f42e0f Compare March 12, 2026 11:02
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-full-te-spec branch from 8f42e0f to cff7137 Compare March 12, 2026 12:03
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 changed the title Add Full TE Spec support for Megatron Pruning DynamicModules Add Full TE Spec support for Megatron Pruning DynamicModules + MoE bug fixes Mar 12, 2026
@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review March 12, 2026 20:26
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners March 12, 2026 20:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant