Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,55 @@ def set_expert_quantizer_amax(
return uncalibrated_modules


# Gate/up naming pairs for standard (unfused) MoE architectures.
# Fused variants (gate_up_proj, linear_fc1) already share a single quantizer and need no sync.
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]


def sync_moe_gate_up_amax(model: nn.Module) -> int:
"""Take element-wise max of gate and up weight quantizer amaxes per expert.

Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
syncing amaxes before quantization ensures the per-block weight_scale values
are computed against a consistent global scale.

Only affects standard MoE models with separate gate/up linear layers
(e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
(e.g. Llama4, GptOss) are unaffected.

Returns:
Number of expert gate/up pairs whose amaxes were synced.
"""
synced = 0
for _, sub_module in model.named_modules():
if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
continue
if not hasattr(sub_module.experts, "__iter__"):
continue
for expert in sub_module.experts:
for gate_name, up_name in _GATE_UP_PAIRS:
gate_linear = getattr(expert, gate_name, None)
up_linear = getattr(expert, up_name, None)
if gate_linear is None or up_linear is None:
continue
gate_wq = getattr(gate_linear, "weight_quantizer", None)
up_wq = getattr(up_linear, "weight_quantizer", None)
if gate_wq is None or up_wq is None:
break
gate_amax = getattr(gate_wq, "amax", None)
up_amax = getattr(up_wq, "amax", None)
if gate_amax is None or up_amax is None:
break
Comment on lines +1206 to +1213
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

❓ Verification inconclusive

Script executed:

#!/bin/bash
set -euo pipefail

echo "== SequentialQuantizer API =="
fd 'tensor_quantizer.py' --exec rg -n -C2 'class SequentialQuantizer|def amax|@property'

echo
echo "== Where SequentialQuantizer is used for weight quantizers =="
rg -n -C3 --type=py 'SequentialQuantizer\(|weight_quantizer' modelopt/torch

echo
echo "== Existing unwrapping pattern in layer_utils.py =="
rg -n -C3 --type=py 'isinstance\(.*SequentialQuantizer\)|weight_quantizer\.amax|sync_moe_gate_up_amax' modelopt/torch/export/layer_utils.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

