Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
03ad6a9
ModelOpt Framework, Recipe Lib, converting existing recipes 1/N
shengliangxu Dec 5, 2025
8b48590
Merge remote-tracking branch 'origin/main' into shengliangx/modeopt-r…
shengliangxu Mar 13, 2026
901b948
StrEnum is not available at 3.10
shengliangxu Mar 13, 2026
a6529be
fix for python 3.10
shengliangxu Mar 13, 2026
5e82641
Flatten all the configurations
shengliangxu Mar 16, 2026
5c4d6ce
remove __base__ inheritance logic
shengliangxu Mar 16, 2026
113390c
Merge remote-tracking branch 'origin/main' into shengliangx/modeopt-r…
shengliangxu Mar 16, 2026
4ab4861
using eXmY instead of a list in YAML
shengliangxu Mar 16, 2026
8d8fe0f
remove list parsing for the bits
shengliangxu Mar 16, 2026
518592a
print the recipe path that we actually use
shengliangxu Mar 16, 2026
972ae8b
move tests
shengliangxu Mar 16, 2026
87e1f70
keep minimum set
shengliangxu Mar 16, 2026
a51aed3
Merge remote-tracking branch 'origin/main' into HEAD
shengliangxu Mar 16, 2026
e8e39d3
flatten quant_cfg
shengliangxu Mar 17, 2026
395d274
remove the config building blocks, will add back when inheritance is …
shengliangxu Mar 17, 2026
2259b7c
move assertions to tests
shengliangxu Mar 17, 2026
0e8af88
Merge remote-tracking branch 'origin/main' into shengliangx/modeopt-r…
shengliangxu Mar 17, 2026
b61f288
wrap the quant_cfg and algorithm into ptq_cfg
shengliangxu Mar 17, 2026
9e00bf1
final bring back directory recipes
shengliangxu Mar 17, 2026
e41157a
final cleanup
shengliangxu Mar 17, 2026
e70e84a
typo
shengliangxu Mar 17, 2026
c146218
mypy
shengliangxu Mar 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
except ImportError:
snapshot_download = None

import modelopt.torch.quantization as mtq
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,22 +198,13 @@ def calibrate_loop(_model):

def build_quant_cfg(
qformat,
kv_cache_qformat,
quant_cfg,
awq_block_size,
model_type,
quant_cfg_choices,
kv_quant_cfg_choices,
moe_calib_experts_ratio: float | None = None,
) -> dict[str, Any]:
quant_cfg = {}
assert qformat in quant_cfg_choices, (
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
)

quant_cfg = quant_cfg_choices[qformat]

if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
quant_cfg = copy.deepcopy(quant_cfg)
if "awq" in str(quant_cfg.get("algorithm")):
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
Expand All @@ -226,16 +216,6 @@ def build_quant_cfg(
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

enable_quant_kv_cache = kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!!

quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)

