Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models#1033
Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models#1033Edwardf0t1 wants to merge 3 commits intomainfrom
Conversation
…pe in export Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…pe in export Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis pull request adds MoE gate-up quantization synchronization functionality. A new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
There was a problem hiding this comment.
Pull request overview
This PR updates the Hugging Face export path to ensure weight_scale_2 is shared between fused MoE MLP projection pairs (gate/up), matching serving-engine expectations when they fuse these projections into a single kernel.
Changes:
- Add
max_gate_up_scales()utility to replacegate_proj/up_proj(andw1/w3)weight_scale_2pairs with their element-wise max in the exported state dict. - Invoke the new post-processing step at the end of
_export_transformers_checkpoint()afterpostprocess_state_dict().
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
modelopt/torch/export/unified_export_hf.py |
Calls max_gate_up_scales() during HF checkpoint export and reports how many pairs were tied. |
modelopt/torch/export/quant_utils.py |
Introduces max_gate_up_scales() to unify weight_scale_2 across gate/up projection pairs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| tied = max_gate_up_scales(quantized_state_dict) | ||
| if tied: | ||
| print(f"Tied weight_scale_2 for {tied} gate/up projection pair(s) in MoE experts.") |
modelopt/torch/export/quant_utils.py
Outdated
| """Replace gate_proj and up_proj weight_scale_2 with their element-wise max. | ||
|
|
||
| For MOE models where gate_proj and up_proj are quantized independently, | ||
| serving engines typically fuse them into a single gate_up_proj and need | ||
| a single shared scale. Using max is conservative (avoids overflow at the | ||
| cost of slightly reduced dynamic range). | ||
| """ | ||
| suffix_pairs = { | ||
| ".gate_proj.weight_scale_2": ".up_proj.weight_scale_2", | ||
| ".w1.weight_scale_2": ".w3.weight_scale_2", | ||
| } |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1033 +/- ##
==========================================
- Coverage 70.11% 70.06% -0.05%
==========================================
Files 221 221
Lines 25459 25531 +72
==========================================
+ Hits 17851 17889 +38
- Misses 7608 7642 +34 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
modelopt/torch/export/quant_utils.py
Outdated
| return post_state_dict | ||
|
|
||
|
|
||
| def max_gate_up_scales(tensors: dict[str, torch.Tensor]) -> int: |
There was a problem hiding this comment.
I think this may not work. We also need to requantize ws1 if the ws2 changes. I think we need to do it here instead: https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L230
There was a problem hiding this comment.
I agree it's safer to sync the _amax on the weight quantizers of gate/up pairs. In @lukealonso's implementation it's done in both amax level and the post-processing stage, and he verified the accuracy which seems good.
cjluo-nv
left a comment
There was a problem hiding this comment.
https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L396-L418 this logic is used to align w1 and w3 scales.
We might want to debug why this fails for qwen3.5
It failed for kimi-k2.5 as well. That's exactly why we needed this PR to fix it. |
| quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora | ||
| ) | ||
|
|
||
| tied = max_gate_up_scales(quantized_state_dict) |
There was a problem hiding this comment.
Should we do this before compressing the weights, i.e., syncing the amax of gate and up? Is my understanding correct that the block-wise scales are still based on the separate global scales?
There was a problem hiding this comment.
You are right - updated to sync at the amax level.
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/layer_utils.py`:
- Around line 1206-1213: The loop that iterates over _GATE_UP_PAIRS incorrectly
uses break when a pair is missing a quantizer or amax, which stops checking
remaining pairs; update the logic in the loop that accesses
gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to use continue
instead of break so only the current pair is skipped and the next pair in
_GATE_UP_PAIRS is checked; keep the final break that exits after successfully
syncing unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c2d4c1ed-e8c9-4c53-b650-0b34276086cf
📒 Files selected for processing (2)
modelopt/torch/export/layer_utils.pymodelopt/torch/export/unified_export_hf.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/unified_export_hf.py
| gate_wq = getattr(gate_linear, "weight_quantizer", None) | ||
| up_wq = getattr(up_linear, "weight_quantizer", None) | ||
| if gate_wq is None or up_wq is None: | ||
| break | ||
| gate_amax = getattr(gate_wq, "amax", None) | ||
| up_amax = getattr(up_wq, "amax", None) | ||
| if gate_amax is None or up_amax is None: | ||
| break |
There was a problem hiding this comment.
❓ Verification inconclusive
Script executed:
#!/bin/bash
set -euo pipefail
echo "== SequentialQuantizer API =="
fd 'tensor_quantizer.py' --exec rg -n -C2 'class SequentialQuantizer|def amax|@property'
echo
echo "== Where SequentialQuantizer is used for weight quantizers =="
rg -n -C3 --type=py 'SequentialQuantizer\(|weight_quantizer' modelopt/torch
echo
echo "== Existing unwrapping pattern in layer_utils.py =="
rg -n -C3 --type=py 'isinstance\(.*SequentialQuantizer\)|weight_quantizer\.amax|sync_moe_gate_up_amax' modelopt/torch/export/layer_utils.pyRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
== SequentialQuantizer API ==
252- return qtensor.dequantize(**kwarg)
253-
254: `@property`
255- def num_bits(self):
256- """Return num_bits for quantization."""
--
262- self._calibrator._num_bits = value
263-
264: `@property`
265- def maxbound(self):
266- """Return maxbound for quantization."""
--
271- return (1 << (self._num_bits - 1 + int(self._unsigned))) - 1
272-
273: `@property`
274- def unsigned(self):
275- """Return True if unsigned quantization is used."""
--
281- self._calibrator._unsigned = value
282-
283: `@property`
284- def pre_quant_scale(self):
285- """Return pre_quant_scale used for smoothquant."""
--
305- )
306-
307: `@property`
308: def amax(self):
309- """Return amax for quantization."""
310- if not hasattr(self, "_amax") or self.is_mx_format:
--
314-
315- `@amax.setter`
316: def amax(self, value):
317- assert value is not None, "amax cannot be set to None."
318-
--
341- self._bias_calibrator.reset()
342-
343: `@property`
344- def step_size(self):
345- """Return step size for integer quantization."""
--
352- return self._amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)
353-
354: `@property`
355- def axis(self):
356- """Return axis for quantization."""
--
362- self._calibrator._axis = value
363-
364: `@property`
365- def block_sizes(self):
366- """Return block_sizes for quantization."""
--
372- self._block_sizes = value
373-
374: `@property`
375- def bias(self):
376- """Return bias for quantization."""
--
379- return self._bias
380-
381: `@property`
382- def bias_axis(self):
383- """Return bias_axis for quantization."""
--
392- self._bias_axis = value
393-
394: `@property`
395- def bias_method(self):
396- """Return bias_method for quantization."""
--
399- return self._bias.get("method", "mean")
400-
401: `@property`
402- def bias_type(self):
403- """Return bias_type for quantization."""
--
414- self._bias["type"] = value
415-
416: `@property`
417- def bias_value(self):
418- """Return bias for quantization."""
--
435- self._bias_value.data.copy_(value.clone().detach().to(self._bias_value.device))
436-
437: `@property`
438- def bias_calibrator(self):
439- """Return bias_calibrator for quantization."""
--
450- return self._bias_calibrator
451-
452: `@property`
453- def fake_quant(self):
454- """Return True if fake quantization is used."""
455- return self._fake_quant
456-
457: `@property`
458- def narrow_range(self):
459- """Return True if symmetric integer range for signed quantization is used."""
--
464- self._narrow_range = value
465-
466: `@property`
467- def is_enabled(self):
468- """Return true if the modules is not disabled."""
--
480- self._disabled = False
481-
482: `@property`
483- def trt_high_precision_dtype(self):
484- """Return True if FP16 AMAX is used when exporting the model."""
--
489- self._trt_high_precision_dtype = value
490-
491: `@property`
492- def is_mx_format(self):
493- """Check if is MX formats."""
--
521- raise NotImplementedError()
522-
523: `@property`
524- def is_static_block_quant(self):
525- """Check if is static block quantization."""
--
530- )
531-
532: `@property`
533- def rotate_is_enabled(self):
534- """Check if rotate is enabled in quant config."""
535- return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
536-
537: `@property`
538- def rotate_is_fp32(self):
539- """Check if rotation needs to be computed in float32."""
--
1286- return tq
1287-
1288: `@property`
1289- def global_amax(self):
1290- """Return global_amax for quantization."""
--
1320-
1321-
1322:class SequentialQuantizer(nn.Sequential):
1323- """A sequential container for :class:`TensorQuantizer` modules.
1324-
== Where SequentialQuantizer is used for weight quantizers ==
modelopt/torch/quantization/utils.py-228-
modelopt/torch/quantization/utils.py-229- # the standard weight and quantizer case
modelopt/torch/quantization/utils.py-230- weight = getattr(module, "weight", None)
modelopt/torch/quantization/utils.py:231: weight_quantizer = getattr(module, "weight_quantizer", None)
modelopt/torch/quantization/utils.py:232: if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
modelopt/torch/quantization/utils.py-233- yield "weight"
modelopt/torch/quantization/utils.py-234-
modelopt/torch/quantization/utils.py-235- # other weight and quantizer case
modelopt/torch/quantization/utils.py-236- for name, _ in module.named_parameters(recurse=False):
modelopt/torch/quantization/utils.py-237- weight = getattr(module, name, None)
modelopt/torch/quantization/utils.py:238: weight_quantizer = getattr(module, f"{name}_weight_quantizer", None)
modelopt/torch/quantization/utils.py-239- if isinstance(weight, nn.Parameter) and isinstance(
modelopt/torch/quantization/utils.py:240: weight_quantizer, (TensorQuantizer, SequentialQuantizer)
modelopt/torch/quantization/utils.py-241- ):
modelopt/torch/quantization/utils.py-242- yield name
modelopt/torch/quantization/utils.py-243-
--
modelopt/torch/quantization/utils.py-246-QuantizerAttrNames = namedtuple(
modelopt/torch/quantization/utils.py-247- "QuantizerAttrNames",
modelopt/torch/quantization/utils.py-248- (
modelopt/torch/quantization/utils.py:249: "weight_quantizer",
modelopt/torch/quantization/utils.py-250- "input_quantizer",
modelopt/torch/quantization/utils.py-251- "output_quantizer",
modelopt/torch/quantization/utils.py-252- "weight_scale",
--
modelopt/torch/quantization/utils.py-261- """Get all the quantizer related attribute names for a given weight name."""
modelopt/torch/quantization/utils.py-262- prefix = f"{weight_name}_" if weight_name != "weight" else ""
modelopt/torch/quantization/utils.py-263- return QuantizerAttrNames(
modelopt/torch/quantization/utils.py:264: weight_quantizer=f"{prefix}weight_quantizer",
modelopt/torch/quantization/utils.py-265- input_quantizer=f"{prefix}input_quantizer",
modelopt/torch/quantization/utils.py-266- output_quantizer=f"{prefix}output_quantizer",
modelopt/torch/quantization/utils.py-267- weight_scale=f"{prefix}weight_scale",
--
modelopt/torch/quantization/utils.py-285- return (
modelopt/torch/quantization/utils.py-286- isinstance(module, QuantModule)
modelopt/torch/quantization/utils.py-287- and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
modelopt/torch/quantization/utils.py:288: and hasattr(module, "weight_quantizer")
modelopt/torch/quantization/utils.py-289- and (
modelopt/torch/quantization/utils.py-290- (getattr(module, "weight", None) is not None and module.weight.dim() == 2)
modelopt/torch/quantization/utils.py-291- # module.weight0 check is required to support TEGroupedLinear
--
modelopt/torch/quantization/utils.py-329- config["quant_cfg"]["*lora*"] = {"enable": False}
modelopt/torch/quantization/utils.py-330- for layer in layers:
modelopt/torch/quantization/utils.py-331- config["quant_cfg"][f"*{layer}.input_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py:332: config["quant_cfg"][f"*{layer}.weight_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-333- config["quant_cfg"][f"*{layer}.output_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-334- return config
modelopt/torch/quantization/utils.py-335-
--
modelopt/torch/quantization/utils.py-537-
modelopt/torch/quantization/utils.py-538- 1. Takes the element-wise max of each ``input_quantizer`` amax across all experts
modelopt/torch/quantization/utils.py-539- and writes it back, so every expert shares the same input amax.
modelopt/torch/quantization/utils.py:540: 2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
modelopt/torch/quantization/utils.py-541- received no tokens during calibration), runs a weight-only ``max_calibrate``
modelopt/torch/quantization/utils.py-542- to populate the missing amax.
modelopt/torch/quantization/utils.py-543- """
--
modelopt/torch/quantization/utils.py-566-
modelopt/torch/quantization/utils.py-567- for expert in experts:
modelopt/torch/quantization/utils.py-568- for name, module in expert.named_modules():
modelopt/torch/quantization/utils.py:569: if name.endswith("weight_quantizer") and module.is_enabled and module.amax is None:
modelopt/torch/quantization/utils.py:570: weight = expert.state_dict().get(name.replace("weight_quantizer", "weight"))
modelopt/torch/quantization/utils.py-571- if weight is not None:
modelopt/torch/quantization/utils.py-572- max_calibrate(module, lambda m, w=weight: m(w), distributed_sync=False)
modelopt/torch/quantization/utils.py-573-
--
modelopt/torch/quantization/utils.py-687- """
modelopt/torch/quantization/utils.py-688- original_fake_quant = []
modelopt/torch/quantization/utils.py-689- for m in module.modules():
modelopt/torch/quantization/utils.py:690: if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:691: original_fake_quant.append(m.weight_quantizer._fake_quant)
modelopt/torch/quantization/utils.py:692: m.weight_quantizer._fake_quant = True
modelopt/torch/quantization/utils.py-693- yield
modelopt/torch/quantization/utils.py-694- for m in module.modules():
modelopt/torch/quantization/utils.py:695: if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:696: m.weight_quantizer._fake_quant = original_fake_quant.pop(0)
modelopt/torch/quantization/utils.py-697-
modelopt/torch/quantization/utils.py-698-
modelopt/torch/quantization/utils.py-699-@contextmanager
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-97- def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-98- cls,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-99- weight: torch.Tensor,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:100: weight_quantizer,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-101- ) -> torch.Tensor:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-102- """Returns E8M0 scale from quantizer or computes from weight.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-103-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-107- Args:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-108- weight: The weight tensor. Can be 2D (out_dim, in_dim) or
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-109- 3D for MoE (num_experts, out_dim, in_dim).
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:110: weight_quantizer: The weight quantizer with block_sizes and optional _scale.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-111-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-112- Returns:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-113- torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-114- """
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:115: assert hasattr(weight_quantizer, "block_sizes"), (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:116: "weight_quantizer must have 'block_sizes' attribute"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-117- )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:118: assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:119: f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-120- )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-121- assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-122-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-126- # For 3D MoE: (num_experts, out_dim, in_dim // 32)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-127- expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-128-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:129: if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:130: scale = weight_quantizer._scale
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-131-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-132- assert scale.dtype == cls.SCALE_DTYPE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-133- f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}"
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-53- return cls.e2m1_bounds_on_device[device]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-54-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-55- `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:56: def _is_static_quantizer(cls, weight_quantizer) -> bool:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-57- """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:58: return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-59-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-60- `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:61: def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:62: """Returns per tensor weight scaling factor from the weight_quantizer.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-63-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-64- Handles both static NVFP4 quantizers (using global_amax) and
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-65- dynamic quantizers (using _amax).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-66-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-67- Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:68: weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-69-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-70- Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-71- The global scaling factor as a float tensor.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-72- """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:73: if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:74: return weight_quantizer.global_amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-75- else:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:76: assert hasattr(weight_quantizer, "_amax"), (
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-77- "Weight quantizer does not have attribute amax"
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-78- )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:79: return weight_quantizer._amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-80-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-81- `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-82- def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-83- cls,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:84: weight_quantizer,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-85- weight: torch.Tensor,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-86- weights_scaling_factor_2: torch.Tensor | None = None,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-87- keep_high_precision: bool = False,
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-92- and dynamic quantizers (computing from weight tensor).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-93-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-94- Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:95: weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-96- weight: The weight tensor (used for shape in static, values in dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-97- weights_scaling_factor_2: Optional pre-computed global scale.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-98- keep_high_precision: Whether to keep scales in high precision.
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-100- Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-101- Tuple of (per_block_scale, weights_scaling_factor_2).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-102- """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:103: block_size = weight_quantizer.block_sizes[-1]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-104-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-105- if weights_scaling_factor_2 is None:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-106- weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:107: weight_quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-108- )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-109-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:110: if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-111- # Static path: use pre-computed per-block amax values from quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:112: global_amax = weight_quantizer.global_amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:113: per_block_amax = weight_quantizer._amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-114-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-115- # Compute scales in float
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-116- per_block_scale_max = global_amax / 6.0
--
modelopt/torch/quantization/qtensor/base_qtensor.py-204- # We dont compress meta tensors or None
modelopt/torch/quantization/qtensor/base_qtensor.py-205- return False
modelopt/torch/quantization/qtensor/base_qtensor.py-206- if (
modelopt/torch/quantization/qtensor/base_qtensor.py:207: hasattr(module, "weight_quantizer")
modelopt/torch/quantization/qtensor/base_qtensor.py:208: and module.weight_quantizer.is_enabled
modelopt/torch/quantization/qtensor/base_qtensor.py:209: and not module.weight_quantizer._fake_quant
modelopt/torch/quantization/qtensor/base_qtensor.py-210- and module.weight.element_size() > 1
modelopt/torch/quantization/qtensor/base_qtensor.py-211- ):
modelopt/torch/quantization/qtensor/base_qtensor.py-212- if force_quantize:
modelopt/torch/quantization/qtensor/base_qtensor.py:213: module.weight_quantizer._dequantize = False
modelopt/torch/quantization/qtensor/base_qtensor.py-214-
modelopt/torch/quantization/qtensor/base_qtensor.py:215: real_quant_tensor = module.weight_quantizer(module.weight)
modelopt/torch/quantization/qtensor/base_qtensor.py-216- module.weight = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/qtensor/base_qtensor.py-217- return True
modelopt/torch/quantization/qtensor/base_qtensor.py-218-
--
modelopt/torch/quantization/plugins/vllm.py-80- torch.Tensor: The quantized output tensor.
modelopt/torch/quantization/plugins/vllm.py-81- """
modelopt/torch/quantization/plugins/vllm.py-82- x = layer.input_quantizer(x)
modelopt/torch/quantization/plugins/vllm.py:83: if layer.weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-84- original_weight = layer.weight
modelopt/torch/quantization/plugins/vllm.py:85: quantized_tensor = layer.weight_quantizer(layer.weight)
modelopt/torch/quantization/plugins/vllm.py-86- # parameterize the quantized weight
modelopt/torch/quantization/plugins/vllm.py-87- if isinstance(original_weight, torch.nn.Parameter) and not isinstance(
modelopt/torch/quantization/plugins/vllm.py-88- quantized_tensor, torch.nn.Parameter
--
modelopt/torch/quantization/plugins/vllm.py-110-class _VLLMParallelLinear(QuantModule):
modelopt/torch/quantization/plugins/vllm.py-111- def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-112- self.input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:113: self.weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-114- self.output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-115- self.output_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-116- assert type(self.quant_method) is vllm_linear.UnquantizedLinearMethod, (
--
modelopt/torch/quantization/plugins/vllm.py-159- def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-160- self.w13_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py-161- self.w2_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:162: self.w13_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py:163: self.w2_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-164- self.w13_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-165- self.w2_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-166- self.w13_output_quantizer.disable()
--
modelopt/torch/quantization/plugins/vllm.py-181- if B is self.w13_weight:
modelopt/torch/quantization/plugins/vllm.py-182- # First layer of expert
modelopt/torch/quantization/plugins/vllm.py-183- A = self.w13_input_quantizer(A) # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:184: if self.w13_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-185- original_weight = self.w13_weight
modelopt/torch/quantization/plugins/vllm.py:186: self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
modelopt/torch/quantization/plugins/vllm.py-187- vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-188- self.w13_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-189- else:
--
modelopt/torch/quantization/plugins/vllm.py-192- C[:] = self.w13_output_quantizer(C)
modelopt/torch/quantization/plugins/vllm.py-193- elif B is self.w2_weight:
modelopt/torch/quantization/plugins/vllm.py-194- A = self.w2_input_quantizer(A) # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:195: if self.w2_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-196- original_weight = self.w2_weight
modelopt/torch/quantization/plugins/vllm.py:197: self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
modelopt/torch/quantization/plugins/vllm.py-198- vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-199- self.w2_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-200- else:
--
modelopt/torch/quantization/plugins/vllm.py-232- # the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
modelopt/torch/quantization/plugins/vllm.py-233- for i in range(self.w13_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-234- self.w13_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:235: self.w13_weight_quantizer(self.w13_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-236- self.w13_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-237- )
modelopt/torch/quantization/plugins/vllm.py-238- )
modelopt/torch/quantization/plugins/vllm.py:239: self.w13_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-240- for i in range(self.w2_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-241- self.w2_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:242: self.w2_weight_quantizer(self.w2_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-243- self.w2_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-244- )
modelopt/torch/quantization/plugins/vllm.py-245- )
modelopt/torch/quantization/plugins/vllm.py:246: self.w2_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-247-
modelopt/torch/quantization/plugins/vllm.py-248- torch.cuda.empty_cache()
modelopt/torch/quantization/plugins/vllm.py-249-
--
modelopt/torch/quantization/nn/modules/quant_module.py-121-
modelopt/torch/quantization/nn/modules/quant_module.py-122- def fold_weight(self, keep_attrs: bool = False):
modelopt/torch/quantization/nn/modules/quant_module.py-123- """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_module.py:124: # Handle all attributes that end with _weight_quantizer
modelopt/torch/quantization/nn/modules/quant_module.py-125- for name in dir(self):
modelopt/torch/quantization/nn/modules/quant_module.py-126- attr = getattr(self, name)
modelopt/torch/quantization/nn/modules/quant_module.py-127- if (
modelopt/torch/quantization/nn/modules/quant_module.py:128: name.endswith("weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_module.py-129- and isinstance(attr, TensorQuantizer)
modelopt/torch/quantization/nn/modules/quant_module.py-130- and attr.fake_quant
modelopt/torch/quantization/nn/modules/quant_module.py-131- ):
modelopt/torch/quantization/nn/modules/quant_module.py:132: # Get the corresponding weight name by removing _weight_quantizer suffix
modelopt/torch/quantization/nn/modules/quant_module.py-133- weight_name = name[:-10]
modelopt/torch/quantization/nn/modules/quant_module.py-134-
modelopt/torch/quantization/nn/modules/quant_module.py-135- assert hasattr(self, weight_name), (
--
modelopt/torch/quantization/nn/modules/quant_module.py-203- Quantized linear modules are modules where both the input and the weight are quantized.
modelopt/torch/quantization/nn/modules/quant_module.py-204- """
modelopt/torch/quantization/nn/modules/quant_module.py-205-
modelopt/torch/quantization/nn/modules/quant_module.py:206: weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_module.py-207- _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_module.py-208- default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_module.py-209-
--
modelopt/torch/quantization/nn/modules/quant_module.py-219- `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_module.py-220- def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
modelopt/torch/quantization/nn/modules/quant_module.py-221- if module._enable_weight_quantization or is_torch_export_mode():
modelopt/torch/quantization/nn/modules/quant_module.py:222: return module.weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_module.py-223- return weight
modelopt/torch/quantization/nn/modules/quant_module.py-224-
modelopt/torch/quantization/nn/modules/quant_module.py-225- def forward(self, input, *args, **kwargs):
--
modelopt/torch/quantization/nn/modules/quant_module.py-234- def _setup(self):
modelopt/torch/quantization/nn/modules/quant_module.py-235- super()._setup()
modelopt/torch/quantization/nn/modules/quant_module.py-236- self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_module.py:237: "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_module.py-238- )
modelopt/torch/quantization/nn/modules/quant_module.py-239- self._register_temp_attribute("_enable_weight_quantization", False)
modelopt/torch/quantization/nn/modules/quant_module.py-240- self._register_dynamic_attribute("weight", self._get_quantized_weight)
--
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1319- return super()._fake_quantize(inputs)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1320-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1321-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py:1322:class SequentialQuantizer(nn.Sequential):
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1323- """A sequential container for :class:`TensorQuantizer` modules.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1324-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1325- This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-47-class QuantRNNBase(QuantModule):
modelopt/torch/quantization/nn/modules/quant_rnn.py-48- """Base class for quantized RNN modules."""
modelopt/torch/quantization/nn/modules/quant_rnn.py-49-
modelopt/torch/quantization/nn/modules/quant_rnn.py:50: weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py-51- _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_rnn.py-52- default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_rnn.py-53- default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-75- self._enable_weight_quantization = False
modelopt/torch/quantization/nn/modules/quant_rnn.py-76-
modelopt/torch/quantization/nn/modules/quant_rnn.py-77- `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_rnn.py:78: def _get_quantized_weight_handler(weight_quantizer_name: str):
modelopt/torch/quantization/nn/modules/quant_rnn.py-79- def _get_quantized_weight(module: "QuantRNNBase", weight: torch.Tensor):
modelopt/torch/quantization/nn/modules/quant_rnn.py-80- if module._enable_weight_quantization:
modelopt/torch/quantization/nn/modules/quant_rnn.py:81: weight_quantizer = getattr(module, weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py:82: return weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-83- return weight
modelopt/torch/quantization/nn/modules/quant_rnn.py-84-
modelopt/torch/quantization/nn/modules/quant_rnn.py-85- return _get_quantized_weight
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-102- for name, _ in self.named_parameters():
modelopt/torch/quantization/nn/modules/quant_rnn.py-103- if name.startswith("weight"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-104- # to be compatible with our current config, the name is some what weird
modelopt/torch/quantization/nn/modules/quant_rnn.py:105: # it would be weight_xxx_weight_quantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py:106: weight_quantizer_name = name + "_weight_quantizer"
modelopt/torch/quantization/nn/modules/quant_rnn.py-107- self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:108: weight_quantizer_name, TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-109- )
modelopt/torch/quantization/nn/modules/quant_rnn.py-110- self._register_dynamic_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:111: name, self._get_quantized_weight_handler(weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py-112- )
modelopt/torch/quantization/nn/modules/quant_rnn.py-113- # for cells
modelopt/torch/quantization/nn/modules/quant_rnn.py-114- self._register_temp_attribute("_input_quantizers", [])
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-143- for iq in self._input_quantizers + self._proj_input_quantizers:
modelopt/torch/quantization/nn/modules/quant_rnn.py-144- iq.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-145-
modelopt/torch/quantization/nn/modules/quant_rnn.py:146: def _disable_weight_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-147- for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:148: if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-149- module.disable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-150-
modelopt/torch/quantization/nn/modules/quant_rnn.py:151: def _enable_weight_quantizer(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-152- for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:153: if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-154- module.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-155-
modelopt/torch/quantization/nn/modules/quant_rnn.py-156- def _setup(self):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-46- """Quantized version of a generic linear functional."""
modelopt/torch/quantization/nn/modules/quant_linear.py-47- output = getattr(package, func_name)(
modelopt/torch/quantization/nn/modules/quant_linear.py-48- self.input_quantizer(input),
modelopt/torch/quantization/nn/modules/quant_linear.py:49: self.weight_quantizer(weight),
modelopt/torch/quantization/nn/modules/quant_linear.py-50- *args,
modelopt/torch/quantization/nn/modules/quant_linear.py-51- **kwargs,
modelopt/torch/quantization/nn/modules/quant_linear.py-52- )
--
modelopt/torch/quantization/nn/modules/quant_linear.py-119-
modelopt/torch/quantization/nn/modules/quant_linear.py-120- def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-121- """Overrides and bypass the _setup function."""
modelopt/torch/quantization/nn/modules/quant_linear.py:122: if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
modelopt/torch/quantization/nn/modules/quant_linear.py-123- return
modelopt/torch/quantization/nn/modules/quant_linear.py:124: self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-125-
modelopt/torch/quantization/nn/modules/quant_linear.py-126- def _not_sequential_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_linear.py:127: return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
modelopt/torch/quantization/nn/modules/quant_linear.py-128- self.input_quantizer, TensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-129- )
modelopt/torch/quantization/nn/modules/quant_linear.py-130-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-138- """Compute the LoRA residual if present, otherwise return None."""
modelopt/torch/quantization/nn/modules/quant_linear.py-139- if (
modelopt/torch/quantization/nn/modules/quant_linear.py-140- self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:141: and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:142: and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-143- ):
modelopt/torch/quantization/nn/modules/quant_linear.py:144: lora_a = F.linear(input, weight=self.weight_quantizer.svdquant_lora_a)
modelopt/torch/quantization/nn/modules/quant_linear.py:145: lora_b = F.linear(lora_a, weight=self.weight_quantizer.svdquant_lora_b)
modelopt/torch/quantization/nn/modules/quant_linear.py-146- return lora_b
modelopt/torch/quantization/nn/modules/quant_linear.py-147- return None
modelopt/torch/quantization/nn/modules/quant_linear.py-148-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-150- """SVDQuant layer forward function."""
modelopt/torch/quantization/nn/modules/quant_linear.py-151- has_svdquant_lora = (
modelopt/torch/quantization/nn/modules/quant_linear.py-152- self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:153: and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:154: and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-155- )
modelopt/torch/quantization/nn/modules/quant_linear.py-156- if has_svdquant_lora:
modelopt/torch/quantization/nn/modules/quant_linear.py-157- input = self._apply_pre_quant_scale(input)
--
modelopt/torch/quantization/nn/modules/quant_linear.py-166- """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_linear.py-167- super().fold_weight(keep_attrs)
modelopt/torch/quantization/nn/modules/quant_linear.py-168- if (
modelopt/torch/quantization/nn/modules/quant_linear.py:169: hasattr(self, "weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_linear.py-170- and hasattr(self, "weight")
modelopt/torch/quantization/nn/modules/quant_linear.py:171: and self.weight_quantizer.fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-172- ):
modelopt/torch/quantization/nn/modules/quant_linear.py-173- if (
modelopt/torch/quantization/nn/modules/quant_linear.py-174- self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:175: and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:176: and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-177- ):
modelopt/torch/quantization/nn/modules/quant_linear.py-178- self.weight.data.copy_(
modelopt/torch/quantization/nn/modules/quant_linear.py-179- self.weight
modelopt/torch/quantization/nn/modules/quant_linear.py:180: + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
modelopt/torch/quantization/nn/modules/quant_linear.py-181- )
modelopt/torch/quantization/nn/modules/quant_linear.py-182- if not keep_attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py-183- _attrs = [
--
modelopt/torch/quantization/nn/modules/quant_linear.py-185- "_svdquant_lora_b",
modelopt/torch/quantization/nn/modules/quant_linear.py-186- ]
modelopt/torch/quantization/nn/modules/quant_linear.py-187- for attr in _attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py:188: if hasattr(self.weight_quantizer, attr):
modelopt/torch/quantization/nn/modules/quant_linear.py:189: delattr(self.weight_quantizer, attr)
modelopt/torch/quantization/nn/modules/quant_linear.py-190-
modelopt/torch/quantization/nn/modules/quant_linear.py-191-
modelopt/torch/quantization/nn/modules/quant_linear.py-192-class RealQuantLinear(QuantModule):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-241-
modelopt/torch/quantization/nn/modules/quant_linear.py-242- def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-243- class RealQuantParameterDict(dict):
modelopt/torch/quantization/nn/modules/quant_linear.py:244: def __init__(self, weight_quantizer: TensorQuantizer, *args, **kwargs):
modelopt/torch/quantization/nn/modules/quant_linear.py-245- super().__init__(*args, **kwargs)
modelopt/torch/quantization/nn/modules/quant_linear.py:246: self.weight_quantizer = weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-247-
modelopt/torch/quantization/nn/modules/quant_linear.py-248- def __setitem__(self, key, value):
modelopt/torch/quantization/nn/modules/quant_linear.py-249- if (
modelopt/torch/quantization/nn/modules/quant_linear.py-250- key == "weight"
modelopt/torch/quantization/nn/modules/quant_linear.py:251: and self.weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py:252: and self.weight_quantizer.is_enabled
modelopt/torch/quantization/nn/modules/quant_linear.py:253: and not self.weight_quantizer._fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-254- and value.element_size() > 1
modelopt/torch/quantization/nn/modules/quant_linear.py-255- ):
modelopt/torch/quantization/nn/modules/quant_linear.py-256- # reset the amax for later calibration
modelopt/torch/quantization/nn/modules/quant_linear.py-257- if (
modelopt/torch/quantization/nn/modules/quant_linear.py:258: self.weight_quantizer.amax is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:259: and self.weight_quantizer.amax.is_meta
modelopt/torch/quantization/nn/modules/quant_linear.py-260- ):
modelopt/torch/quantization/nn/modules/quant_linear.py:261: delattr(self.weight_quantizer, "_amax")
modelopt/torch/quantization/nn/modules/quant_linear.py:262: self.weight_quantizer.amax = self.weight_quantizer._get_amax(value)
modelopt/torch/quantization/nn/modules/quant_linear.py:263: self.weight_quantizer._calibrator.reset()
modelopt/torch/quantization/nn/modules/quant_linear.py-264- # compress the weight
modelopt/torch/quantization/nn/modules/quant_linear.py:265: real_quant_tensor = self.weight_quantizer(value)
modelopt/torch/quantization/nn/modules/quant_linear.py-266- real_quant_value = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/nn/modules/quant_linear.py-267- del value # delete the original weight to save memory
modelopt/torch/quantization/nn/modules/quant_linear.py-268- value = real_quant_value
--
modelopt/torch/quantization/nn/modules/quant_linear.py-270-
modelopt/torch/quantization/nn/modules/quant_linear.py-271- # Monkey patch the _parameters.__setitem__ to real quant the weight when loading
modelopt/torch/quantization/nn/modules/quant_linear.py-272- # HF accelerate loads the weight by directly assigning the weight through the _parameters dict.
modelopt/torch/quantization/nn/modules/quant_linear.py:273: self._parameters = RealQuantParameterDict(self.weight_quantizer, self._parameters)
modelopt/torch/quantization/nn/modules/quant_linear.py-274-
modelopt/torch/quantization/nn/modules/quant_linear.py-275- # Function to dynamically override load_state_dict
modelopt/torch/quantization/nn/modules/quant_linear.py-276- dynamically_update_state_methods(self)
--
modelopt/torch/quantization/plugins/transformer_engine.py-78- idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-79- weight, inputs = args[idx], args[idx + 1]
modelopt/torch/quantization/plugins/transformer_engine.py-80- remaining_args = args[idx + 2 :]
modelopt/torch/quantization/plugins/transformer_engine.py:81: weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-82- inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-83- new_args = (weight, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-84- new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-90- idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-91- weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
modelopt/torch/quantization/plugins/transformer_engine.py-92- remaining_args = args[idx + 3 :]
modelopt/torch/quantization/plugins/transformer_engine.py:93: weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-94- inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-95- new_args = (weight, weight_fp8, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-96- new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-170- weights_and_biases = args[-2 * num_gemms :]
modelopt/torch/quantization/plugins/transformer_engine.py-171- weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
modelopt/torch/quantization/plugins/transformer_engine.py-172- quantized_inputs = self.input_quantizer(inp)
modelopt/torch/quantization/plugins/transformer_engine.py:173: quantized_weights = [self.weight_quantizer(weight) for weight in weights]
modelopt/torch/quantization/plugins/transformer_engine.py-174-
modelopt/torch/quantization/plugins/transformer_engine.py-175- output = getattr(package, func_name)(
modelopt/torch/quantization/plugins/transformer_engine.py-176- *(
--
modelopt/torch/quantization/plugins/transformer_engine.py-208-
modelopt/torch/quantization/plugins/transformer_engine.py-209- `@staticmethod`
modelopt/torch/quantization/plugins/transformer_engine.py-210- def forward(ctx, inp, ln_weight, ln_bias, weight, *args, **kwargs):
modelopt/torch/quantization/plugins/transformer_engine.py:211: input_quantizer, weight_quantizer = _QuantLayerNormLinearFunc.modelopt_quantizers
modelopt/torch/quantization/plugins/transformer_engine.py-212-
modelopt/torch/quantization/plugins/transformer_engine.py:213: qweight = weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-214- qweight.requires_grad = weight.requires_grad
modelopt/torch/quantization/plugins/transformer_engine.py-215- if ctx is not None:
modelopt/torch/quantization/plugins/transformer_engine.py-216- # We need to recompute the quantized input for the backward pass, so we save the input_quantizer
--
modelopt/torch/quantization/plugins/transformer_engine.py-291- # This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
modelopt/torch/quantization/plugins/transformer_engine.py-292- _QuantLayerNormLinearFunc.modelopt_quantizers = (
modelopt/torch/quantization/plugins/transformer_engine.py-293- self.input_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py:294: self.weight_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py-295- )
modelopt/torch/quantization/plugins/transformer_engine.py-296- with replace_function(
modelopt/torch/quantization/plugins/transformer_engine.py-297- te_layernorm_linear,
--
modelopt/torch/quantization/plugins/peft.py-34-class _QuantLoraLinear(QuantModule):
modelopt/torch/quantization/plugins/peft.py-35- def _setup(self):
modelopt/torch/quantization/plugins/peft.py-36- self.input_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py:37: self.weight_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py-38- self.output_quantizer = Tenso
... [stdout truncated: output exceeded 50000 characters (209210 total). Please modify your command to reduce output size, for example by:
- Using grep/rg to filter results
- Using head/tail to limit lines
- Using --max-count flag with grep/rg
- Processing output with awk/sed to extract only needed information]
Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Full context of sync_moe_gate_up_amax function ==="
sed -n '1174,1221p' modelopt/torch/export/layer_utils.py
echo
echo "=== Check if SequentialQuantizer has amax property ==="
sed -n '1322,1400p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -80
echo
echo "=== Check imports and _GATE_UP_PAIRS ==="
sed -n '1,50p' modelopt/torch/export/layer_utils.py
rg -n '_GATE_UP_PAIRS|from.*SequentialQuantizer|import.*SequentialQuantizer' modelopt/torch/export/layer_utils.pyRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
=== Full context of sync_moe_gate_up_amax function ===
# Gate/up naming pairs for standard (unfused) MoE architectures.
# Fused variants (gate_up_proj, linear_fc1) already share a single quantizer and need no sync.
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
def sync_moe_gate_up_amax(model: nn.Module) -> int:
"""Take element-wise max of gate and up weight quantizer amaxes per expert.
Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
syncing amaxes before quantization ensures the per-block weight_scale values
are computed against a consistent global scale.
Only affects standard MoE models with separate gate/up linear layers
(e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
(e.g. Llama4, GptOss) are unaffected.
Returns:
Number of expert gate/up pairs whose amaxes were synced.
"""
synced = 0
for _, sub_module in model.named_modules():
if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
continue
if not hasattr(sub_module.experts, "__iter__"):
continue
for expert in sub_module.experts:
for gate_name, up_name in _GATE_UP_PAIRS:
gate_linear = getattr(expert, gate_name, None)
up_linear = getattr(expert, up_name, None)
if gate_linear is None or up_linear is None:
continue
gate_wq = getattr(gate_linear, "weight_quantizer", None)
up_wq = getattr(up_linear, "weight_quantizer", None)
if gate_wq is None or up_wq is None:
break
gate_amax = getattr(gate_wq, "amax", None)
up_amax = getattr(up_wq, "amax", None)
if gate_amax is None or up_amax is None:
break
if not torch.equal(gate_amax, up_amax):
shared_amax = torch.max(gate_amax, up_amax)
gate_wq.amax = shared_amax
up_wq.amax = shared_amax.clone()
synced += 1
break
return synced
=== Check if SequentialQuantizer has amax property ===
class SequentialQuantizer(nn.Sequential):
"""A sequential container for :class:`TensorQuantizer` modules.
This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
:class:`TensorQuantizer` modules and containerize them similar to :class:`torch.nn.Sequential`.
We delegate certain properties and methods to all contained quantizers.
In the case of conflicts, the first quantizer's property or method takes priority.
`SequentialQuantizer` is useful in cases like INT4 weights, FP8 activations where weight quantization is not the
same as the gemm quantization. It allows for applying multiple quantization formats to the same tensor in sequence.
Use `SequentialQuantizer` methods in lower level implementations for better code organization and readability.
Args:
quantizers (TensorQuantizer): :class:`TensorQuantizer` modules to be added to the container.
"""
_delegated_properties = ["fake_quant", "is_enabled", "amax"]
_delegated_methods = [
"reset_amax",
"disable",
"enable",
"load_calib_amax",
"load_calib_bias",
]
def __init__(self, *quantizers: TensorQuantizer):
"""Initialize SequentialQuantizer module."""
super().__init__(*quantizers)
assert all(isinstance(q, TensorQuantizer) for q in self), (
"All quantizers must be a TensorQuantizer."
)
def __getattr__(self, name):
"""Delegate properties and methods to all contained quantizers."""
if name in self._delegated_properties:
# Return the property of the first quantizer
return getattr(self[0], name)
if name in self._delegated_methods:
def method_wrapper(*args, **kwargs):
outputs = getattr(self[0], name)(*args, **kwargs)
for quantizer in self[1:]:
outputs = getattr(quantizer, name)(*args, **kwargs)
return outputs
return method_wrapper
# Defer to super class for attributes not handled here
return super().__getattr__(name)
def __setattr__(self, name, value):
if name in self._delegated_properties:
for quantizer in self:
setattr(quantizer, name, value)
else:
super().__setattr__(name, value)
def get_modelopt_state(self) -> dict[str, Any]:
"""Get meta state to be saved in checkpoint."""
return {"num_quantizers": len(self), "is_sequential_quantizer": True}
def set_from_attribute_config(
self,
attributes: list[dict[str, Any] | QuantizerAttributeConfig]
| dict[str, Any]
| QuantizerAttributeConfig,
):
"""Set the attributes of contained quantizers from a list of attribute_dicts."""
if not isinstance(attributes, (list, tuple)):
assert isinstance(attributes, (dict, QuantizerAttributeConfig)), (
"attributes must be a list or a dict."
)
attributes = [attributes] * len(self)
for attribute, quantizer in zip(attributes, self):
=== Check imports and _GATE_UP_PAIRS ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for model_config export.
Some of the logics in this file are empirical and needs constant update if exceptions occur.
"""
from warnings import warn
import torch
import torch.nn as nn
try:
from transformers.activations import ACT2FN
except Exception:
warn("Cannot find transformers package. Hugginface modules cannot be exported.")
from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import import_plugin
from ..quantization.nn import SequentialQuantizer, TensorQuantizer
from .hf_config_map import HF_CONFIG_MAP
from .mcore_config_map import MCORE_CONFIG_MAP
from .model_config import (
LAYERNORM_DEFAULT,
LAYERNORM_RMS,
LINEAR_COLUMN,
LINEAR_GROUP,
LINEAR_ROW,
QUANTIZATION_FP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
AttentionConfig,
ConvConfig,
DecoderLayerConfig,
EmbeddingConfig,
ExpertConfig,
34:from ..quantization.nn import SequentialQuantizer, TensorQuantizer
1176:_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
1201: for gate_name, up_name in _GATE_UP_PAIRS:
Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check build_mlp_config handling at lines 827-831 ==="
sed -n '820,840p' modelopt/torch/export/layer_utils.py
echo
echo "=== Check if SequentialQuantizer.amax delegation actually works ==="
sed -n '1342,1380p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py
echo
echo "=== Verify the loop structure - what does break actually break? ==="
sed -n '1195,1225p' modelopt/torch/export/layer_utils.py | cat -nRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
=== Check build_mlp_config handling at lines 827-831 ===
or get_quantization_format(module) in [QUANTIZATION_FP8, QUANTIZATION_NVFP4]
)
):
preprocess_linear_fusion([fc_linear, gate_linear])
if fc_linear is not None:
weight_quantizer = None
if hasattr(fc_linear, "weight_quantizer"):
weight_quantizer = fc_linear.weight_quantizer
if isinstance(weight_quantizer, SequentialQuantizer):
weight_quantizer = weight_quantizer[0]
# swap fused fc and gate
if decoder_type in ["chatglm", "phi3"]:
weights = torch.chunk(fc_linear.weight, 2, dim=0)
weights = (weights[1], weights[0])
fc_linear.weight.data = torch.cat(weights, dim=0)
if (
weight_quantizer is not None
and weight_quantizer.is_enabled
=== Check if SequentialQuantizer.amax delegation actually works ===
_delegated_methods = [
"reset_amax",
"disable",
"enable",
"load_calib_amax",
"load_calib_bias",
]
def __init__(self, *quantizers: TensorQuantizer):
"""Initialize SequentialQuantizer module."""
super().__init__(*quantizers)
assert all(isinstance(q, TensorQuantizer) for q in self), (
"All quantizers must be a TensorQuantizer."
)
def __getattr__(self, name):
"""Delegate properties and methods to all contained quantizers."""
if name in self._delegated_properties:
# Return the property of the first quantizer
return getattr(self[0], name)
if name in self._delegated_methods:
def method_wrapper(*args, **kwargs):
outputs = getattr(self[0], name)(*args, **kwargs)
for quantizer in self[1:]:
outputs = getattr(quantizer, name)(*args, **kwargs)
return outputs
return method_wrapper
# Defer to super class for attributes not handled here
return super().__getattr__(name)
def __setattr__(self, name, value):
if name in self._delegated_properties:
for quantizer in self:
setattr(quantizer, name, value)
else:
=== Verify the loop structure - what does break actually break? ===
1 for _, sub_module in model.named_modules():
2 if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
3 continue
4 if not hasattr(sub_module.experts, "__iter__"):
5 continue
6 for expert in sub_module.experts:
7 for gate_name, up_name in _GATE_UP_PAIRS:
8 gate_linear = getattr(expert, gate_name, None)
9 up_linear = getattr(expert, up_name, None)
10 if gate_linear is None or up_linear is None:
11 continue
12 gate_wq = getattr(gate_linear, "weight_quantizer", None)
13 up_wq = getattr(up_linear, "weight_quantizer", None)
14 if gate_wq is None or up_wq is None:
15 break
16 gate_amax = getattr(gate_wq, "amax", None)
17 up_amax = getattr(up_wq, "amax", None)
18 if gate_amax is None or up_amax is None:
19 break
20 if not torch.equal(gate_amax, up_amax):
21 shared_amax = torch.max(gate_amax, up_amax)
22 gate_wq.amax = shared_amax
23 up_wq.amax = shared_amax.clone()
24 synced += 1
25 break
26 return synced
27
28
29 def build_stacked_experts(
30 experts: nn.Module,
31 linear_names: list[str],
Use continue instead of break to check all gate/up pairs (not just the first) per expert.
On lines 1209 and 1213, break exits the pair-checking loop entirely if any pair lacks a quantizer or amax. Since _GATE_UP_PAIRS has two pairs (gate_proj/up_proj and w1/w3), a missing quantizer in the first pair prevents checking the second. This should be continue to skip only the current pair and attempt the next one. (The final break on line 1214 correctly exits after syncing.)
Proposed fix
gate_wq = getattr(gate_linear, "weight_quantizer", None)
up_wq = getattr(up_linear, "weight_quantizer", None)
if gate_wq is None or up_wq is None:
- break
+ continue
gate_amax = getattr(gate_wq, "amax", None)
up_amax = getattr(up_wq, "amax", None)
if gate_amax is None or up_amax is None:
- break
+ continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/layer_utils.py` around lines 1206 - 1213, The loop that
iterates over _GATE_UP_PAIRS incorrectly uses break when a pair is missing a
quantizer or amax, which stops checking remaining pairs; update the logic in the
loop that accesses gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to
use continue instead of break so only the current pair is skipped and the next
pair in _GATE_UP_PAIRS is checked; keep the final break that exits after
successfully syncing unchanged.
What does this PR do?
Unify
weight_scale_2betweengate_proj/up_proj(andw1/w3) in the HF export path for MOE models. Serving engines fuse these projections into a singlegate_up_projand require a shared scale; this takes the element-wise max of the two independent scales as a conservative choice that avoids overflow.Type of change: ? Bug fix
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Release Notes