Skip to content

Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models#1033

Open
Edwardf0t1 wants to merge 3 commits intomainfrom
zhiyu/handle-moe-w13-scales
Open

Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models#1033
Edwardf0t1 wants to merge 3 commits intomainfrom
zhiyu/handle-moe-w13-scales

Conversation

@Edwardf0t1
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 commented Mar 13, 2026

What does this PR do?

Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models. Serving engines fuse these projections into a single gate_up_proj and require a shared scale; this takes the element-wise max of the two independent scales as a conservative choice that avoids overflow.

Type of change: ? Bug fix

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features
    • Added automatic synchronization of quantization scaling values for Mixture of Experts gate and up projections during model export. This feature applies to non-fused MoE expert configurations and aligns scaling parameters across gate and up projection pairs. Relevant for models such as Qwen MoE and DeepSeek.

…pe in export

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…pe in export

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@Edwardf0t1 Edwardf0t1 marked this pull request as ready for review March 13, 2026 05:50
@Edwardf0t1 Edwardf0t1 requested a review from a team as a code owner March 13, 2026 05:50
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

This pull request adds MoE gate-up quantization synchronization functionality. A new sync_moe_gate_up_amax function is implemented to synchronize amax values between gate and up weight quantizers for non-fused MoE experts, and integrated into the export checkpoint process via _export_transformers_checkpoint.

Changes

Cohort / File(s) Summary
MoE Synchronization Implementation
modelopt/torch/export/layer_utils.py
Introduces _GATE_UP_PAIRS constant and sync_moe_gate_up_amax(model) function that traverses MOE modules, identifies gate/up expert pairs, and synchronizes their weight quantizer amax values to the element-wise maximum, returning the count of synced pairs.
Export Integration
modelopt/torch/export/unified_export_hf.py
Imports sync_moe_gate_up_amax and calls it within _export_transformers_checkpoint post-quantization, printing a status message when synchronization occurs.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately describes the main change: unifying weight_scale_2 between gate_proj/up_proj pairs in the HF export path for MOE models, which aligns with the core objective of synchronizing amax values for MoE gate/up projections.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Code review found no security anti-patterns: no torch.load with weights_only=False, no numpy.load with allow_pickle=True, no eval/exec, no trust_remote_code=True, no nosec comments, no new dependencies. The new sync_moe_gate_up_amax function performs safe tensor operations only.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch zhiyu/handle-moe-w13-scales
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Hugging Face export path to ensure weight_scale_2 is shared between fused MoE MLP projection pairs (gate/up), matching serving-engine expectations when they fuse these projections into a single kernel.

Changes:

  • Add max_gate_up_scales() utility to replace gate_proj/up_proj (and w1/w3) weight_scale_2 pairs with their element-wise max in the exported state dict.
  • Invoke the new post-processing step at the end of _export_transformers_checkpoint() after postprocess_state_dict().

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
modelopt/torch/export/unified_export_hf.py Calls max_gate_up_scales() during HF checkpoint export and reports how many pairs were tied.
modelopt/torch/export/quant_utils.py Introduces max_gate_up_scales() to unify weight_scale_2 across gate/up projection pairs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +795 to +797
tied = max_gate_up_scales(quantized_state_dict)
if tied:
print(f"Tied weight_scale_2 for {tied} gate/up projection pair(s) in MoE experts.")
Comment on lines +1158 to +1168
"""Replace gate_proj and up_proj weight_scale_2 with their element-wise max.

For MOE models where gate_proj and up_proj are quantized independently,
serving engines typically fuse them into a single gate_up_proj and need
a single shared scale. Using max is conservative (avoids overflow at the
cost of slightly reduced dynamic range).
"""
suffix_pairs = {
".gate_proj.weight_scale_2": ".up_proj.weight_scale_2",
".w1.weight_scale_2": ".w3.weight_scale_2",
}
@codecov
Copy link

codecov bot commented Mar 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.06%. Comparing base (bc87981) to head (17b7f46).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1033      +/-   ##
==========================================
- Coverage   70.11%   70.06%   -0.05%     
==========================================
  Files         221      221              
  Lines       25459    25531      +72     
==========================================
+ Hits        17851    17889      +38     
- Misses       7608     7642      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

return post_state_dict


def max_gate_up_scales(tensors: dict[str, torch.Tensor]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may not work. We also need to requantize ws1 if the ws2 changes. I think we need to do it here instead: https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L230

Copy link
Contributor Author

@Edwardf0t1 Edwardf0t1 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's safer to sync the _amax on the weight quantizers of gate/up pairs. In @lukealonso's implementation it's done in both amax level and the post-processing stage, and he verified the accuracy which seems good.

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L396-L418 this logic is used to align w1 and w3 scales.

We might want to debug why this fails for qwen3.5

@Edwardf0t1
Copy link
Contributor Author

https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L396-L418 this logic is used to align w1 and w3 scales.

We might want to debug why this fails for qwen3.5

It failed for kimi-k2.5 as well. That's exactly why we needed this PR to fix it.

quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
)