== SequentialQuantizer API ==
252-        return qtensor.dequantize(**kwarg)
253-
254:    `@property`
255-    def num_bits(self):
256-        """Return num_bits for quantization."""
--
262-        self._calibrator._num_bits = value
263-
264:    `@property`
265-    def maxbound(self):
266-        """Return maxbound for quantization."""
--
271-        return (1 << (self._num_bits - 1 + int(self._unsigned))) - 1
272-
273:    `@property`
274-    def unsigned(self):
275-        """Return True if unsigned quantization is used."""
--
281-        self._calibrator._unsigned = value
282-
283:    `@property`
284-    def pre_quant_scale(self):
285-        """Return pre_quant_scale used for smoothquant."""
--
305-            )
306-
307:    `@property`
308:    def amax(self):
309-        """Return amax for quantization."""
310-        if not hasattr(self, "_amax") or self.is_mx_format:
--
314-
315-    `@amax.setter`
316:    def amax(self, value):
317-        assert value is not None, "amax cannot be set to None."
318-
--
341-            self._bias_calibrator.reset()
342-
343:    `@property`
344-    def step_size(self):
345-        """Return step size for integer quantization."""
--
352-        return self._amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)
353-
354:    `@property`
355-    def axis(self):
356-        """Return axis for quantization."""
--
362-        self._calibrator._axis = value
363-
364:    `@property`
365-    def block_sizes(self):
366-        """Return block_sizes for quantization."""
--
372-        self._block_sizes = value
373-
374:    `@property`
375-    def bias(self):
376-        """Return bias for quantization."""
--
379-        return self._bias
380-
381:    `@property`
382-    def bias_axis(self):
383-        """Return bias_axis for quantization."""
--
392-        self._bias_axis = value
393-
394:    `@property`
395-    def bias_method(self):
396-        """Return bias_method for quantization."""
--
399-        return self._bias.get("method", "mean")
400-
401:    `@property`
402-    def bias_type(self):
403-        """Return bias_type for quantization."""
--
414-        self._bias["type"] = value
415-
416:    `@property`
417-    def bias_value(self):
418-        """Return bias for quantization."""
--
435-            self._bias_value.data.copy_(value.clone().detach().to(self._bias_value.device))
436-
437:    `@property`
438-    def bias_calibrator(self):
439-        """Return bias_calibrator for quantization."""
--
450-        return self._bias_calibrator
451-
452:    `@property`
453-    def fake_quant(self):
454-        """Return True if fake quantization is used."""
455-        return self._fake_quant
456-
457:    `@property`
458-    def narrow_range(self):
459-        """Return True if symmetric integer range for signed quantization is used."""
--
464-        self._narrow_range = value
465-
466:    `@property`
467-    def is_enabled(self):
468-        """Return true if the modules is not disabled."""
--
480-        self._disabled = False
481-
482:    `@property`
483-    def trt_high_precision_dtype(self):
484-        """Return True if FP16 AMAX is used when exporting the model."""
--
489-        self._trt_high_precision_dtype = value
490-
491:    `@property`
492-    def is_mx_format(self):
493-        """Check if is MX formats."""
--
521-            raise NotImplementedError()
522-
523:    `@property`
524-    def is_static_block_quant(self):
525-        """Check if is static block quantization."""
--
530-        )
531-
532:    `@property`
533-    def rotate_is_enabled(self):
534-        """Check if rotate is enabled in quant config."""
535-        return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
536-
537:    `@property`
538-    def rotate_is_fp32(self):
539-        """Check if rotation needs to be computed in float32."""
--
1286-        return tq
1287-
1288:    `@property`
1289-    def global_amax(self):
1290-        """Return global_amax for quantization."""
--
1320-
1321-
1322:class SequentialQuantizer(nn.Sequential):
1323-    """A sequential container for  :class:`TensorQuantizer` modules.
1324-

== Where SequentialQuantizer is used for weight quantizers ==
modelopt/torch/quantization/utils.py-228-
modelopt/torch/quantization/utils.py-229-    # the standard weight and quantizer case
modelopt/torch/quantization/utils.py-230-    weight = getattr(module, "weight", None)
modelopt/torch/quantization/utils.py:231:    weight_quantizer = getattr(module, "weight_quantizer", None)
modelopt/torch/quantization/utils.py:232:    if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
modelopt/torch/quantization/utils.py-233-        yield "weight"
modelopt/torch/quantization/utils.py-234-
modelopt/torch/quantization/utils.py-235-    # other weight and quantizer case
modelopt/torch/quantization/utils.py-236-    for name, _ in module.named_parameters(recurse=False):
modelopt/torch/quantization/utils.py-237-        weight = getattr(module, name, None)
modelopt/torch/quantization/utils.py:238:        weight_quantizer = getattr(module, f"{name}_weight_quantizer", None)
modelopt/torch/quantization/utils.py-239-        if isinstance(weight, nn.Parameter) and isinstance(
modelopt/torch/quantization/utils.py:240:            weight_quantizer, (TensorQuantizer, SequentialQuantizer)
modelopt/torch/quantization/utils.py-241-        ):
modelopt/torch/quantization/utils.py-242-            yield name
modelopt/torch/quantization/utils.py-243-
--
modelopt/torch/quantization/utils.py-246-QuantizerAttrNames = namedtuple(
modelopt/torch/quantization/utils.py-247-    "QuantizerAttrNames",
modelopt/torch/quantization/utils.py-248-    (
modelopt/torch/quantization/utils.py:249:        "weight_quantizer",
modelopt/torch/quantization/utils.py-250-        "input_quantizer",
modelopt/torch/quantization/utils.py-251-        "output_quantizer",
modelopt/torch/quantization/utils.py-252-        "weight_scale",
--
modelopt/torch/quantization/utils.py-261-    """Get all the quantizer related attribute names for a given weight name."""
modelopt/torch/quantization/utils.py-262-    prefix = f"{weight_name}_" if weight_name != "weight" else ""
modelopt/torch/quantization/utils.py-263-    return QuantizerAttrNames(
modelopt/torch/quantization/utils.py:264:        weight_quantizer=f"{prefix}weight_quantizer",
modelopt/torch/quantization/utils.py-265-        input_quantizer=f"{prefix}input_quantizer",
modelopt/torch/quantization/utils.py-266-        output_quantizer=f"{prefix}output_quantizer",
modelopt/torch/quantization/utils.py-267-        weight_scale=f"{prefix}weight_scale",
--
modelopt/torch/quantization/utils.py-285-    return (
modelopt/torch/quantization/utils.py-286-        isinstance(module, QuantModule)
modelopt/torch/quantization/utils.py-287-        and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
modelopt/torch/quantization/utils.py:288:        and hasattr(module, "weight_quantizer")
modelopt/torch/quantization/utils.py-289-        and (
modelopt/torch/quantization/utils.py-290-            (getattr(module, "weight", None) is not None and module.weight.dim() == 2)
modelopt/torch/quantization/utils.py-291-            # module.weight0 check is required to support TEGroupedLinear
--
modelopt/torch/quantization/utils.py-329-    config["quant_cfg"]["*lora*"] = {"enable": False}
modelopt/torch/quantization/utils.py-330-    for layer in layers:
modelopt/torch/quantization/utils.py-331-        config["quant_cfg"][f"*{layer}.input_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py:332:        config["quant_cfg"][f"*{layer}.weight_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-333-        config["quant_cfg"][f"*{layer}.output_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-334-    return config
modelopt/torch/quantization/utils.py-335-
--
modelopt/torch/quantization/utils.py-537-
modelopt/torch/quantization/utils.py-538-    1. Takes the element-wise max of each ``input_quantizer`` amax across all experts
modelopt/torch/quantization/utils.py-539-       and writes it back, so every expert shares the same input amax.
modelopt/torch/quantization/utils.py:540:    2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
modelopt/torch/quantization/utils.py-541-       received no tokens during calibration), runs a weight-only ``max_calibrate``
modelopt/torch/quantization/utils.py-542-       to populate the missing amax.
modelopt/torch/quantization/utils.py-543-    """
--
modelopt/torch/quantization/utils.py-566-
modelopt/torch/quantization/utils.py-567-    for expert in experts:
modelopt/torch/quantization/utils.py-568-        for name, module in expert.named_modules():
modelopt/torch/quantization/utils.py:569:            if name.endswith("weight_quantizer") and module.is_enabled and module.amax is None:
modelopt/torch/quantization/utils.py:570:                weight = expert.state_dict().get(name.replace("weight_quantizer", "weight"))
modelopt/torch/quantization/utils.py-571-                if weight is not None:
modelopt/torch/quantization/utils.py-572-                    max_calibrate(module, lambda m, w=weight: m(w), distributed_sync=False)
modelopt/torch/quantization/utils.py-573-
--
modelopt/torch/quantization/utils.py-687-    """
modelopt/torch/quantization/utils.py-688-    original_fake_quant = []
modelopt/torch/quantization/utils.py-689-    for m in module.modules():
modelopt/torch/quantization/utils.py:690:        if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:691:            original_fake_quant.append(m.weight_quantizer._fake_quant)
modelopt/torch/quantization/utils.py:692:            m.weight_quantizer._fake_quant = True
modelopt/torch/quantization/utils.py-693-    yield
modelopt/torch/quantization/utils.py-694-    for m in module.modules():
modelopt/torch/quantization/utils.py:695:        if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:696:            m.weight_quantizer._fake_quant = original_fake_quant.pop(0)
modelopt/torch/quantization/utils.py-697-
modelopt/torch/quantization/utils.py-698-
modelopt/torch/quantization/utils.py-699-@contextmanager
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-97-    def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-98-        cls,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-99-        weight: torch.Tensor,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:100:        weight_quantizer,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-101-    ) -> torch.Tensor:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-102-        """Returns E8M0 scale from quantizer or computes from weight.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-103-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-107-        Args:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-108-            weight: The weight tensor. Can be 2D (out_dim, in_dim) or
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-109-                3D for MoE (num_experts, out_dim, in_dim).
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:110:            weight_quantizer: The weight quantizer with block_sizes and optional _scale.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-111-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-112-        Returns:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-113-            torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-114-        """
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:115:        assert hasattr(weight_quantizer, "block_sizes"), (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:116:            "weight_quantizer must have 'block_sizes' attribute"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-117-        )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:118:        assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:119:            f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-120-        )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-121-        assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-122-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-126-        # For 3D MoE: (num_experts, out_dim, in_dim // 32)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-127-        expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-128-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:129:        if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:130:            scale = weight_quantizer._scale
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-131-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-132-            assert scale.dtype == cls.SCALE_DTYPE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-133-                f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}"
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-53-        return cls.e2m1_bounds_on_device[device]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-54-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-55-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:56:    def _is_static_quantizer(cls, weight_quantizer) -> bool:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-57-        """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:58:        return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-59-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-60-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:61:    def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:62:        """Returns per tensor weight scaling factor from the weight_quantizer.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-63-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-64-        Handles both static NVFP4 quantizers (using global_amax) and
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-65-        dynamic quantizers (using _amax).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-66-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-67-        Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:68:            weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-69-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-70-        Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-71-            The global scaling factor as a float tensor.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-72-        """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:73:        if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:74:            return weight_quantizer.global_amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-75-        else:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:76:            assert hasattr(weight_quantizer, "_amax"), (
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-77-                "Weight quantizer does not have attribute amax"
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-78-            )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:79:            return weight_quantizer._amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-80-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-81-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-82-    def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-83-        cls,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:84:        weight_quantizer,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-85-        weight: torch.Tensor,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-86-        weights_scaling_factor_2: torch.Tensor | None = None,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-87-        keep_high_precision: bool = False,
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-92-        and dynamic quantizers (computing from weight tensor).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-93-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-94-        Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:95:            weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-96-            weight: The weight tensor (used for shape in static, values in dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-97-            weights_scaling_factor_2: Optional pre-computed global scale.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-98-            keep_high_precision: Whether to keep scales in high precision.
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-100-        Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-101-            Tuple of (per_block_scale, weights_scaling_factor_2).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-102-        """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:103:        block_size = weight_quantizer.block_sizes[-1]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-104-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-105-        if weights_scaling_factor_2 is None:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-106-            weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:107:                weight_quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-108-            )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-109-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:110:        if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-111-            # Static path: use pre-computed per-block amax values from quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:112:            global_amax = weight_quantizer.global_amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:113:            per_block_amax = weight_quantizer._amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-114-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-115-            # Compute scales in float
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-116-            per_block_scale_max = global_amax / 6.0
--
modelopt/torch/quantization/qtensor/base_qtensor.py-204-            # We dont compress meta tensors or None
modelopt/torch/quantization/qtensor/base_qtensor.py-205-            return False
modelopt/torch/quantization/qtensor/base_qtensor.py-206-        if (
modelopt/torch/quantization/qtensor/base_qtensor.py:207:            hasattr(module, "weight_quantizer")
modelopt/torch/quantization/qtensor/base_qtensor.py:208:            and module.weight_quantizer.is_enabled
modelopt/torch/quantization/qtensor/base_qtensor.py:209:            and not module.weight_quantizer._fake_quant
modelopt/torch/quantization/qtensor/base_qtensor.py-210-            and module.weight.element_size() > 1
modelopt/torch/quantization/qtensor/base_qtensor.py-211-        ):
modelopt/torch/quantization/qtensor/base_qtensor.py-212-            if force_quantize:
modelopt/torch/quantization/qtensor/base_qtensor.py:213:                module.weight_quantizer._dequantize = False
modelopt/torch/quantization/qtensor/base_qtensor.py-214-
modelopt/torch/quantization/qtensor/base_qtensor.py:215:            real_quant_tensor = module.weight_quantizer(module.weight)
modelopt/torch/quantization/qtensor/base_qtensor.py-216-            module.weight = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/qtensor/base_qtensor.py-217-            return True
modelopt/torch/quantization/qtensor/base_qtensor.py-218-
--
modelopt/torch/quantization/plugins/vllm.py-80-            torch.Tensor: The quantized output tensor.
modelopt/torch/quantization/plugins/vllm.py-81-        """
modelopt/torch/quantization/plugins/vllm.py-82-        x = layer.input_quantizer(x)
modelopt/torch/quantization/plugins/vllm.py:83:        if layer.weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-84-            original_weight = layer.weight
modelopt/torch/quantization/plugins/vllm.py:85:            quantized_tensor = layer.weight_quantizer(layer.weight)
modelopt/torch/quantization/plugins/vllm.py-86-            # parameterize the quantized weight
modelopt/torch/quantization/plugins/vllm.py-87-            if isinstance(original_weight, torch.nn.Parameter) and not isinstance(
modelopt/torch/quantization/plugins/vllm.py-88-                quantized_tensor, torch.nn.Parameter
--
modelopt/torch/quantization/plugins/vllm.py-110-class _VLLMParallelLinear(QuantModule):
modelopt/torch/quantization/plugins/vllm.py-111-    def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-112-        self.input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:113:        self.weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-114-        self.output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-115-        self.output_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-116-        assert type(self.quant_method) is vllm_linear.UnquantizedLinearMethod, (
--
modelopt/torch/quantization/plugins/vllm.py-159-    def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-160-        self.w13_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py-161-        self.w2_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:162:        self.w13_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py:163:        self.w2_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-164-        self.w13_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-165-        self.w2_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-166-        self.w13_output_quantizer.disable()
--
modelopt/torch/quantization/plugins/vllm.py-181-        if B is self.w13_weight:
modelopt/torch/quantization/plugins/vllm.py-182-            # First layer of expert
modelopt/torch/quantization/plugins/vllm.py-183-            A = self.w13_input_quantizer(A)  # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:184:            if self.w13_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-185-                original_weight = self.w13_weight
modelopt/torch/quantization/plugins/vllm.py:186:                self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
modelopt/torch/quantization/plugins/vllm.py-187-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-188-                self.w13_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-189-            else:
--
modelopt/torch/quantization/plugins/vllm.py-192-                C[:] = self.w13_output_quantizer(C)
modelopt/torch/quantization/plugins/vllm.py-193-        elif B is self.w2_weight:
modelopt/torch/quantization/plugins/vllm.py-194-            A = self.w2_input_quantizer(A)  # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:195:            if self.w2_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-196-                original_weight = self.w2_weight
modelopt/torch/quantization/plugins/vllm.py:197:                self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
modelopt/torch/quantization/plugins/vllm.py-198-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-199-                self.w2_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-200-            else:
--
modelopt/torch/quantization/plugins/vllm.py-232-        # the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
modelopt/torch/quantization/plugins/vllm.py-233-        for i in range(self.w13_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-234-            self.w13_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:235:                self.w13_weight_quantizer(self.w13_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-236-                    self.w13_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-237-                )
modelopt/torch/quantization/plugins/vllm.py-238-            )
modelopt/torch/quantization/plugins/vllm.py:239:        self.w13_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-240-        for i in range(self.w2_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-241-            self.w2_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:242:                self.w2_weight_quantizer(self.w2_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-243-                    self.w2_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-244-                )
modelopt/torch/quantization/plugins/vllm.py-245-            )
modelopt/torch/quantization/plugins/vllm.py:246:        self.w2_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-247-
modelopt/torch/quantization/plugins/vllm.py-248-        torch.cuda.empty_cache()
modelopt/torch/quantization/plugins/vllm.py-249-
--
modelopt/torch/quantization/nn/modules/quant_module.py-121-
modelopt/torch/quantization/nn/modules/quant_module.py-122-    def fold_weight(self, keep_attrs: bool = False):
modelopt/torch/quantization/nn/modules/quant_module.py-123-        """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_module.py:124:        # Handle all attributes that end with _weight_quantizer
modelopt/torch/quantization/nn/modules/quant_module.py-125-        for name in dir(self):
modelopt/torch/quantization/nn/modules/quant_module.py-126-            attr = getattr(self, name)
modelopt/torch/quantization/nn/modules/quant_module.py-127-            if (
modelopt/torch/quantization/nn/modules/quant_module.py:128:                name.endswith("weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_module.py-129-                and isinstance(attr, TensorQuantizer)
modelopt/torch/quantization/nn/modules/quant_module.py-130-                and attr.fake_quant
modelopt/torch/quantization/nn/modules/quant_module.py-131-            ):
modelopt/torch/quantization/nn/modules/quant_module.py:132:                # Get the corresponding weight name by removing _weight_quantizer suffix
modelopt/torch/quantization/nn/modules/quant_module.py-133-                weight_name = name[:-10]
modelopt/torch/quantization/nn/modules/quant_module.py-134-
modelopt/torch/quantization/nn/modules/quant_module.py-135-                assert hasattr(self, weight_name), (
--
modelopt/torch/quantization/nn/modules/quant_module.py-203-    Quantized linear modules are modules where both the input and the weight are quantized.
modelopt/torch/quantization/nn/modules/quant_module.py-204-    """
modelopt/torch/quantization/nn/modules/quant_module.py-205-
modelopt/torch/quantization/nn/modules/quant_module.py:206:    weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_module.py-207-    _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_module.py-208-    default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_module.py-209-
--
modelopt/torch/quantization/nn/modules/quant_module.py-219-    `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_module.py-220-    def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
modelopt/torch/quantization/nn/modules/quant_module.py-221-        if module._enable_weight_quantization or is_torch_export_mode():
modelopt/torch/quantization/nn/modules/quant_module.py:222:            return module.weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_module.py-223-        return weight
modelopt/torch/quantization/nn/modules/quant_module.py-224-
modelopt/torch/quantization/nn/modules/quant_module.py-225-    def forward(self, input, *args, **kwargs):
--
modelopt/torch/quantization/nn/modules/quant_module.py-234-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_module.py-235-        super()._setup()
modelopt/torch/quantization/nn/modules/quant_module.py-236-        self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_module.py:237:            "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_module.py-238-        )
modelopt/torch/quantization/nn/modules/quant_module.py-239-        self._register_temp_attribute("_enable_weight_quantization", False)
modelopt/torch/quantization/nn/modules/quant_module.py-240-        self._register_dynamic_attribute("weight", self._get_quantized_weight)
--
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1319-        return super()._fake_quantize(inputs)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1320-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1321-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py:1322:class SequentialQuantizer(nn.Sequential):
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1323-    """A sequential container for  :class:`TensorQuantizer` modules.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1324-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1325-    This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-47-class QuantRNNBase(QuantModule):
modelopt/torch/quantization/nn/modules/quant_rnn.py-48-    """Base class for quantized RNN modules."""
modelopt/torch/quantization/nn/modules/quant_rnn.py-49-
modelopt/torch/quantization/nn/modules/quant_rnn.py:50:    weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py-51-    _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_rnn.py-52-    default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_rnn.py-53-    default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-75-        self._enable_weight_quantization = False
modelopt/torch/quantization/nn/modules/quant_rnn.py-76-
modelopt/torch/quantization/nn/modules/quant_rnn.py-77-    `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_rnn.py:78:    def _get_quantized_weight_handler(weight_quantizer_name: str):
modelopt/torch/quantization/nn/modules/quant_rnn.py-79-        def _get_quantized_weight(module: "QuantRNNBase", weight: torch.Tensor):
modelopt/torch/quantization/nn/modules/quant_rnn.py-80-            if module._enable_weight_quantization:
modelopt/torch/quantization/nn/modules/quant_rnn.py:81:                weight_quantizer = getattr(module, weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py:82:                return weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-83-            return weight
modelopt/torch/quantization/nn/modules/quant_rnn.py-84-
modelopt/torch/quantization/nn/modules/quant_rnn.py-85-        return _get_quantized_weight
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-102-        for name, _ in self.named_parameters():
modelopt/torch/quantization/nn/modules/quant_rnn.py-103-            if name.startswith("weight"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-104-                # to be compatible with our current config, the name is some what weird
modelopt/torch/quantization/nn/modules/quant_rnn.py:105:                # it would be weight_xxx_weight_quantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py:106:                weight_quantizer_name = name + "_weight_quantizer"
modelopt/torch/quantization/nn/modules/quant_rnn.py-107-                self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:108:                    weight_quantizer_name, TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-109-                )
modelopt/torch/quantization/nn/modules/quant_rnn.py-110-                self._register_dynamic_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:111:                    name, self._get_quantized_weight_handler(weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py-112-                )
modelopt/torch/quantization/nn/modules/quant_rnn.py-113-        # for cells
modelopt/torch/quantization/nn/modules/quant_rnn.py-114-        self._register_temp_attribute("_input_quantizers", [])
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-143-        for iq in self._input_quantizers + self._proj_input_quantizers:
modelopt/torch/quantization/nn/modules/quant_rnn.py-144-            iq.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-145-
modelopt/torch/quantization/nn/modules/quant_rnn.py:146:    def _disable_weight_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-147-        for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:148:            if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-149-                module.disable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-150-
modelopt/torch/quantization/nn/modules/quant_rnn.py:151:    def _enable_weight_quantizer(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-152-        for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:153:            if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-154-                module.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-155-
modelopt/torch/quantization/nn/modules/quant_rnn.py-156-    def _setup(self):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-46-        """Quantized version of a generic linear functional."""
modelopt/torch/quantization/nn/modules/quant_linear.py-47-        output = getattr(package, func_name)(
modelopt/torch/quantization/nn/modules/quant_linear.py-48-            self.input_quantizer(input),
modelopt/torch/quantization/nn/modules/quant_linear.py:49:            self.weight_quantizer(weight),
modelopt/torch/quantization/nn/modules/quant_linear.py-50-            *args,
modelopt/torch/quantization/nn/modules/quant_linear.py-51-            **kwargs,
modelopt/torch/quantization/nn/modules/quant_linear.py-52-        )
--
modelopt/torch/quantization/nn/modules/quant_linear.py-119-
modelopt/torch/quantization/nn/modules/quant_linear.py-120-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-121-        """Overrides and bypass the _setup function."""
modelopt/torch/quantization/nn/modules/quant_linear.py:122:        if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
modelopt/torch/quantization/nn/modules/quant_linear.py-123-            return
modelopt/torch/quantization/nn/modules/quant_linear.py:124:        self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-125-
modelopt/torch/quantization/nn/modules/quant_linear.py-126-    def _not_sequential_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_linear.py:127:        return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
modelopt/torch/quantization/nn/modules/quant_linear.py-128-            self.input_quantizer, TensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-129-        )
modelopt/torch/quantization/nn/modules/quant_linear.py-130-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-138-        """Compute the LoRA residual if present, otherwise return None."""
modelopt/torch/quantization/nn/modules/quant_linear.py-139-        if (
modelopt/torch/quantization/nn/modules/quant_linear.py-140-            self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:141:            and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:142:            and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-143-        ):
modelopt/torch/quantization/nn/modules/quant_linear.py:144:            lora_a = F.linear(input, weight=self.weight_quantizer.svdquant_lora_a)
modelopt/torch/quantization/nn/modules/quant_linear.py:145:            lora_b = F.linear(lora_a, weight=self.weight_quantizer.svdquant_lora_b)
modelopt/torch/quantization/nn/modules/quant_linear.py-146-            return lora_b
modelopt/torch/quantization/nn/modules/quant_linear.py-147-        return None
modelopt/torch/quantization/nn/modules/quant_linear.py-148-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-150-        """SVDQuant layer forward function."""
modelopt/torch/quantization/nn/modules/quant_linear.py-151-        has_svdquant_lora = (
modelopt/torch/quantization/nn/modules/quant_linear.py-152-            self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:153:            and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:154:            and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-155-        )
modelopt/torch/quantization/nn/modules/quant_linear.py-156-        if has_svdquant_lora:
modelopt/torch/quantization/nn/modules/quant_linear.py-157-            input = self._apply_pre_quant_scale(input)
--
modelopt/torch/quantization/nn/modules/quant_linear.py-166-        """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_linear.py-167-        super().fold_weight(keep_attrs)
modelopt/torch/quantization/nn/modules/quant_linear.py-168-        if (
modelopt/torch/quantization/nn/modules/quant_linear.py:169:            hasattr(self, "weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_linear.py-170-            and hasattr(self, "weight")
modelopt/torch/quantization/nn/modules/quant_linear.py:171:            and self.weight_quantizer.fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-172-        ):
modelopt/torch/quantization/nn/modules/quant_linear.py-173-            if (
modelopt/torch/quantization/nn/modules/quant_linear.py-174-                self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:175:                and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:176:                and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-177-            ):
modelopt/torch/quantization/nn/modules/quant_linear.py-178-                self.weight.data.copy_(
modelopt/torch/quantization/nn/modules/quant_linear.py-179-                    self.weight
modelopt/torch/quantization/nn/modules/quant_linear.py:180:                    + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
modelopt/torch/quantization/nn/modules/quant_linear.py-181-                )
modelopt/torch/quantization/nn/modules/quant_linear.py-182-            if not keep_attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py-183-                _attrs = [
--
modelopt/torch/quantization/nn/modules/quant_linear.py-185-                    "_svdquant_lora_b",
modelopt/torch/quantization/nn/modules/quant_linear.py-186-                ]
modelopt/torch/quantization/nn/modules/quant_linear.py-187-                for attr in _attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py:188:                    if hasattr(self.weight_quantizer, attr):
modelopt/torch/quantization/nn/modules/quant_linear.py:189:                        delattr(self.weight_quantizer, attr)
modelopt/torch/quantization/nn/modules/quant_linear.py-190-
modelopt/torch/quantization/nn/modules/quant_linear.py-191-
modelopt/torch/quantization/nn/modules/quant_linear.py-192-class RealQuantLinear(QuantModule):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-241-
modelopt/torch/quantization/nn/modules/quant_linear.py-242-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-243-        class RealQuantParameterDict(dict):
modelopt/torch/quantization/nn/modules/quant_linear.py:244:            def __init__(self, weight_quantizer: TensorQuantizer, *args, **kwargs):
modelopt/torch/quantization/nn/modules/quant_linear.py-245-                super().__init__(*args, **kwargs)
modelopt/torch/quantization/nn/modules/quant_linear.py:246:                self.weight_quantizer = weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-247-
modelopt/torch/quantization/nn/modules/quant_linear.py-248-            def __setitem__(self, key, value):
modelopt/torch/quantization/nn/modules/quant_linear.py-249-                if (
modelopt/torch/quantization/nn/modules/quant_linear.py-250-                    key == "weight"
modelopt/torch/quantization/nn/modules/quant_linear.py:251:                    and self.weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py:252:                    and self.weight_quantizer.is_enabled
modelopt/torch/quantization/nn/modules/quant_linear.py:253:                    and not self.weight_quantizer._fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-254-                    and value.element_size() > 1
modelopt/torch/quantization/nn/modules/quant_linear.py-255-                ):
modelopt/torch/quantization/nn/modules/quant_linear.py-256-                    # reset the amax for later calibration
modelopt/torch/quantization/nn/modules/quant_linear.py-257-                    if (
modelopt/torch/quantization/nn/modules/quant_linear.py:258:                        self.weight_quantizer.amax is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:259:                        and self.weight_quantizer.amax.is_meta
modelopt/torch/quantization/nn/modules/quant_linear.py-260-                    ):
modelopt/torch/quantization/nn/modules/quant_linear.py:261:                        delattr(self.weight_quantizer, "_amax")
modelopt/torch/quantization/nn/modules/quant_linear.py:262:                        self.weight_quantizer.amax = self.weight_quantizer._get_amax(value)
modelopt/torch/quantization/nn/modules/quant_linear.py:263:                        self.weight_quantizer._calibrator.reset()
modelopt/torch/quantization/nn/modules/quant_linear.py-264-                    # compress the weight
modelopt/torch/quantization/nn/modules/quant_linear.py:265:                    real_quant_tensor = self.weight_quantizer(value)
modelopt/torch/quantization/nn/modules/quant_linear.py-266-                    real_quant_value = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/nn/modules/quant_linear.py-267-                    del value  # delete the original weight to save memory
modelopt/torch/quantization/nn/modules/quant_linear.py-268-                    value = real_quant_value
--
modelopt/torch/quantization/nn/modules/quant_linear.py-270-
modelopt/torch/quantization/nn/modules/quant_linear.py-271-        # Monkey patch the _parameters.__setitem__ to real quant the weight when loading
modelopt/torch/quantization/nn/modules/quant_linear.py-272-        # HF accelerate loads the weight by directly assigning the weight through the _parameters dict.
modelopt/torch/quantization/nn/modules/quant_linear.py:273:        self._parameters = RealQuantParameterDict(self.weight_quantizer, self._parameters)
modelopt/torch/quantization/nn/modules/quant_linear.py-274-
modelopt/torch/quantization/nn/modules/quant_linear.py-275-        # Function to dynamically override load_state_dict
modelopt/torch/quantization/nn/modules/quant_linear.py-276-        dynamically_update_state_methods(self)
--
modelopt/torch/quantization/plugins/transformer_engine.py-78-            idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-79-            weight, inputs = args[idx], args[idx + 1]
modelopt/torch/quantization/plugins/transformer_engine.py-80-            remaining_args = args[idx + 2 :]
modelopt/torch/quantization/plugins/transformer_engine.py:81:            weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-82-            inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-83-            new_args = (weight, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-84-            new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-90-            idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-91-            weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
modelopt/torch/quantization/plugins/transformer_engine.py-92-            remaining_args = args[idx + 3 :]
modelopt/torch/quantization/plugins/transformer_engine.py:93:            weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-94-            inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-95-            new_args = (weight, weight_fp8, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-96-            new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-170-        weights_and_biases = args[-2 * num_gemms :]
modelopt/torch/quantization/plugins/transformer_engine.py-171-        weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
modelopt/torch/quantization/plugins/transformer_engine.py-172-        quantized_inputs = self.input_quantizer(inp)
modelopt/torch/quantization/plugins/transformer_engine.py:173:        quantized_weights = [self.weight_quantizer(weight) for weight in weights]
modelopt/torch/quantization/plugins/transformer_engine.py-174-
modelopt/torch/quantization/plugins/transformer_engine.py-175-        output = getattr(package, func_name)(
modelopt/torch/quantization/plugins/transformer_engine.py-176-            *(
--
modelopt/torch/quantization/plugins/transformer_engine.py-208-
modelopt/torch/quantization/plugins/transformer_engine.py-209-    `@staticmethod`
modelopt/torch/quantization/plugins/transformer_engine.py-210-    def forward(ctx, inp, ln_weight, ln_bias, weight, *args, **kwargs):
modelopt/torch/quantization/plugins/transformer_engine.py:211:        input_quantizer, weight_quantizer = _QuantLayerNormLinearFunc.modelopt_quantizers
modelopt/torch/quantization/plugins/transformer_engine.py-212-
modelopt/torch/quantization/plugins/transformer_engine.py:213:        qweight = weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-214-        qweight.requires_grad = weight.requires_grad
modelopt/torch/quantization/plugins/transformer_engine.py-215-        if ctx is not None:
modelopt/torch/quantization/plugins/transformer_engine.py-216-            # We need to recompute the quantized input for the backward pass, so we save the input_quantizer
--
modelopt/torch/quantization/plugins/transformer_engine.py-291-        # This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
modelopt/torch/quantization/plugins/transformer_engine.py-292-        _QuantLayerNormLinearFunc.modelopt_quantizers = (
modelopt/torch/quantization/plugins/transformer_engine.py-293-            self.input_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py:294:            self.weight_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py-295-        )
modelopt/torch/quantization/plugins/transformer_engine.py-296-        with replace_function(
modelopt/torch/quantization/plugins/transformer_engine.py-297-            te_layernorm_linear,
--
modelopt/torch/quantization/plugins/peft.py-34-class _QuantLoraLinear(QuantModule):
modelopt/torch/quantization/plugins/peft.py-35-    def _setup(self):
modelopt/torch/quantization/plugins/peft.py-36-        self.input_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py:37:        self.weight_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py-38-        self.output_quantizer = Tenso

... [stdout truncated: output exceeded 50000 characters (209210 total). Please modify your command to reduce output size, for example by:
- Using grep/rg to filter results
- Using head/tail to limit lines
- Using --max-count flag with grep/rg
- Processing output with awk/sed to extract only needed information]

Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Full context of sync_moe_gate_up_amax function ==="
sed -n '1174,1221p' modelopt/torch/export/layer_utils.py

echo
echo "=== Check if SequentialQuantizer has amax property ==="
sed -n '1322,1400p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -80

echo
echo "=== Check imports and _GATE_UP_PAIRS ==="
sed -n '1,50p' modelopt/torch/export/layer_utils.py
rg -n '_GATE_UP_PAIRS|from.*SequentialQuantizer|import.*SequentialQuantizer' modelopt/torch/export/layer_utils.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== Full context of sync_moe_gate_up_amax function ===
# Gate/up naming pairs for standard (unfused) MoE architectures.
# Fused variants (gate_up_proj, linear_fc1) already share a single quantizer and need no sync.
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]


def sync_moe_gate_up_amax(model: nn.Module) -> int:
    """Take element-wise max of gate and up weight quantizer amaxes per expert.

    Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
    require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
    syncing amaxes before quantization ensures the per-block weight_scale values
    are computed against a consistent global scale.

    Only affects standard MoE models with separate gate/up linear layers
    (e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
    (e.g. Llama4, GptOss) are unaffected.

    Returns:
        Number of expert gate/up pairs whose amaxes were synced.
    """
    synced = 0
    for _, sub_module in model.named_modules():
        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
            continue
        if not hasattr(sub_module.experts, "__iter__"):
            continue
        for expert in sub_module.experts:
            for gate_name, up_name in _GATE_UP_PAIRS:
                gate_linear = getattr(expert, gate_name, None)
                up_linear = getattr(expert, up_name, None)
                if gate_linear is None or up_linear is None:
                    continue
                gate_wq = getattr(gate_linear, "weight_quantizer", None)
                up_wq = getattr(up_linear, "weight_quantizer", None)
                if gate_wq is None or up_wq is None:
                    break
                gate_amax = getattr(gate_wq, "amax", None)
                up_amax = getattr(up_wq, "amax", None)
                if gate_amax is None or up_amax is None:
                    break
                if not torch.equal(gate_amax, up_amax):
                    shared_amax = torch.max(gate_amax, up_amax)
                    gate_wq.amax = shared_amax
                    up_wq.amax = shared_amax.clone()
                    synced += 1
                break
    return synced


=== Check if SequentialQuantizer has amax property ===
class SequentialQuantizer(nn.Sequential):
    """A sequential container for  :class:`TensorQuantizer` modules.

    This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
    :class:`TensorQuantizer` modules and containerize them similar to :class:`torch.nn.Sequential`.

    We delegate certain properties and methods to all contained quantizers.
    In the case of conflicts, the first quantizer's property or method takes priority.

    `SequentialQuantizer` is useful in cases like INT4 weights, FP8 activations where weight quantization is not the
    same as the gemm quantization. It allows for applying multiple quantization formats to the same tensor in sequence.

    Use `SequentialQuantizer` methods in lower level implementations for better code organization and readability.

    Args:
        quantizers (TensorQuantizer): :class:`TensorQuantizer` modules to be added to the container.

    """

    _delegated_properties = ["fake_quant", "is_enabled", "amax"]
    _delegated_methods = [
        "reset_amax",
        "disable",
        "enable",
        "load_calib_amax",
        "load_calib_bias",
    ]

    def __init__(self, *quantizers: TensorQuantizer):
        """Initialize SequentialQuantizer module."""
        super().__init__(*quantizers)
        assert all(isinstance(q, TensorQuantizer) for q in self), (
            "All quantizers must be a TensorQuantizer."
        )

    def __getattr__(self, name):
        """Delegate properties and methods to all contained quantizers."""
        if name in self._delegated_properties:
            # Return the property of the first quantizer
            return getattr(self[0], name)

        if name in self._delegated_methods:

            def method_wrapper(*args, **kwargs):
                outputs = getattr(self[0], name)(*args, **kwargs)
                for quantizer in self[1:]:
                    outputs = getattr(quantizer, name)(*args, **kwargs)
                return outputs

            return method_wrapper

        # Defer to super class for attributes not handled here
        return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in self._delegated_properties:
            for quantizer in self:
                setattr(quantizer, name, value)
        else:
            super().__setattr__(name, value)

    def get_modelopt_state(self) -> dict[str, Any]:
        """Get meta state to be saved in checkpoint."""
        return {"num_quantizers": len(self), "is_sequential_quantizer": True}

    def set_from_attribute_config(
        self,
        attributes: list[dict[str, Any] | QuantizerAttributeConfig]
        | dict[str, Any]
        | QuantizerAttributeConfig,
    ):
        """Set the attributes of contained quantizers from a list of attribute_dicts."""
        if not isinstance(attributes, (list, tuple)):
            assert isinstance(attributes, (dict, QuantizerAttributeConfig)), (
                "attributes must be a list or a dict."
            )
            attributes = [attributes] * len(self)

        for attribute, quantizer in zip(attributes, self):

=== Check imports and _GATE_UP_PAIRS ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils for model_config export.

Some of the logics in this file are empirical and needs constant update if exceptions occur.
"""

from warnings import warn

import torch
import torch.nn as nn

try:
    from transformers.activations import ACT2FN
except Exception:
    warn("Cannot find transformers package. Hugginface modules cannot be exported.")

from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import import_plugin

from ..quantization.nn import SequentialQuantizer, TensorQuantizer
from .hf_config_map import HF_CONFIG_MAP
from .mcore_config_map import MCORE_CONFIG_MAP
from .model_config import (
    LAYERNORM_DEFAULT,
    LAYERNORM_RMS,
    LINEAR_COLUMN,
    LINEAR_GROUP,
    LINEAR_ROW,
    QUANTIZATION_FP8,
    QUANTIZATION_NONE,
    QUANTIZATION_NVFP4,
    AttentionConfig,
    ConvConfig,
    DecoderLayerConfig,
    EmbeddingConfig,
    ExpertConfig,
34:from ..quantization.nn import SequentialQuantizer, TensorQuantizer
1176:_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
1201:            for gate_name, up_name in _GATE_UP_PAIRS:

Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check build_mlp_config handling at lines 827-831 ==="
sed -n '820,840p' modelopt/torch/export/layer_utils.py

echo
echo "=== Check if SequentialQuantizer.amax delegation actually works ==="
sed -n '1342,1380p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py

echo
echo "=== Verify the loop structure - what does break actually break? ==="
sed -n '1195,1225p' modelopt/torch/export/layer_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== Check build_mlp_config handling at lines 827-831 ===
            or get_quantization_format(module) in [QUANTIZATION_FP8, QUANTIZATION_NVFP4]
        )
    ):
        preprocess_linear_fusion([fc_linear, gate_linear])

    if fc_linear is not None:
        weight_quantizer = None
        if hasattr(fc_linear, "weight_quantizer"):
            weight_quantizer = fc_linear.weight_quantizer
            if isinstance(weight_quantizer, SequentialQuantizer):
                weight_quantizer = weight_quantizer[0]

        # swap fused fc and gate
        if decoder_type in ["chatglm", "phi3"]:
            weights = torch.chunk(fc_linear.weight, 2, dim=0)
            weights = (weights[1], weights[0])
            fc_linear.weight.data = torch.cat(weights, dim=0)

            if (
                weight_quantizer is not None
                and weight_quantizer.is_enabled

=== Check if SequentialQuantizer.amax delegation actually works ===
    _delegated_methods = [
        "reset_amax",
        "disable",
        "enable",
        "load_calib_amax",
        "load_calib_bias",
    ]

    def __init__(self, *quantizers: TensorQuantizer):
        """Initialize SequentialQuantizer module."""
        super().__init__(*quantizers)
        assert all(isinstance(q, TensorQuantizer) for q in self), (
            "All quantizers must be a TensorQuantizer."
        )

    def __getattr__(self, name):
        """Delegate properties and methods to all contained quantizers."""
        if name in self._delegated_properties:
            # Return the property of the first quantizer
            return getattr(self[0], name)

        if name in self._delegated_methods:

            def method_wrapper(*args, **kwargs):
                outputs = getattr(self[0], name)(*args, **kwargs)
                for quantizer in self[1:]:
                    outputs = getattr(quantizer, name)(*args, **kwargs)
                return outputs

            return method_wrapper

        # Defer to super class for attributes not handled here
        return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in self._delegated_properties:
            for quantizer in self:
                setattr(quantizer, name, value)
        else:

=== Verify the loop structure - what does break actually break? ===
     1	    for _, sub_module in model.named_modules():
     2	        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
     3	            continue
     4	        if not hasattr(sub_module.experts, "__iter__"):
     5	            continue
     6	        for expert in sub_module.experts:
     7	            for gate_name, up_name in _GATE_UP_PAIRS:
     8	                gate_linear = getattr(expert, gate_name, None)
     9	                up_linear = getattr(expert, up_name, None)
    10	                if gate_linear is None or up_linear is None:
    11	                    continue
    12	                gate_wq = getattr(gate_linear, "weight_quantizer", None)
    13	                up_wq = getattr(up_linear, "weight_quantizer", None)
    14	                if gate_wq is None or up_wq is None:
    15	                    break
    16	                gate_amax = getattr(gate_wq, "amax", None)
    17	                up_amax = getattr(up_wq, "amax", None)
    18	                if gate_amax is None or up_amax is None:
    19	                    break
    20	                if not torch.equal(gate_amax, up_amax):
    21	                    shared_amax = torch.max(gate_amax, up_amax)
    22	                    gate_wq.amax = shared_amax
    23	                    up_wq.amax = shared_amax.clone()
    24	                    synced += 1
    25	                break
    26	    return synced
    27	
    28	
    29	def build_stacked_experts(
    30	    experts: nn.Module,
    31	    linear_names: list[str],

Use continue instead of break to check all gate/up pairs (not just the first) per expert.

On lines 1209 and 1213, break exits the pair-checking loop entirely if any pair lacks a quantizer or amax. Since _GATE_UP_PAIRS has two pairs (gate_proj/up_proj and w1/w3), a missing quantizer in the first pair prevents checking the second. This should be continue to skip only the current pair and attempt the next one. (The final break on line 1214 correctly exits after syncing.)

Proposed fix
                 gate_wq = getattr(gate_linear, "weight_quantizer", None)
                 up_wq = getattr(up_linear, "weight_quantizer", None)
                 if gate_wq is None or up_wq is None:
-                    break
+                    continue
                 gate_amax = getattr(gate_wq, "amax", None)
                 up_amax = getattr(up_wq, "amax", None)
                 if gate_amax is None or up_amax is None:
-                    break
+                    continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/layer_utils.py` around lines 1206 - 1213, The loop that
iterates over _GATE_UP_PAIRS incorrectly uses break when a pair is missing a
quantizer or amax, which stops checking remaining pairs; update the logic in the
loop that accesses gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to
use continue instead of break so only the current pair is skipped and the next
pair in _GATE_UP_PAIRS is checked; keep the final break that exits after
successfully syncing unchanged.

if not torch.equal(gate_amax, up_amax):
shared_amax = torch.max(gate_amax, up_amax)
gate_wq.amax = shared_amax
up_wq.amax = shared_amax.clone()
synced += 1
break
return synced


def build_stacked_experts(
experts: nn.Module,
linear_names: list[str],
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
is_moe,
is_quantlinear,
set_expert_quantizer_amax,
sync_moe_gate_up_amax,
)
from .model_config import (
QUANTIZATION_FP8,
Expand Down Expand Up @@ -775,6 +776,18 @@ def _export_transformers_checkpoint(
exclude_modules.append(pattern)
print(f"Adding MTP layer to quantization_config ignore: {pattern}")

# Safety net: sync any gate/up weight quantizer amaxes that
# requantize_resmooth_fused_llm_layers did not reach (e.g. experts not
# activated during the dummy forward, or non-standard expert naming).
synced = sync_moe_gate_up_amax(model)
if synced:
warnings.warn(
f"Found {synced} MoE expert gate/up projection pair(s) with mismatched "
f"weight_scale_2 after requantize_resmooth_fused_llm_layers. "
f"This typically means the dummy forward did not activate these experts. "
f"Taking element-wise max of amaxes for serving-engine fusion."
)

# Process all quantized modules and export weights
_process_quantized_modules(model, dtype, is_modelopt_qlora)

Expand Down
Loading