if moe_calib_experts_ratio:
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
if isinstance(quant_cfg["algorithm"], str):
Expand Down
90 changes: 51 additions & 39 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from modelopt.torch.export import (
export_hf_checkpoint,
export_speculative_decoding,
Expand Down Expand Up @@ -262,7 +263,7 @@ def auto_quantize(
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
assert all(
args.qformat
qformat
in [
"fp8",
"int8_sq",
Expand All @@ -277,7 +278,7 @@ def auto_quantize(
"nvfp4_omlp_only",
"mxfp8",
]
for args.qformat in qformat_list
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

def loss_func(output, data):
Expand Down Expand Up @@ -548,9 +549,6 @@ def mono_quantize(
print("Quantization will only be applied to the decoder (text generation) component")

if not model_is_already_quantized or calibration_only:
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")

# quantize the model

use_calibration = need_calibration(quant_cfg)
Expand Down Expand Up @@ -746,8 +744,6 @@ def pre_quantize(
)
else:
generated_ids_before_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")

return preview_input_ids, generated_ids_before_ptq

Expand Down Expand Up @@ -923,38 +919,42 @@ def quantize_main(

else:
# mono quantization
assert len(args.qformat.split(",")) == 1, (
"Plain quantization supports only one quantization format."
)

assert (
args.qformat
in [
"int8_wo",
"int4_awq",
"fp8",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_omlp_only",
"mxfp8",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Plain quantization format {args.qformat} not supported for HF export path"

quant_cfg = build_quant_cfg(
args.qformat,
args.kv_cache_qformat,
args.awq_block_size,
model_type,
QUANT_CFG_CHOICES,
KV_QUANT_CFG_CHOICES,
args.moe_calib_experts_ratio,
)
if args.recipe is not None:
print(f"Use recipe {args.recipe} for quantization")
recipe = load_recipe(args.recipe)
assert isinstance(recipe, ModelOptPTQRecipe), (
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
)
Comment thread
shengliangxu marked this conversation as resolved.
quant_cfg = recipe.ptq_cfg

else:
assert len(args.qformat.split(",")) == 1, (
Comment thread
kevalmorabia97 marked this conversation as resolved.
"Plain quantization supports only one quantization format."
)

assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]

quant_cfg = build_quant_cfg(
args.qformat,
quant_cfg,
args.awq_block_size,
model_type,
args.moe_calib_experts_ratio,
)

enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
)

# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
# These layers are typically speculative decoding layers that should be exported as-is
Expand Down Expand Up @@ -1013,9 +1013,21 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--pyt_ckpt_path",
help="Specify where the PyTorch checkpoint path is",
"--model",
help=(
"Model name or path to the PyTorch checkpoint to be quantized. "
"Can be a local path or a Huggingface model name."
),
required=True,
)
parser.add_argument(
"--recipe",
help=(
"PTQ recipe YAML file or name without suffix (e.g. general/ptq/nvfp4_default-fp8_kv)."
),
default=None,
)

parser.add_argument("--device", default="cuda")
parser.add_argument(
"--qformat",
Expand Down
17 changes: 13 additions & 4 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,25 @@ def main(args):
trust_remote_code=args.trust_remote_code,
)

# Build quantization config
quant_cfg = QUANT_CFG_CHOICES[args.qformat]

quant_cfg = build_quant_cfg(
args.qformat,
args.kv_cache_qformat,
quant_cfg,
args.awq_block_size,
model_type,
QUANT_CFG_CHOICES,
KV_QUANT_CFG_CHOICES,
)

enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
)
Comment thread
shengliangxu marked this conversation as resolved.

# Quantize the model
if accelerator.is_main_process:
print("Starting quantization...")
Expand Down
27 changes: 27 additions & 0 deletions modelopt/recipe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for the ModelOpt recipe lib.

``modelopt.recipe`` contains tooling to:

* load and store model optimization recipes
* (TODO) utilities to manipulate the recipes, such as merging multiple recipes together, or
overriding some fields in a recipe with user-provided values.

"""

from .config import *
from .loader import *
113 changes: 113 additions & 0 deletions modelopt/recipe/_config_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""YAML config loading utilities.

This module is intentionally free of ``modelopt.torch`` imports so that
``modelopt.torch.quantization.config`` can import :func:`load_config` without
triggering a circular import through ``modelopt.recipe.loader``.
"""

from importlib.resources import files

try:
from importlib.resources.abc import Traversable
except ImportError: # Python < 3.11
from importlib.abc import Traversable
import re
from pathlib import Path
from typing import Any

import yaml

# Root to all built-in recipes. Users can create own recipes.
BUILTIN_RECIPES_LIB = files("modelopt_recipes")

_EXMY_RE = re.compile(r"^[Ee](\d+)[Mm](\d+)$")
_EXMY_KEYS = frozenset({"num_bits", "scale_bits"})


def _parse_exmy_num_bits(obj: Any) -> Any:
"""Recursively convert ``ExMy`` strings in ``num_bits`` / ``scale_bits`` to ``(x, y)`` tuples."""
if isinstance(obj, dict):
return {
k: (
_parse_exmy(v)
if k in _EXMY_KEYS and isinstance(v, str)
else _parse_exmy_num_bits(v)
)
for k, v in obj.items()
}
if isinstance(obj, list):
return [_parse_exmy_num_bits(item) for item in obj]
return obj


def _parse_exmy(s: str) -> tuple[int, int] | str:
m = _EXMY_RE.match(s)
if m:
return (int(m.group(1)), int(m.group(2)))
return s


def load_config(config_file: str | Path | Traversable) -> dict[str, Any]:
"""Load a config yaml.

config_file: Path to a config yaml file. The path suffix can be omitted.
"""
Comment thread
shengliangxu marked this conversation as resolved.
paths_to_check: list[Path | Traversable] = []
if isinstance(config_file, str):
if not config_file.endswith(".yml") and not config_file.endswith(".yaml"):
paths_to_check.append(Path(f"{config_file}.yml"))
paths_to_check.append(Path(f"{config_file}.yaml"))
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml"))
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml"))
else:
paths_to_check.append(Path(config_file))
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(config_file))
elif isinstance(config_file, Path):
if config_file.suffix in (".yml", ".yaml"):
paths_to_check.append(config_file)
if not config_file.is_absolute():
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(str(config_file)))
else:
paths_to_check.append(Path(f"{config_file}.yml"))
paths_to_check.append(Path(f"{config_file}.yaml"))
if not config_file.is_absolute():
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml"))
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml"))
elif isinstance(config_file, Traversable):
paths_to_check.append(config_file)
else:
raise ValueError(f"Invalid config file of {config_file}")

config_path = None
for path in paths_to_check:
if path.is_file():
config_path = path
break
if not config_path:
raise ValueError(
f"Cannot find config file of {config_file}, paths checked: {paths_to_check}"
)

_raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
if _raw is None:
return {}
if not isinstance(_raw, dict):
raise ValueError(
f"Config file {config_path} must contain a YAML mapping, got {type(_raw).__name__}"
)
return _parse_exmy_num_bits(_raw)
Loading
Loading