tied = max_gate_up_scales(quantized_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this before compressing the weights, i.e., syncing the amax of gate and up? Is my understanding correct that the block-wise scales are still based on the separate global scales?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right - updated to sync at the amax level.

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/layer_utils.py`:
- Around line 1206-1213: The loop that iterates over _GATE_UP_PAIRS incorrectly
uses break when a pair is missing a quantizer or amax, which stops checking
remaining pairs; update the logic in the loop that accesses
gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to use continue
instead of break so only the current pair is skipped and the next pair in
_GATE_UP_PAIRS is checked; keep the final break that exits after successfully
syncing unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c2d4c1ed-e8c9-4c53-b650-0b34276086cf

📥 Commits

Reviewing files that changed from the base of the PR and between 7398fea and 17b7f46.

📒 Files selected for processing (2)
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/unified_export_hf.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/unified_export_hf.py

Comment on lines +1206 to +1213
gate_wq = getattr(gate_linear, "weight_quantizer", None)
up_wq = getattr(up_linear, "weight_quantizer", None)
if gate_wq is None or up_wq is None:
break
gate_amax = getattr(gate_wq, "amax", None)
up_amax = getattr(up_wq, "amax", None)
if gate_amax is None or up_amax is None:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

❓ Verification inconclusive

Script executed:

#!/bin/bash
set -euo pipefail

echo "== SequentialQuantizer API =="
fd 'tensor_quantizer.py' --exec rg -n -C2 'class SequentialQuantizer|def amax|@property'

echo
echo "== Where SequentialQuantizer is used for weight quantizers =="
rg -n -C3 --type=py 'SequentialQuantizer\(|weight_quantizer' modelopt/torch

echo
echo "== Existing unwrapping pattern in layer_utils.py =="
rg -n -C3 --type=py 'isinstance\(.*SequentialQuantizer\)|weight_quantizer\.amax|sync_moe_gate_up_amax' modelopt/torch/export/layer_utils.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

== SequentialQuantizer API ==
252-        return qtensor.dequantize(**kwarg)
253-
254:    `@property`
255-    def num_bits(self):
256-        """Return num_bits for quantization."""
--
262-        self._calibrator._num_bits = value
263-
264:    `@property`
265-    def maxbound(self):
266-        """Return maxbound for quantization."""
--
271-        return (1 << (self._num_bits - 1 + int(self._unsigned))) - 1
272-
273:    `@property`
274-    def unsigned(self):
275-        """Return True if unsigned quantization is used."""
--
281-        self._calibrator._unsigned = value
282-
283:    `@property`
284-    def pre_quant_scale(self):
285-        """Return pre_quant_scale used for smoothquant."""
--
305-            )
306-
307:    `@property`
308:    def amax(self):
309-        """Return amax for quantization."""
310-        if not hasattr(self, "_amax") or self.is_mx_format:
--
314-
315-    `@amax.setter`
316:    def amax(self, value):
317-        assert value is not None, "amax cannot be set to None."
318-
--
341-            self._bias_calibrator.reset()
342-
343:    `@property`
344-    def step_size(self):
345-        """Return step size for integer quantization."""
--
352-        return self._amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)
353-
354:    `@property`
355-    def axis(self):
356-        """Return axis for quantization."""
--
362-        self._calibrator._axis = value
363-
364:    `@property`
365-    def block_sizes(self):
366-        """Return block_sizes for quantization."""
--
372-        self._block_sizes = value
373-
374:    `@property`
375-    def bias(self):
376-        """Return bias for quantization."""
--
379-        return self._bias
380-
381:    `@property`
382-    def bias_axis(self):
383-        """Return bias_axis for quantization."""
--
392-        self._bias_axis = value
393-
394:    `@property`
395-    def bias_method(self):
396-        """Return bias_method for quantization."""
--
399-        return self._bias.get("method", "mean")
400-
401:    `@property`
402-    def bias_type(self):
403-        """Return bias_type for quantization."""
--
414-        self._bias["type"] = value
415-
416:    `@property`
417-    def bias_value(self):
418-        """Return bias for quantization."""
--
435-            self._bias_value.data.copy_(value.clone().detach().to(self._bias_value.device))
436-
437:    `@property`
438-    def bias_calibrator(self):
439-        """Return bias_calibrator for quantization."""
--
450-        return self._bias_calibrator
451-
452:    `@property`
453-    def fake_quant(self):
454-        """Return True if fake quantization is used."""
455-        return self._fake_quant
456-
457:    `@property`
458-    def narrow_range(self):
459-        """Return True if symmetric integer range for signed quantization is used."""
--
464-        self._narrow_range = value
465-
466:    `@property`
467-    def is_enabled(self):
468-        """Return true if the modules is not disabled."""
--
480-        self._disabled = False
481-
482:    `@property`
483-    def trt_high_precision_dtype(self):
484-        """Return True if FP16 AMAX is used when exporting the model."""
--
489-        self._trt_high_precision_dtype = value
490-
491:    `@property`
492-    def is_mx_format(self):
493-        """Check if is MX formats."""
--
521-            raise NotImplementedError()
522-
523:    `@property`
524-    def is_static_block_quant(self):
525-        """Check if is static block quantization."""
--
530-        )
531-
532:    `@property`
533-    def rotate_is_enabled(self):
534-        """Check if rotate is enabled in quant config."""
535-        return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
536-
537:    `@property`
538-    def rotate_is_fp32(self):
539-        """Check if rotation needs to be computed in float32."""
--
1286-        return tq
1287-
1288:    `@property`
1289-    def global_amax(self):
1290-        """Return global_amax for quantization."""
--
1320-
1321-
1322:class SequentialQuantizer(nn.Sequential):
1323-    """A sequential container for  :class:`TensorQuantizer` modules.
1324-

== Where SequentialQuantizer is used for weight quantizers ==
modelopt/torch/quantization/utils.py-228-
modelopt/torch/quantization/utils.py-229-    # the standard weight and quantizer case
modelopt/torch/quantization/utils.py-230-    weight = getattr(module, "weight", None)
modelopt/torch/quantization/utils.py:231:    weight_quantizer = getattr(module, "weight_quantizer", None)
modelopt/torch/quantization/utils.py:232:    if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
modelopt/torch/quantization/utils.py-233-        yield "weight"
modelopt/torch/quantization/utils.py-234-
modelopt/torch/quantization/utils.py-235-    # other weight and quantizer case
modelopt/torch/quantization/utils.py-236-    for name, _ in module.named_parameters(recurse=False):
modelopt/torch/quantization/utils.py-237-        weight = getattr(module, name, None)
modelopt/torch/quantization/utils.py:238:        weight_quantizer = getattr(module, f"{name}_weight_quantizer", None)
modelopt/torch/quantization/utils.py-239-        if isinstance(weight, nn.Parameter) and isinstance(
modelopt/torch/quantization/utils.py:240:            weight_quantizer, (TensorQuantizer, SequentialQuantizer)
modelopt/torch/quantization/utils.py-241-        ):
modelopt/torch/quantization/utils.py-242-            yield name
modelopt/torch/quantization/utils.py-243-
--
modelopt/torch/quantization/utils.py-246-QuantizerAttrNames = namedtuple(
modelopt/torch/quantization/utils.py-247-    "QuantizerAttrNames",
modelopt/torch/quantization/utils.py-248-    (
modelopt/torch/quantization/utils.py:249:        "weight_quantizer",
modelopt/torch/quantization/utils.py-250-        "input_quantizer",
modelopt/torch/quantization/utils.py-251-        "output_quantizer",
modelopt/torch/quantization/utils.py-252-        "weight_scale",
--
modelopt/torch/quantization/utils.py-261-    """Get all the quantizer related attribute names for a given weight name."""
modelopt/torch/quantization/utils.py-262-    prefix = f"{weight_name}_" if weight_name != "weight" else ""
modelopt/torch/quantization/utils.py-263-    return QuantizerAttrNames(
modelopt/torch/quantization/utils.py:264:        weight_quantizer=f"{prefix}weight_quantizer",
modelopt/torch/quantization/utils.py-265-        input_quantizer=f"{prefix}input_quantizer",
modelopt/torch/quantization/utils.py-266-        output_quantizer=f"{prefix}output_quantizer",
modelopt/torch/quantization/utils.py-267-        weight_scale=f"{prefix}weight_scale",
--
modelopt/torch/quantization/utils.py-285-    return (
modelopt/torch/quantization/utils.py-286-        isinstance(module, QuantModule)
modelopt/torch/quantization/utils.py-287-        and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
modelopt/torch/quantization/utils.py:288:        and hasattr(module, "weight_quantizer")
modelopt/torch/quantization/utils.py-289-        and (
modelopt/torch/quantization/utils.py-290-            (getattr(module, "weight", None) is not None and module.weight.dim() == 2)
modelopt/torch/quantization/utils.py-291-            # module.weight0 check is required to support TEGroupedLinear
--
modelopt/torch/quantization/utils.py-329-    config["quant_cfg"]["*lora*"] = {"enable": False}
modelopt/torch/quantization/utils.py-330-    for layer in layers:
modelopt/torch/quantization/utils.py-331-        config["quant_cfg"][f"*{layer}.input_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py:332:        config["quant_cfg"][f"*{layer}.weight_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-333-        config["quant_cfg"][f"*{layer}.output_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-334-    return config
modelopt/torch/quantization/utils.py-335-
--
modelopt/torch/quantization/utils.py-537-
modelopt/torch/quantization/utils.py-538-    1. Takes the element-wise max of each ``input_quantizer`` amax across all experts
modelopt/torch/quantization/utils.py-539-       and writes it back, so every expert shares the same input amax.
modelopt/torch/quantization/utils.py:540:    2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
modelopt/torch/quantization/utils.py-541-       received no tokens during calibration), runs a weight-only ``max_calibrate``
modelopt/torch/quantization/utils.py-542-       to populate the missing amax.
modelopt/torch/quantization/utils.py-543-    """
--
modelopt/torch/quantization/utils.py-566-
modelopt/torch/quantization/utils.py-567-    for expert in experts:
modelopt/torch/quantization/utils.py-568-        for name, module in expert.named_modules():
modelopt/torch/quantization/utils.py:569:            if name.endswith("weight_quantizer") and module.is_enabled and module.amax is None:
modelopt/torch/quantization/utils.py:570:                weight = expert.state_dict().get(name.replace("weight_quantizer", "weight"))
modelopt/torch/quantization/utils.py-571-                if weight is not None:
modelopt/torch/quantization/utils.py-572-                    max_calibrate(module, lambda m, w=weight: m(w), distributed_sync=False)
modelopt/torch/quantization/utils.py-573-
--
modelopt/torch/quantization/utils.py-687-    """
modelopt/torch/quantization/utils.py-688-    original_fake_quant = []
modelopt/torch/quantization/utils.py-689-    for m in module.modules():
modelopt/torch/quantization/utils.py:690:        if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:691:            original_fake_quant.append(m.weight_quantizer._fake_quant)
modelopt/torch/quantization/utils.py:692:            m.weight_quantizer._fake_quant = True
modelopt/torch/quantization/utils.py-693-    yield
modelopt/torch/quantization/utils.py-694-    for m in module.modules():
modelopt/torch/quantization/utils.py:695:        if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:696:            m.weight_quantizer._fake_quant = original_fake_quant.pop(0)
modelopt/torch/quantization/utils.py-697-
modelopt/torch/quantization/utils.py-698-
modelopt/torch/quantization/utils.py-699-@contextmanager
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-97-    def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-98-        cls,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-99-        weight: torch.Tensor,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:100:        weight_quantizer,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-101-    ) -> torch.Tensor:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-102-        """Returns E8M0 scale from quantizer or computes from weight.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-103-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-107-        Args:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-108-            weight: The weight tensor. Can be 2D (out_dim, in_dim) or
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-109-                3D for MoE (num_experts, out_dim, in_dim).
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:110:            weight_quantizer: The weight quantizer with block_sizes and optional _scale.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-111-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-112-        Returns:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-113-            torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-114-        """
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:115:        assert hasattr(weight_quantizer, "block_sizes"), (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:116:            "weight_quantizer must have 'block_sizes' attribute"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-117-        )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:118:        assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:119:            f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-120-        )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-121-        assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-122-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-126-        # For 3D MoE: (num_experts, out_dim, in_dim // 32)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-127-        expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-128-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:129:        if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:130:            scale = weight_quantizer._scale
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-131-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-132-            assert scale.dtype == cls.SCALE_DTYPE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-133-                f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}"
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-53-        return cls.e2m1_bounds_on_device[device]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-54-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-55-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:56:    def _is_static_quantizer(cls, weight_quantizer) -> bool:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-57-        """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:58:        return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-59-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-60-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:61:    def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:62:        """Returns per tensor weight scaling factor from the weight_quantizer.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-63-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-64-        Handles both static NVFP4 quantizers (using global_amax) and
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-65-        dynamic quantizers (using _amax).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-66-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-67-        Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:68:            weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-69-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-70-        Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-71-            The global scaling factor as a float tensor.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-72-        """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:73:        if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:74:            return weight_quantizer.global_amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-75-        else:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:76:            assert hasattr(weight_quantizer, "_amax"), (
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-77-                "Weight quantizer does not have attribute amax"
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-78-            )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:79:            return weight_quantizer._amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-80-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-81-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-82-    def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-83-        cls,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:84:        weight_quantizer,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-85-        weight: torch.Tensor,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-86-        weights_scaling_factor_2: torch.Tensor | None = None,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-87-        keep_high_precision: bool = False,
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-92-        and dynamic quantizers (computing from weight tensor).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-93-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-94-        Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:95:            weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-96-            weight: The weight tensor (used for shape in static, values in dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-97-            weights_scaling_factor_2: Optional pre-computed global scale.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-98-            keep_high_precision: Whether to keep scales in high precision.
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-100-        Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-101-            Tuple of (per_block_scale, weights_scaling_factor_2).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-102-        """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:103:        block_size = weight_quantizer.block_sizes[-1]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-104-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-105-        if weights_scaling_factor_2 is None:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-106-            weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:107:                weight_quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-108-            )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-109-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:110:        if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-111-            # Static path: use pre-computed per-block amax values from quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:112:            global_amax = weight_quantizer.global_amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:113:            per_block_amax = weight_quantizer._amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-114-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-115-            # Compute scales in float
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-116-            per_block_scale_max = global_amax / 6.0
--
modelopt/torch/quantization/qtensor/base_qtensor.py-204-            # We dont compress meta tensors or None
modelopt/torch/quantization/qtensor/base_qtensor.py-205-            return False
modelopt/torch/quantization/qtensor/base_qtensor.py-206-        if (
modelopt/torch/quantization/qtensor/base_qtensor.py:207:            hasattr(module, "weight_quantizer")
modelopt/torch/quantization/qtensor/base_qtensor.py:208:            and module.weight_quantizer.is_enabled
modelopt/torch/quantization/qtensor/base_qtensor.py:209:            and not module.weight_quantizer._fake_quant
modelopt/torch/quantization/qtensor/base_qtensor.py-210-            and module.weight.element_size() > 1
modelopt/torch/quantization/qtensor/base_qtensor.py-211-        ):
modelopt/torch/quantization/qtensor/base_qtensor.py-212-            if force_quantize:
modelopt/torch/quantization/qtensor/base_qtensor.py:213:                module.weight_quantizer._dequantize = False
modelopt/torch/quantization/qtensor/base_qtensor.py-214-
modelopt/torch/quantization/qtensor/base_qtensor.py:215:            real_quant_tensor = module.weight_quantizer(module.weight)
modelopt/torch/quantization/qtensor/base_qtensor.py-216-            module.weight = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/qtensor/base_qtensor.py-217-            return True
modelopt/torch/quantization/qtensor/base_qtensor.py-218-
--
modelopt/torch/quantization/plugins/vllm.py-80-            torch.Tensor: The quantized output tensor.
modelopt/torch/quantization/plugins/vllm.py-81-        """
modelopt/torch/quantization/plugins/vllm.py-82-        x = layer.input_quantizer(x)
modelopt/torch/quantization/plugins/vllm.py:83:        if layer.weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-84-            original_weight = layer.weight
modelopt/torch/quantization/plugins/vllm.py:85:            quantized_tensor = layer.weight_quantizer(layer.weight)
modelopt/torch/quantization/plugins/vllm.py-86-            # parameterize the quantized weight
modelopt/torch/quantization/plugins/vllm.py-87-            if isinstance(original_weight, torch.nn.Parameter) and not isinstance(
modelopt/torch/quantization/plugins/vllm.py-88-                quantized_tensor, torch.nn.Parameter
--
modelopt/torch/quantization/plugins/vllm.py-110-class _VLLMParallelLinear(QuantModule):
modelopt/torch/quantization/plugins/vllm.py-111-    def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-112-        self.input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:113:        self.weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-114-        self.output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-115-        self.output_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-116-        assert type(self.quant_method) is vllm_linear.UnquantizedLinearMethod, (
--
modelopt/torch/quantization/plugins/vllm.py-159-    def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-160-        self.w13_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py-161-        self.w2_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:162:        self.w13_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py:163:        self.w2_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-164-        self.w13_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-165-        self.w2_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-166-        self.w13_output_quantizer.disable()
--
modelopt/torch/quantization/plugins/vllm.py-181-        if B is self.w13_weight:
modelopt/torch/quantization/plugins/vllm.py-182-            # First layer of expert
modelopt/torch/quantization/plugins/vllm.py-183-            A = self.w13_input_quantizer(A)  # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:184:            if self.w13_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-185-                original_weight = self.w13_weight
modelopt/torch/quantization/plugins/vllm.py:186:                self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
modelopt/torch/quantization/plugins/vllm.py-187-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-188-                self.w13_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-189-            else:
--
modelopt/torch/quantization/plugins/vllm.py-192-                C[:] = self.w13_output_quantizer(C)
modelopt/torch/quantization/plugins/vllm.py-193-        elif B is self.w2_weight:
modelopt/torch/quantization/plugins/vllm.py-194-            A = self.w2_input_quantizer(A)  # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:195:            if self.w2_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-196-                original_weight = self.w2_weight
modelopt/torch/quantization/plugins/vllm.py:197:                self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
modelopt/torch/quantization/plugins/vllm.py-198-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-199-                self.w2_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-200-            else:
--
modelopt/torch/quantization/plugins/vllm.py-232-        # the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
modelopt/torch/quantization/plugins/vllm.py-233-        for i in range(self.w13_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-234-            self.w13_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:235:                self.w13_weight_quantizer(self.w13_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-236-                    self.w13_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-237-                )
modelopt/torch/quantization/plugins/vllm.py-238-            )
modelopt/torch/quantization/plugins/vllm.py:239:        self.w13_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-240-        for i in range(self.w2_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-241-            self.w2_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:242:                self.w2_weight_quantizer(self.w2_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-243-                    self.w2_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-244-                )
modelopt/torch/quantization/plugins/vllm.py-245-            )
modelopt/torch/quantization/plugins/vllm.py:246:        self.w2_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-247-
modelopt/torch/quantization/plugins/vllm.py-248-        torch.cuda.empty_cache()
modelopt/torch/quantization/plugins/vllm.py-249-
--
modelopt/torch/quantization/nn/modules/quant_module.py-121-
modelopt/torch/quantization/nn/modules/quant_module.py-122-    def fold_weight(self, keep_attrs: bool = False):
modelopt/torch/quantization/nn/modules/quant_module.py-123-        """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_module.py:124:        # Handle all attributes that end with _weight_quantizer
modelopt/torch/quantization/nn/modules/quant_module.py-125-        for name in dir(self):
modelopt/torch/quantization/nn/modules/quant_module.py-126-            attr = getattr(self, name)
modelopt/torch/quantization/nn/modules/quant_module.py-127-            if (
modelopt/torch/quantization/nn/modules/quant_module.py:128:                name.endswith("weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_module.py-129-                and isinstance(attr, TensorQuantizer)
modelopt/torch/quantization/nn/modules/quant_module.py-130-                and attr.fake_quant
modelopt/torch/quantization/nn/modules/quant_module.py-131-            ):
modelopt/torch/quantization/nn/modules/quant_module.py:132:                # Get the corresponding weight name by removing _weight_quantizer suffix
modelopt/torch/quantization/nn/modules/quant_module.py-133-                weight_name = name[:-10]
modelopt/torch/quantization/nn/modules/quant_module.py-134-
modelopt/torch/quantization/nn/modules/quant_module.py-135-                assert hasattr(self, weight_name), (
--
modelopt/torch/quantization/nn/modules/quant_module.py-203-    Quantized linear modules are modules where both the input and the weight are quantized.
modelopt/torch/quantization/nn/modules/quant_module.py-204-    """
modelopt/torch/quantization/nn/modules/quant_module.py-205-
modelopt/torch/quantization/nn/modules/quant_module.py:206:    weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_module.py-207-    _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_module.py-208-    default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_module.py-209-
--
modelopt/torch/quantization/nn/modules/quant_module.py-219-    `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_module.py-220-    def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
modelopt/torch/quantization/nn/modules/quant_module.py-221-        if module._enable_weight_quantization or is_torch_export_mode():
modelopt/torch/quantization/nn/modules/quant_module.py:222:            return module.weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_module.py-223-        return weight
modelopt/torch/quantization/nn/modules/quant_module.py-224-
modelopt/torch/quantization/nn/modules/quant_module.py-225-    def forward(self, input, *args, **kwargs):
--
modelopt/torch/quantization/nn/modules/quant_module.py-234-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_module.py-235-        super()._setup()
modelopt/torch/quantization/nn/modules/quant_module.py-236-        self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_module.py:237:            "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_module.py-238-        )
modelopt/torch/quantization/nn/modules/quant_module.py-239-        self._register_temp_attribute("_enable_weight_quantization", False)
modelopt/torch/quantization/nn/modules/quant_module.py-240-        self._register_dynamic_attribute("weight", self._get_quantized_weight)
--
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1319-        return super()._fake_quantize(inputs)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1320-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1321-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py:1322:class SequentialQuantizer(nn.Sequential):
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1323-    """A sequential container for  :class:`TensorQuantizer` modules.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1324-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1325-    This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-47-class QuantRNNBase(QuantModule):
modelopt/torch/quantization/nn/modules/quant_rnn.py-48-    """Base class for quantized RNN modules."""
modelopt/torch/quantization/nn/modules/quant_rnn.py-49-
modelopt/torch/quantization/nn/modules/quant_rnn.py:50:    weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py-51-    _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_rnn.py-52-    default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_rnn.py-53-    default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-75-        self._enable_weight_quantization = False
modelopt/torch/quantization/nn/modules/quant_rnn.py-76-
modelopt/torch/quantization/nn/modules/quant_rnn.py-77-    `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_rnn.py:78:    def _get_quantized_weight_handler(weight_quantizer_name: str):
modelopt/torch/quantization/nn/modules/quant_rnn.py-79-        def _get_quantized_weight(module: "QuantRNNBase", weight: torch.Tensor):
modelopt/torch/quantization/nn/modules/quant_rnn.py-80-            if module._enable_weight_quantization:
modelopt/torch/quantization/nn/modules/quant_rnn.py:81:                weight_quantizer = getattr(module, weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py:82:                return weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-83-            return weight
modelopt/torch/quantization/nn/modules/quant_rnn.py-84-
modelopt/torch/quantization/nn/modules/quant_rnn.py-85-        return _get_quantized_weight
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-102-        for name, _ in self.named_parameters():
modelopt/torch/quantization/nn/modules/quant_rnn.py-103-            if name.startswith("weight"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-104-                # to be compatible with our current config, the name is some what weird
modelopt/torch/quantization/nn/modules/quant_rnn.py:105:                # it would be weight_xxx_weight_quantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py:106:                weight_quantizer_name = name + "_weight_quantizer"
modelopt/torch/quantization/nn/modules/quant_rnn.py-107-                self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:108:                    weight_quantizer_name, TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-109-                )
modelopt/torch/quantization/nn/modules/quant_rnn.py-110-                self._register_dynamic_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:111:                    name, self._get_quantized_weight_handler(weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py-112-                )
modelopt/torch/quantization/nn/modules/quant_rnn.py-113-        # for cells
modelopt/torch/quantization/nn/modules/quant_rnn.py-114-        self._register_temp_attribute("_input_quantizers", [])
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-143-        for iq in self._input_quantizers + self._proj_input_quantizers:
modelopt/torch/quantization/nn/modules/quant_rnn.py-144-            iq.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-145-
modelopt/torch/quantization/nn/modules/quant_rnn.py:146:    def _disable_weight_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-147-        for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:148:            if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-149-                module.disable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-150-
modelopt/torch/quantization/nn/modules/quant_rnn.py:151:    def _enable_weight_quantizer(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-152-        for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:153:            if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-154-                module.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-155-
modelopt/torch/quantization/nn/modules/quant_rnn.py-156-    def _setup(self):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-46-        """Quantized version of a generic linear functional."""
modelopt/torch/quantization/nn/modules/quant_linear.py-47-        output = getattr(package, func_name)(
modelopt/torch/quantization/nn/modules/quant_linear.py-48-            self.input_quantizer(input),
modelopt/torch/quantization/nn/modules/quant_linear.py:49:            self.weight_quantizer(weight),
modelopt/torch/quantization/nn/modules/quant_linear.py-50-            *args,
modelopt/torch/quantization/nn/modules/quant_linear.py-51-            **kwargs,
modelopt/torch/quantization/nn/modules/quant_linear.py-52-        )
--
modelopt/torch/quantization/nn/modules/quant_linear.py-119-
modelopt/torch/quantization/nn/modules/quant_linear.py-120-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-121-        """Overrides and bypass the _setup function."""
modelopt/torch/quantization/nn/modules/quant_linear.py:122:        if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
modelopt/torch/quantization/nn/modules/quant_linear.py-123-            return
modelopt/torch/quantization/nn/modules/quant_linear.py:124:        self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-125-
modelopt/torch/quantization/nn/modules/quant_linear.py-126-    def _not_sequential_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_linear.py:127:        return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
modelopt/torch/quantization/nn/modules/quant_linear.py-128-            self.input_quantizer, TensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-129-        )
modelopt/torch/quantization/nn/modules/quant_linear.py-130-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-138-        """Compute the LoRA residual if present, otherwise return None."""
modelopt/torch/quantization/nn/modules/quant_linear.py-139-        if (
modelopt/torch/quantization/nn/modules/quant_linear.py-140-            self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:141:            and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:142:            and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-143-        ):
modelopt/torch/quantization/nn/modules/quant_linear.py:144:            lora_a = F.linear(input, weight=self.weight_quantizer.svdquant_lora_a)
modelopt/torch/quantization/nn/modules/quant_linear.py:145:            lora_b = F.linear(lora_a, weight=self.weight_quantizer.svdquant_lora_b)
modelopt/torch/quantization/nn/modules/quant_linear.py-146-            return lora_b
modelopt/torch/quantization/nn/modules/quant_linear.py-147-        return None
modelopt/torch/quantization/nn/modules/quant_linear.py-148-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-150-        """SVDQuant layer forward function."""
modelopt/torch/quantization/nn/modules/quant_linear.py-151-        has_svdquant_lora = (
modelopt/torch/quantization/nn/modules/quant_linear.py-152-            self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:153:            and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:154:            and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-155-        )
modelopt/torch/quantization/nn/modules/quant_linear.py-156-        if has_svdquant_lora:
modelopt/torch/quantization/nn/modules/quant_linear.py-157-            input = self._apply_pre_quant_scale(input)
--
modelopt/torch/quantization/nn/modules/quant_linear.py-166-        """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_linear.py-167-        super().fold_weight(keep_attrs)
modelopt/torch/quantization/nn/modules/quant_linear.py-168-        if (
modelopt/torch/quantization/nn/modules/quant_linear.py:169:            hasattr(self, "weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_linear.py-170-            and hasattr(self, "weight")
modelopt/torch/quantization/nn/modules/quant_linear.py:171:            and self.weight_quantizer.fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-172-        ):
modelopt/torch/quantization/nn/modules/quant_linear.py-173-            if (
modelopt/torch/quantization/nn/modules/quant_linear.py-174-                self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:175:                and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:176:                and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-177-            ):
modelopt/torch/quantization/nn/modules/quant_linear.py-178-                self.weight.data.copy_(
modelopt/torch/quantization/nn/modules/quant_linear.py-179-                    self.weight
modelopt/torch/quantization/nn/modules/quant_linear.py:180:                    + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
modelopt/torch/quantization/nn/modules/quant_linear.py-181-                )
modelopt/torch/quantization/nn/modules/quant_linear.py-182-            if not keep_attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py-183-                _attrs = [
--
modelopt/torch/quantization/nn/modules/quant_linear.py-185-                    "_svdquant_lora_b",
modelopt/torch/quantization/nn/modules/quant_linear.py-186-                ]
modelopt/torch/quantization/nn/modules/quant_linear.py-187-                for attr in _attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py:188:                    if hasattr(self.weight_quantizer, attr):
modelopt/torch/quantization/nn/modules/quant_linear.py:189:                        delattr(self.weight_quantizer, attr)
modelopt/torch/quantization/nn/modules/quant_linear.py-190-
modelopt/torch/quantization/nn/modules/quant_linear.py-191-
modelopt/torch/quantization/nn/modules/quant_linear.py-192-class RealQuantLinear(QuantModule):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-241-
modelopt/torch/quantization/nn/modules/quant_linear.py-242-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-243-        class RealQuantParameterDict(dict):
modelopt/torch/quantization/nn/modules/quant_linear.py:244:            def __init__(self, weight_quantizer: TensorQuantizer, *args, **kwargs):
modelopt/torch/quantization/nn/modules/quant_linear.py-245-                super().__init__(*args, **kwargs)
modelopt/torch/quantization/nn/modules/quant_linear.py:246:                self.weight_quantizer = weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-247-
modelopt/torch/quantization/nn/modules/quant_linear.py-248-            def __setitem__(self, key, value):
modelopt/torch/quantization/nn/modules/quant_linear.py-249-                if (
modelopt/torch/quantization/nn/modules/quant_linear.py-250-                    key == "weight"
modelopt/torch/quantization/nn/modules/quant_linear.py:251:                    and self.weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py:252:                    and self.weight_quantizer.is_enabled
modelopt/torch/quantization/nn/modules/quant_linear.py:253:                    and not self.weight_quantizer._fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-254-                    and value.element_size() > 1
modelopt/torch/quantization/nn/modules/quant_linear.py-255-                ):
modelopt/torch/quantization/nn/modules/quant_linear.py-256-                    # reset the amax for later calibration
modelopt/torch/quantization/nn/modules/quant_linear.py-257-                    if (
modelopt/torch/quantization/nn/modules/quant_linear.py:258:                        self.weight_quantizer.amax is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:259:                        and self.weight_quantizer.amax.is_meta
modelopt/torch/quantization/nn/modules/quant_linear.py-260-                    ):
modelopt/torch/quantization/nn/modules/quant_linear.py:261:                        delattr(self.weight_quantizer, "_amax")
modelopt/torch/quantization/nn/modules/quant_linear.py:262:                        self.weight_quantizer.amax = self.weight_quantizer._get_amax(value)
modelopt/torch/quantization/nn/modules/quant_linear.py:263:                        self.weight_quantizer._calibrator.reset()
modelopt/torch/quantization/nn/modules/quant_linear.py-264-                    # compress the weight
modelopt/torch/quantization/nn/modules/quant_linear.py:265:                    real_quant_tensor = self.weight_quantizer(value)
modelopt/torch/quantization/nn/modules/quant_linear.py-266-                    real_quant_value = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/nn/modules/quant_linear.py-267-                    del value  # delete the original weight to save memory
modelopt/torch/quantization/nn/modules/quant_linear.py-268-                    value = real_quant_value
--
modelopt/torch/quantization/nn/modules/quant_linear.py-270-
modelopt/torch/quantization/nn/modules/quant_linear.py-271-        # Monkey patch the _parameters.__setitem__ to real quant the weight when loading
modelopt/torch/quantization/nn/modules/quant_linear.py-272-        # HF accelerate loads the weight by directly assigning the weight through the _parameters dict.
modelopt/torch/quantization/nn/modules/quant_linear.py:273:        self._parameters = RealQuantParameterDict(self.weight_quantizer, self._parameters)
modelopt/torch/quantization/nn/modules/quant_linear.py-274-
modelopt/torch/quantization/nn/modules/quant_linear.py-275-        # Function to dynamically override load_state_dict
modelopt/torch/quantization/nn/modules/quant_linear.py-276-        dynamically_update_state_methods(self)
--
modelopt/torch/quantization/plugins/transformer_engine.py-78-            idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-79-            weight, inputs = args[idx], args[idx + 1]
modelopt/torch/quantization/plugins/transformer_engine.py-80-            remaining_args = args[idx + 2 :]
modelopt/torch/quantization/plugins/transformer_engine.py:81:            weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-82-            inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-83-            new_args = (weight, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-84-            new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-90-            idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-91-            weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
modelopt/torch/quantization/plugins/transformer_engine.py-92-            remaining_args = args[idx + 3 :]
modelopt/torch/quantization/plugins/transformer_engine.py:93:            weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-94-            inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-95-            new_args = (weight, weight_fp8, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-96-            new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-170-        weights_and_biases = args[-2 * num_gemms :]
modelopt/torch/quantization/plugins/transformer_engine.py-171-        weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
modelopt/torch/quantization/plugins/transformer_engine.py-172-        quantized_inputs = self.input_quantizer(inp)
modelopt/torch/quantization/plugins/transformer_engine.py:173:        quantized_weights = [self.weight_quantizer(weight) for weight in weights]
modelopt/torch/quantization/plugins/transformer_engine.py-174-
modelopt/torch/quantization/plugins/transformer_engine.py-175-        output = getattr(package, func_name)(
modelopt/torch/quantization/plugins/transformer_engine.py-176-            *(
--
modelopt/torch/quantization/plugins/transformer_engine.py-208-
modelopt/torch/quantization/plugins/transformer_engine.py-209-    `@staticmethod`
modelopt/torch/quantization/plugins/transformer_engine.py-210-    def forward(ctx, inp, ln_weight, ln_bias, weight, *args, **kwargs):
modelopt/torch/quantization/plugins/transformer_engine.py:211:        input_quantizer, weight_quantizer = _QuantLayerNormLinearFunc.modelopt_quantizers
modelopt/torch/quantization/plugins/transformer_engine.py-212-
modelopt/torch/quantization/plugins/transformer_engine.py:213:        qweight = weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-214-        qweight.requires_grad = weight.requires_grad
modelopt/torch/quantization/plugins/transformer_engine.py-215-        if ctx is not None:
modelopt/torch/quantization/plugins/transformer_engine.py-216-            # We need to recompute the quantized input for the backward pass, so we save the input_quantizer
--
modelopt/torch/quantization/plugins/transformer_engine.py-291-        # This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
modelopt/torch/quantization/plugins/transformer_engine.py-292-        _QuantLayerNormLinearFunc.modelopt_quantizers = (
modelopt/torch/quantization/plugins/transformer_engine.py-293-            self.input_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py:294:            self.weight_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py-295-        )
modelopt/torch/quantization/plugins/transformer_engine.py-296-        with replace_function(
modelopt/torch/quantization/plugins/transformer_engine.py-297-            te_layernorm_linear,
--
modelopt/torch/quantization/plugins/peft.py-34-class _QuantLoraLinear(QuantModule):
modelopt/torch/quantization/plugins/peft.py-35-    def _setup(self):
modelopt/torch/quantization/plugins/peft.py-36-        self.input_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py:37:        self.weight_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py-38-        self.output_quantizer = Tenso

... [stdout truncated: output exceeded 50000 characters (209210 total). Please modify your command to reduce output size, for example by:
- Using grep/rg to filter results
- Using head/tail to limit lines
- Using --max-count flag with grep/rg
- Processing output with awk/sed to extract only needed information]

Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Full context of sync_moe_gate_up_amax function ==="
sed -n '1174,1221p' modelopt/torch/export/layer_utils.py

echo
echo "=== Check if SequentialQuantizer has amax property ==="
sed -n '1322,1400p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -80

echo
echo "=== Check imports and _GATE_UP_PAIRS ==="
sed -n '1,50p' modelopt/torch/export/layer_utils.py
rg -n '_GATE_UP_PAIRS|from.*SequentialQuantizer|import.*SequentialQuantizer' modelopt/torch/export/layer_utils.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== Full context of sync_moe_gate_up_amax function ===
# Gate/up naming pairs for standard (unfused) MoE architectures.
# Fused variants (gate_up_proj, linear_fc1) already share a single quantizer and need no sync.
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]


def sync_moe_gate_up_amax(model: nn.Module) -> int:
    """Take element-wise max of gate and up weight quantizer amaxes per expert.

    Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
    require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
    syncing amaxes before quantization ensures the per-block weight_scale values
    are computed against a consistent global scale.

    Only affects standard MoE models with separate gate/up linear layers
    (e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
    (e.g. Llama4, GptOss) are unaffected.

    Returns:
        Number of expert gate/up pairs whose amaxes were synced.
    """
    synced = 0
    for _, sub_module in model.named_modules():
        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
            continue
        if not hasattr(sub_module.experts, "__iter__"):
            continue
        for expert in sub_module.experts:
            for gate_name, up_name in _GATE_UP_PAIRS:
                gate_linear = getattr(expert, gate_name, None)
                up_linear = getattr(expert, up_name, None)
                if gate_linear is None or up_linear is None:
                    continue
                gate_wq = getattr(gate_linear, "weight_quantizer", None)
                up_wq = getattr(up_linear, "weight_quantizer", None)
                if gate_wq is None or up_wq is None:
                    break
                gate_amax = getattr(gate_wq, "amax", None)
                up_amax = getattr(up_wq, "amax", None)
                if gate_amax is None or up_amax is None:
                    break
                if not torch.equal(gate_amax, up_amax):
                    shared_amax = torch.max(gate_amax, up_amax)
                    gate_wq.amax = shared_amax
                    up_wq.amax = shared_amax.clone()
                    synced += 1
                break
    return synced


=== Check if SequentialQuantizer has amax property ===
class SequentialQuantizer(nn.Sequential):
    """A sequential container for  :class:`TensorQuantizer` modules.

    This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
    :class:`TensorQuantizer` modules and containerize them similar to :class:`torch.nn.Sequential`.

    We delegate certain properties and methods to all contained quantizers.
    In the case of conflicts, the first quantizer's property or method takes priority.

    `SequentialQuantizer` is useful in cases like INT4 weights, FP8 activations where weight quantization is not the
    same as the gemm quantization. It allows for applying multiple quantization formats to the same tensor in sequence.

    Use `SequentialQuantizer` methods in lower level implementations for better code organization and readability.

    Args:
        quantizers (TensorQuantizer): :class:`TensorQuantizer` modules to be added to the container.

    """

    _delegated_properties = ["fake_quant", "is_enabled", "amax"]
    _delegated_methods = [
        "reset_amax",
        "disable",
        "enable",
        "load_calib_amax",
        "load_calib_bias",
    ]

    def __init__(self, *quantizers: TensorQuantizer):
        """Initialize SequentialQuantizer module."""
        super().__init__(*quantizers)
        assert all(isinstance(q, TensorQuantizer) for q in self), (
            "All quantizers must be a TensorQuantizer."
        )

    def __getattr__(self, name):
        """Delegate properties and methods to all contained quantizers."""
        if name in self._delegated_properties:
            # Return the property of the first quantizer
            return getattr(self[0], name)

        if name in self._delegated_methods:

            def method_wrapper(*args, **kwargs):
                outputs = getattr(self[0], name)(*args, **kwargs)
                for quantizer in self[1:]:
                    outputs = getattr(quantizer, name)(*args, **kwargs)
                return outputs

            return method_wrapper

        # Defer to super class for attributes not handled here
        return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in self._delegated_properties:
            for quantizer in self:
                setattr(quantizer, name, value)
        else:
            super().__setattr__(name, value)

    def get_modelopt_state(self) -> dict[str, Any]:
        """Get meta state to be saved in checkpoint."""
        return {"num_quantizers": len(self), "is_sequential_quantizer": True}

    def set_from_attribute_config(
        self,
        attributes: list[dict[str, Any] | QuantizerAttributeConfig]
        | dict[str, Any]
        | QuantizerAttributeConfig,
    ):
        """Set the attributes of contained quantizers from a list of attribute_dicts."""
        if not isinstance(attributes, (list, tuple)):
            assert isinstance(attributes, (dict, QuantizerAttributeConfig)), (
                "attributes must be a list or a dict."
            )
            attributes = [attributes] * len(self)

        for attribute, quantizer in zip(attributes, self):

=== Check imports and _GATE_UP_PAIRS ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils for model_config export.

Some of the logics in this file are empirical and needs constant update if exceptions occur.
"""

from warnings import warn

import torch
import torch.nn as nn

try:
    from transformers.activations import ACT2FN
except Exception:
    warn("Cannot find transformers package. Hugginface modules cannot be exported.")

from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import import_plugin

from ..quantization.nn import SequentialQuantizer, TensorQuantizer
from .hf_config_map import HF_CONFIG_MAP
from .mcore_config_map import MCORE_CONFIG_MAP
from .model_config import (
    LAYERNORM_DEFAULT,
    LAYERNORM_RMS,
    LINEAR_COLUMN,
    LINEAR_GROUP,
    LINEAR_ROW,
    QUANTIZATION_FP8,
    QUANTIZATION_NONE,
    QUANTIZATION_NVFP4,
    AttentionConfig,
    ConvConfig,
    DecoderLayerConfig,
    EmbeddingConfig,
    ExpertConfig,
34:from ..quantization.nn import SequentialQuantizer, TensorQuantizer
1176:_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
1201:            for gate_name, up_name in _GATE_UP_PAIRS:

Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check build_mlp_config handling at lines 827-831 ==="
sed -n '820,840p' modelopt/torch/export/layer_utils.py

echo
echo "=== Check if SequentialQuantizer.amax delegation actually works ==="
sed -n '1342,1380p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py

echo
echo "=== Verify the loop structure - what does break actually break? ==="
sed -n '1195,1225p' modelopt/torch/export/layer_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== Check build_mlp_config handling at lines 827-831 ===
            or get_quantization_format(module) in [QUANTIZATION_FP8, QUANTIZATION_NVFP4]
        )
    ):
        preprocess_linear_fusion([fc_linear, gate_linear])

    if fc_linear is not None:
        weight_quantizer = None
        if hasattr(fc_linear, "weight_quantizer"):
            weight_quantizer = fc_linear.weight_quantizer
            if isinstance(weight_quantizer, SequentialQuantizer):
                weight_quantizer = weight_quantizer[0]

        # swap fused fc and gate
        if decoder_type in ["chatglm", "phi3"]:
            weights = torch.chunk(fc_linear.weight, 2, dim=0)
            weights = (weights[1], weights[0])
            fc_linear.weight.data = torch.cat(weights, dim=0)

            if (
                weight_quantizer is not None
                and weight_quantizer.is_enabled

=== Check if SequentialQuantizer.amax delegation actually works ===
    _delegated_methods = [
        "reset_amax",
        "disable",
        "enable",
        "load_calib_amax",
        "load_calib_bias",
    ]

    def __init__(self, *quantizers: TensorQuantizer):
        """Initialize SequentialQuantizer module."""
        super().__init__(*quantizers)
        assert all(isinstance(q, TensorQuantizer) for q in self), (
            "All quantizers must be a TensorQuantizer."
        )

    def __getattr__(self, name):
        """Delegate properties and methods to all contained quantizers."""
        if name in self._delegated_properties:
            # Return the property of the first quantizer
            return getattr(self[0], name)

        if name in self._delegated_methods:

            def method_wrapper(*args, **kwargs):
                outputs = getattr(self[0], name)(*args, **kwargs)
                for quantizer in self[1:]:
                    outputs = getattr(quantizer, name)(*args, **kwargs)
                return outputs

            return method_wrapper

        # Defer to super class for attributes not handled here
        return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in self._delegated_properties:
            for quantizer in self:
                setattr(quantizer, name, value)
        else:

=== Verify the loop structure - what does break actually break? ===
     1	    for _, sub_module in model.named_modules():
     2	        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
     3	            continue
     4	        if not hasattr(sub_module.experts, "__iter__"):
     5	            continue
     6	        for expert in sub_module.experts:
     7	            for gate_name, up_name in _GATE_UP_PAIRS:
     8	                gate_linear = getattr(expert, gate_name, None)
     9	                up_linear = getattr(expert, up_name, None)
    10	                if gate_linear is None or up_linear is None:
    11	                    continue
    12	                gate_wq = getattr(gate_linear, "weight_quantizer", None)
    13	                up_wq = getattr(up_linear, "weight_quantizer", None)
    14	                if gate_wq is None or up_wq is None:
    15	                    break
    16	                gate_amax = getattr(gate_wq, "amax", None)
    17	                up_amax = getattr(up_wq, "amax", None)
    18	                if gate_amax is None or up_amax is None:
    19	                    break
    20	                if not torch.equal(gate_amax, up_amax):
    21	                    shared_amax = torch.max(gate_amax, up_amax)
    22	                    gate_wq.amax = shared_amax
    23	                    up_wq.amax = shared_amax.clone()
    24	                    synced += 1
    25	                break
    26	    return synced
    27	
    28	
    29	def build_stacked_experts(
    30	    experts: nn.Module,
    31	    linear_names: list[str],

Use continue instead of break to check all gate/up pairs (not just the first) per expert.

On lines 1209 and 1213, break exits the pair-checking loop entirely if any pair lacks a quantizer or amax. Since _GATE_UP_PAIRS has two pairs (gate_proj/up_proj and w1/w3), a missing quantizer in the first pair prevents checking the second. This should be continue to skip only the current pair and attempt the next one. (The final break on line 1214 correctly exits after syncing.)

Proposed fix
                 gate_wq = getattr(gate_linear, "weight_quantizer", None)
                 up_wq = getattr(up_linear, "weight_quantizer", None)
                 if gate_wq is None or up_wq is None:
-                    break
+                    continue
                 gate_amax = getattr(gate_wq, "amax", None)
                 up_amax = getattr(up_wq, "amax", None)
                 if gate_amax is None or up_amax is None:
-                    break
+                    continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/layer_utils.py` around lines 1206 - 1213, The loop that
iterates over _GATE_UP_PAIRS incorrectly uses break when a pair is missing a
quantizer or amax, which stops checking remaining pairs; update the logic in the
loop that accesses gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to
use continue instead of break so only the current pair is skipped and the next
pair in _GATE_UP_PAIRS is checked; keep the final break that exits after
successfully syncing unchanged.

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Edwardf0t1 Edwardf0t1 requested a review from cjluo-nv March 14, 2026 05:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants