OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
Conversation
📝 WalkthroughWalkthroughAdds FP8 post-processing to convert TRT-specific FP8 QDQ nodes to native ONNX QDQ nodes, introduces ONNX Cast utilities (redundant-cast removal and targeted FP16 cast conversion), and updates Torch ONNX export flow to use these utilities with a reordered FP16/quantization pipeline. Changes
Sequence Diagram(s)sequenceDiagram
participant Exporter as FP8 Exporter
participant Graph as ONNX Graph
participant TRT as TRT_FP8 Nodes
participant Native as Native ONNX Ops
participant Cleaner as Graph Cleaner
Exporter->>Graph: scan for TRT_FP8QuantizeLinear / TRT_FP8DequantizeLinear
Graph->>TRT: identify TRT_FP8QuantizeLinear nodes
Exporter->>Graph: for each TRT_FP8QuantizeLinear -> create `zero_point` const if missing
Exporter->>Graph: replace TRT_FP8QuantizeLinear with QuantizeLinear (set saturate)
Graph->>TRT: identify TRT_FP8DequantizeLinear nodes
Exporter->>Graph: replace TRT_FP8DequantizeLinear with DequantizeLinear
Exporter->>Cleaner: invoke cleanup & topological sort
Cleaner->>Graph: remove unused nodes, fix edges, toposort
Cleaner->>Native: graph now uses native ONNX QDQ nodes
Exporter->>Exporter: export cleaned ONNX model
Note over Exporter,Cleaner: logger.info/debug traces conversions
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 104-140: The post_process function's docstring mentions updating
GELU nodes to tanh approximation and inserting Cast nodes after Sqrt, but the
implementation in post_process only converts
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to
QuantizeLinear/DequantizeLinear; either remove or revise those docstring lines
to reflect current behavior, or implement the missing steps: locate GELU nodes
in graph.nodes and replace/modify them to the tanh-approx variant, and insert
Cast nodes immediately after Sqrt nodes' outputs; reference post_process,
TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, GELU, and Sqrt when making the
change.
- Around line 119-126: The FP8 zero-point tensor zp_tensor is missing explicit
shape metadata; update the creation of zp_tensor (used to build zero_point and
appended to node.inputs) to set its dims explicitly (e.g., call
zp_tensor.dims.extend([1]) for a 1-element tensor) so it matches other tensors
created in this module (see the FP8 weights tensor creation) and ensures ONNX
runtimes receive shape info.
In `@modelopt/onnx/utils.py`:
- Around line 1314-1349: In change_casts_to_fp16, only modify Cast nodes that
actually cast from FP32: for each Cast node (node.op_type == "Cast") look up the
source tensor name node.input[0] in graph.initializer, graph.input,
graph.value_info or graph.output to get its element_type and only change the
node.attribute "to" from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 if
the source dtype is FLOAT; also avoid changing Casts that are FP16->FP32 and add
a debug log entry when you modify a Cast (include node.name or node.output[0]
and original->new dtypes) to aid debugging.
🧹 Nitpick comments (1)
modelopt/onnx/utils.py (1)
1218-1261: Consider edge case where first Cast has multiple consumers.The function checks
len(node.outputs[0].outputs) != 1(line 1231) to ensure the first Cast's output goes to exactly one node. However, this may be overly restrictive. If the first Cast feeds into a duplicate second Cast AND other nodes, you could still remove the duplicate Cast while preserving the connection to other consumers. The current logic skips this optimization opportunity.This is a minor optimization opportunity and the current implementation is safe.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #852 +/- ##
==========================================
+ Coverage 70.07% 70.10% +0.03%
==========================================
Files 221 221
Lines 25499 25505 +6
==========================================
+ Hits 17869 17881 +12
+ Misses 7630 7624 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@ajrasane can you please add the before and after accuracy results in the PR description? I.e: with FP8 custom Q/DQ nodes vs with FP8 native Q/DQ nodes. Thanks! |
|
Let's also add this change to the Changelog file. |
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
There was a problem hiding this comment.
Can you please elaborate the goal/need of this function? Thanks!
There was a problem hiding this comment.
This is because after the convert_float_to_float16() function, one of the inputs for these nodes is FP16, while the other is FP32. Hence we run into a compilation issue with TensorRT. To fix this, I manually update them here for these operators.
There was a problem hiding this comment.
Got it, thanks for the explanation. Can you please update the docstring to give a bit more details? Thanks!
There was a problem hiding this comment.
@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?
There was a problem hiding this comment.
I run into an error while building the engine with TensorRT:
[03/12/2026-21:06:15] [E] Error[9]: Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [myelin_graph.h:attachExceptionMsgToGraph:1146] MyelinCheckException: operand.h:456: CHECK(is_tensor()) failed. In compileGraph at optimizer/myelin/codeGenerator.cpp:1421
[03/12/2026-21:06:15] [E] Error[10]: IBuilder::buildSerializedNetworkToStream: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[x_cast_to_fp16...(Unnamed Layer* 1752) [ElementWise]]}. In computeCosts at optimizer/common/tactic/optimizer.cpp:4115)
I also remember that you had previously mentioned that autocast is not supposed to be used after quantization as it would need a separate design. Hence I removed it from here. Let me know if that is no more the case.
9b30f17 to
e2abd9d
Compare
Accuracy looks good, any idea why perf is slower after this PR? Also, can you please specify which model these numbers are for? Thanks. |
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
There was a problem hiding this comment.
@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?
modelopt/onnx/utils.py
Outdated
| logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}") | ||
|
|
||
| if removed_count > 0: | ||
| graph.cleanup().toposort() |
There was a problem hiding this comment.
I recall some issues with toposort.
If you see any failures do to it, we can probably omit, _bypass_cast maintains node sorting.
AutoCast's unit testing covers this part well, and indeed, I see there's quite a few failures with this refactor. |
0186223 to
788313f
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
1151-1172:⚠️ Potential issue | 🟠 MajorThis shared cast cleanup can undo output-name preservation.
_cleanup()already fixes network output names before this call, butonnx_utils.remove_redundant_casts()bypasses output casts by replacinggraph.outputswith the cast input. For redundant casts on model outputs, that reverts the exported output tensor name to the pre-cast name and can trip_sanity_check()or break the public I/O contract.
🧹 Nitpick comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)
62-71: Scope thisonnxconverter_commonworkaround to the conversion call.Patching the module at import time changes behavior process-wide, and
suppress(AttributeError)hides every upstreamAttributeError, not just the known list/attr bug. A temporary patch aroundconvert_float_to_float16()is much safer, and it avoids making this module import brittle if the upstream symbol changes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 62 - 71, The current import-time monkey-patch of _f16_module.remove_unnecessary_cast_node (using _original_remove_unnecessary_cast_node and _patched_remove_unnecessary_cast_node) is global and hides all AttributeError via suppress(AttributeError); instead, scope the workaround only around the call to convert_float_to_float16(): before calling convert_float_to_float16() save the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal wrapper that only catches the specific list/attribute error, call convert_float_to_float16(), and finally restore the original function in a try/finally so the patch is temporary and does not swallow unrelated AttributeErrors or affect the rest of the process.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: In FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 576-599: The model_metadata is built too early and can become
stale after graph rewrites; move the metadata creation so it runs after all ONNX
mutations: after quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.
- Around line 581-588: The FP16 export path should not use torch.autocast during
tracing because you already perform explicit post-export conversion with
convert_float_to_float16; update the autocast logic so the
torch.autocast("cuda") context is only entered when weights_dtype == "bf16" (and
not when weights_dtype == "fp16"), i.e., change the condition that currently
enables autocast for weights_dtype != "fp32" to specifically check for "bf16"
and leave the FP16 path to rely solely on convert_float_to_float16; keep the
convert_float_to_float16 call for FP16 unchanged.
---
Nitpick comments:
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 62-71: The current import-time monkey-patch of
_f16_module.remove_unnecessary_cast_node (using
_original_remove_unnecessary_cast_node and
_patched_remove_unnecessary_cast_node) is global and hides all AttributeError
via suppress(AttributeError); instead, scope the workaround only around the call
to convert_float_to_float16(): before calling convert_float_to_float16() save
the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal
wrapper that only catches the specific list/attribute error, call
convert_float_to_float16(), and finally restore the original function in a
try/finally so the patch is temporary and does not swallow unrelated
AttributeErrors or affect the rest of the process.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a86e0e09-2562-47a8-9985-98fa922fd2f0
📒 Files selected for processing (6)
CHANGELOG.rstmodelopt/onnx/autocast/precisionconverter.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/export/nvfp4_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.py
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/utils.py
| # Convert TRT_FP8QuantizeLinear to native QuantizeLinear | ||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8QuantizeLinear": | ||
| node.op = "QuantizeLinear" | ||
| # Add FP8 zero_point if not present | ||
| if len(node.inputs) == 2: | ||
| # Create FP8 zero point constant | ||
| zp_tensor = onnx.TensorProto() | ||
| zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| zp_tensor.dims.extend([1]) # 1-element tensor | ||
| zp_tensor.raw_data = b"\x00" # Zero in FP8 | ||
| zp_values = LazyValues(zp_tensor) | ||
| zero_point = gs.Constant(node.name + "_zero_point", zp_values) | ||
| node.inputs.append(zero_point) | ||
| # Add saturate attribute for FP8 | ||
| node.attrs["saturate"] = 1 | ||
| logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear") | ||
|
|
||
| # Convert TRT_FP8DequantizeLinear to native DequantizeLinear | ||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8DequantizeLinear": | ||
| node.op = "DequantizeLinear" | ||
| logger.debug( | ||
| f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" | ||
| ) | ||
|
|
||
| graph.cleanup().toposort() |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Using the official ONNX operator and type documentation, what is the minimum ai.onnx opset that supports FLOAT8E4M3FN with QuantizeLinear and DequantizeLinear, including the saturate attribute on QuantizeLinear?
💡 Result:
Minimum ai.onnx opset: 19.
- QuantizeLinear first supports
tensor(float8e4m3fn)(via itsT2type constraint) in since_version = 19, and this is also the version that defines thesaturateattribute (applying to float8 quantization). [1] - DequantizeLinear first supports
tensor(float8e4m3fn)(via itsT1type constraint) in since_version = 19. [2]
Sources:
[1] QuantizeLinear (ONNX operator doc, version 19) (onnx.ai)
[2] DequantizeLinear (ONNX operator doc, version 19) (onnx.ai)
Citations:
- 1: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html?utm_source=openai
- 2: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
🏁 Script executed:
# Check the fp8_exporter.py file to understand the full context
# Look for the function containing this code and any opset checks
fd "fp8_exporter.py" -x wc -l {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 153
🏁 Script executed:
# Read the file to understand the function context and any opset validation
cat -n modelopt/onnx/export/fp8_exporter.py | head -200Repository: NVIDIA/Model-Optimizer
Length of output: 7470
🏁 Script executed:
# Search for opset-related code and checks in the file
rg -n "opset|onnx_opset" modelopt/onnx/export/fp8_exporter.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Find the base exporter to understand the parent class
fd "base_exporter.py" modelopt/onnx/exportRepository: NVIDIA/Model-Optimizer
Length of output: 103
🏁 Script executed:
# Search for calls to post_process and get_onnx_bytes_and_metadata
rg -n "post_process|get_onnx_bytes_and_metadata" modelopt/onnx/export --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 942
🏁 Script executed:
# Check if there's opset validation in the base exporter
cat -n modelopt/onnx/export/base_exporter.py | head -300Repository: NVIDIA/Model-Optimizer
Length of output: 2365
🏁 Script executed:
# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata|onnx_opset" modelopt/onnx --type py -B 2 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 823
🏁 Script executed:
# Look for where FP8QuantExporter is instantiated/used
rg -n "FP8QuantExporter" modelopt/onnx --type py -B 2 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1792
🏁 Script executed:
# Check the quantizer file that likely calls the exporter
fd "quantizer.py" modelopt/onnx -x wc -l {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Search more broadly for onnx_opset and opset parameter usage
rg -n "onnx_opset|opset_version" modelopt/onnx --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 3244
🏁 Script executed:
# Look for quantizer or export related files that handle opset
fd "quantizer|export" modelopt/onnx --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 235
🏁 Script executed:
# Check the quantizer base class
rg -n "class.*Quantizer" modelopt/onnx --type py -A 10 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the ort_patching.py file around the FP8 opset check
cat -n modelopt/onnx/quantization/ort_patching.py | sed -n '220,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 1445
🏁 Script executed:
# Check where _check_opset_version is called and in what context
rg -n "_check_opset_version|check_opset_version" modelopt/onnx/quantization --type py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1253
🏁 Script executed:
# Look at the quantize.py file to understand the full flow
cat -n modelopt/onnx/quantization/quantize.py | sed -n '140,185p'Repository: NVIDIA/Model-Optimizer
Length of output: 2716
🏁 Script executed:
# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata" modelopt/onnxRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the quantize_mode constants to understand FP8 mode handling
rg -n "QDQ_PRECISION_MIN_OPSET|fp8" modelopt/onnx/quantization/quantize.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 3436
🏁 Script executed:
# Check if there's a separate export entry point that may bypass opset checks
rg -n "def.*export|get_onnx_bytes" modelopt/onnx --type py | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 473
🏁 Script executed:
# Search for quantize_fp8 function
rg -n "def quantize_fp8|def quantize" modelopt/onnx/quantization/fp8.py -A 10 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 604
🏁 Script executed:
# Look for where process_model is called
rg -n "process_model" modelopt/onnx --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 453
🏁 Script executed:
# Check the fp8.py file to see if it calls FP8QuantExporter.post_process
cat -n modelopt/onnx/quantization/fp8.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 7314
🏁 Script executed:
# Continue reading fp8.py to see where FP8QuantExporter is used
cat -n modelopt/onnx/quantization/fp8.py | sed -n '150,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 5087
🏁 Script executed:
# Search for where FP8QuantExporter.process_model is called
rg -n "\.process_model|FP8QuantExporter" modelopt/onnx --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 1639
🏁 Script executed:
# Check for any test files or examples that use FP8 export
fd "test.*fp8|fp8.*test" modelopt --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 235
🏁 Script executed:
# Search for where process_model is called with FP8QuantExporter
rg -n "process_model\|FP8QuantExporter" modelopt/onnxRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look in llm_export_utils since it showed up in opset search
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '155,170p'Repository: NVIDIA/Model-Optimizer
Length of output: 379
🏁 Script executed:
# Check the entire quantize function end to see if it calls process_model or post_process
cat -n modelopt/onnx/quantization/fp8.py | tail -100Repository: NVIDIA/Model-Optimizer
Length of output: 5585
🏁 Script executed:
# Search for get_onnx_bytes_and_metadata more carefully
rg -n "get_onnx_bytes_and_metadata\|get_onnx_bytes"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look for uses of FP8QuantExporter more comprehensively
rg -rn "FP8QuantExporter" modelopt/Repository: NVIDIA/Model-Optimizer
Length of output: 667
🏁 Script executed:
# Check if there's an export function that uses the exporter classes
rg -n "def export\|def get_onnx" modelopt/onnx -B 2 -A 5 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the torch_onnx.py to see how exporters are used
cat -n modelopt/torch/_deploy/utils/torch_onnx.py | grep -A 10 -B 10 "FP8QuantExporter\|n\|onnx_exporters"Repository: NVIDIA/Model-Optimizer
Length of output: 31963
🏁 Script executed:
# Check the actual file content more carefully
grep -n "class.*Exporter" modelopt/onnx/export/*.pyRepository: NVIDIA/Model-Optimizer
Length of output: 564
🏁 Script executed:
# Look at what might be calling process_model on FP8QuantExporter
rg -n "process_model\|post_process" modelopt/torchRepository: NVIDIA/Model-Optimizer
Length of output: 48
Add opset >= 19 validation before FP8 Q/DQ conversion in FP8QuantExporter.post_process().
The code converts TRT custom ops to native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and the saturate attribute, but does not verify that the model's opset is >= 19 (the minimum required for these operators). When callers invoke get_onnx_bytes_and_metadata() with onnx_opset < 19 on a FP8-quantized model, the post-processor will silently generate an invalid ONNX model instead of upgrading the opset or raising an error.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, In
FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).
| onnx_opt_graph = quantize_weights(model, onnx_opt_graph) | ||
|
|
||
| if dq_only: | ||
| onnx_opt_graph = qdq_to_dq(onnx_opt_graph) | ||
|
|
||
| try: | ||
| # TODO: Single-precision torch model assumed | ||
| param_dtype = next(model.parameters()).dtype | ||
| except StopIteration: | ||
| param_dtype = torch.float32 | ||
| if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: | ||
| if is_int4_quantized(model) or is_mxfp8_quantized(model): | ||
| assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" | ||
| onnx_opt_graph = convert_float_to_float16( | ||
| onnx_opt_graph, | ||
| keep_io_types=False, | ||
| disable_shape_infer=True, | ||
| check_fp16_ready=False, | ||
| ) | ||
| else: | ||
| onnx_opt_graph = convert_to_f16( | ||
| onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False | ||
| ) | ||
| if weights_dtype == "fp16": | ||
| onnx_opt_graph = convert_float_to_float16( | ||
| onnx_opt_graph, | ||
| keep_io_types=False, | ||
| disable_shape_infer=True, | ||
| check_fp16_ready=False, | ||
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) | ||
|
|
||
| # TensorRT expects all scales to be postive | ||
| onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) | ||
| onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) | ||
|
|
||
| # TensorRT expects all scales to be postive | ||
| onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) | ||
|
|
||
| # TODO: Remove manual ir_version change once ORT supports ir_version 11 | ||
| # Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version | ||
| onnx_opt_graph.ir_version = 10 |
There was a problem hiding this comment.
Rebuild model_metadata after the final ONNX rewrites.
Lines 572-574 capture metadata before quantize_weights(), convert_float_to_float16(), change_casts_to_fp16(), and remove_redundant_casts(). Those passes add/remove nodes and can rewrite I/O tensors, so the returned metadata can drift from the serialized model bytes.
🛠️ Proposed fix
Move the metadata creation block below the last graph mutation.
- model_metadata = create_model_metadata(
- tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
- )Then re-add it after onnx_opt_graph.ir_version = 10:
+ model_metadata = create_model_metadata(
+ tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 576 - 599, The
model_metadata is built too early and can become stale after graph rewrites;
move the metadata creation so it runs after all ONNX mutations: after
quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.
| if weights_dtype == "fp16": | ||
| onnx_opt_graph = convert_float_to_float16( | ||
| onnx_opt_graph, | ||
| keep_io_types=False, | ||
| disable_shape_infer=True, | ||
| check_fp16_ready=False, | ||
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd modelopt && wc -l torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 103
🏁 Script executed:
cd modelopt && sed -n '500,515p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 760
🏁 Script executed:
cd modelopt && sed -n '575,595p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 831
🏁 Script executed:
cd modelopt && sed -n '480,530p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2582
🏁 Script executed:
cd modelopt && sed -n '550,600p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1877
🏁 Script executed:
cd modelopt && grep -n "weights_dtype" torch/_deploy/utils/torch_onnx.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 424
🏁 Script executed:
cd modelopt && grep -B5 -A5 "weights_dtype == \"fp16\"" torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 413
🏁 Script executed:
cd modelopt && grep -n "bf16\|fp16" torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 378
🏁 Script executed:
cd modelopt && grep -r "convert_float_to_float16" --include="*.py" | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 242
🏁 Script executed:
cd modelopt && grep -B10 "use_torch_autocast" torch/_deploy/utils/torch_onnx.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 786
🏁 Script executed:
find modelopt -name "*test*torch_onnx*" -o -name "*torch_onnx*test*" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cd modelopt && grep -r "weights_dtype.*fp16\|fp16.*weights_dtype" --include="*.py" tests/ 2>/dev/null | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cd modelopt && git log --oneline -20 -- torch/_deploy/utils/torch_onnx.py 2>/dev/null | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 116
🏁 Script executed:
cd modelopt && sed -n '440,460p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1172
🏁 Script executed:
cd modelopt && sed -n '500,595p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3831
🏁 Script executed:
cd modelopt && git log -p --all -- torch/_deploy/utils/torch_onnx.py 2>/dev/null | grep -A20 -B5 "use_torch_autocast\|weights_dtype.*fp16" | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cd modelopt && find . -path ./modelopt -prune -o -name "*.py" -type f -exec grep -l "weights_dtype.*fp16\|fp16.*weights_dtype" {} \; 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 101
🏁 Script executed:
cd modelopt && find . -name "*.py" -type f -exec grep -l "torch_onnx.*fp16\|export.*fp16" {} \; 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cd modelopt && sed -n '1,50p' torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1648
🏁 Script executed:
cd modelopt && git diff HEAD~5..HEAD -- torch/_deploy/utils/torch_onnx.py 2>/dev/null | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cd modelopt && git show HEAD:torch/_deploy/utils/torch_onnx.py 2>/dev/null | sed -n '500,520p'Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cd modelopt && grep -n "autocast\|bf16" torch/_deploy/utils/torch_onnx.pyRepository: NVIDIA/Model-Optimizer
Length of output: 419
🏁 Script executed:
cd modelopt && python3 << 'EOF'
# Check the logic of use_torch_autocast for different weights_dtype values
def analyze_autocast_logic(is_fp4, is_mxfp8, weights_dtype):
use_torch_autocast = not (
is_fp4 or is_mxfp8 or weights_dtype == "fp32"
)
return use_torch_autocast
print("Current logic (assuming not quantized):")
for dtype in ["fp32", "fp16", "bf16"]:
result = analyze_autocast_logic(False, False, dtype)
print(f" weights_dtype={dtype}: use_torch_autocast={result}")
print("\nProposed logic (assuming not quantized):")
def proposed_logic(is_fp4, is_mxfp8, weights_dtype):
use_torch_autocast = weights_dtype == "bf16" and not (is_fp4 or is_mxfp8)
return use_torch_autocast
for dtype in ["fp32", "fp16", "bf16"]:
result = proposed_logic(False, False, dtype)
print(f" weights_dtype={dtype}: use_torch_autocast={result}")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 426
The FP16 export path mixes autocast with explicit graph conversion, which may cause unintended behavior.
Lines 505-508 enable torch.autocast("cuda") for both FP16 and BF16 (weights_dtype != "fp32" and not quantized). This means the model is traced under autocast, which can promote operations to lower precision. Then at lines 581-588, convert_float_to_float16() applies explicit post-export conversion. This dual approach can conflict—autocast during tracing may insert different operators than expected, undermining the explicit conversion strategy.
BF16 benefits from autocast (mixed precision promotion during tracing), but FP16 should rely solely on post-export graph conversion without autocast interference.
Recommended fix: Disable autocast for FP16 and keep it only for BF16:
- use_torch_autocast = not (
- is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
- )
+ use_torch_autocast = weights_dtype == "bf16" and not (
+ is_fp4_quantized(model) or is_mxfp8_quantized(model)
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 581 - 588, The FP16
export path should not use torch.autocast during tracing because you already
perform explicit post-export conversion with convert_float_to_float16; update
the autocast logic so the torch.autocast("cuda") context is only entered when
weights_dtype == "bf16" (and not when weights_dtype == "fp16"), i.e., change the
condition that currently enables autocast for weights_dtype != "fp32" to
specifically check for "bf16" and leave the FP16 path to rely solely on
convert_float_to_float16; keep the convert_float_to_float16 call for FP16
unchanged.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/onnx/utils.py (1)
1507-1511:⚠️ Potential issue | 🟠 MajorGuard FP32→FP16 cast rewrites by input dtype.
At Line 1507, this rewrites any
Cast(to=FLOAT)feeding target ops, including casts from non-FP32 sources. That can silently change behavior by removing intentional upcasts.💡 Proposed fix
for node in model.graph.node: if node.op_type != "Cast": continue @@ if not feeds_target: continue + cast_input_type = _get_tensor_type_by_name(model, node.input[0]) + if cast_input_type != onnx.TensorProto.FLOAT: + continue + # Check if Cast is to FP32, and change to FP16 for attr in node.attribute: if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: attr.i = onnx.TensorProto.FLOAT16 break
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
921-923: Prefer publiconnx_utilsAPIs over underscored helpers across modules.Using
_get_tensor_type_by_name,_bypass_cast_node, and_is_same_type_castfrom another module couples this class to private implementation details. Exposing public wrappers inmodelopt/onnx/utils.pywould make this boundary safer.Also applies to: 942-942, 1097-1102
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/autocast/precisionconverter.py` around lines 921 - 923, The code currently calls private helpers onnx_utils._get_tensor_type_by_name, onnx_utils._bypass_cast_node, and onnx_utils._is_same_type_cast from precisionconverter.py; replace these calls with public wrapper APIs (e.g., get_tensor_type_by_name, bypass_cast_node, is_same_type_cast) exported from modelopt/onnx/utils.py and update precisionconverter.py to call those public names (also update the other locations that use the underscored helpers around lines referenced). Add the small public wrapper implementations in modelopt/onnx/utils.py that delegate to the existing private functions so other modules use the stable public API and then run tests to ensure behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/utils.py`:
- Around line 1296-1322: The current _is_sequential_cast only compares the two
Cast target types; modify it to also fetch the data type of the original source
feeding the first Cast (e.g., inspect the producer of node.input[0] or its
ValueInfo/initializer) and verify that this source type equals the second cast's
target type (the value returned by get_cast_to_type(next_node)) before returning
True; this extra check ensures that when _bypass_cast_node rewires the graph the
source type is compatible with the second Cast. Use get_consumer_nodes,
get_cast_to_type and the node.input[0] producer lookup to locate and compare
types.
---
Nitpick comments:
In `@modelopt/onnx/autocast/precisionconverter.py`:
- Around line 921-923: The code currently calls private helpers
onnx_utils._get_tensor_type_by_name, onnx_utils._bypass_cast_node, and
onnx_utils._is_same_type_cast from precisionconverter.py; replace these calls
with public wrapper APIs (e.g., get_tensor_type_by_name, bypass_cast_node,
is_same_type_cast) exported from modelopt/onnx/utils.py and update
precisionconverter.py to call those public names (also update the other
locations that use the underscored helpers around lines referenced). Add the
small public wrapper implementations in modelopt/onnx/utils.py that delegate to
the existing private functions so other modules use the stable public API and
then run tests to ensure behavior is unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 188d9d44-97fd-4011-abb0-37b5f6bdbf27
📒 Files selected for processing (5)
modelopt/onnx/autocast/graphsanitizer.pymodelopt/onnx/autocast/precisionconverter.pymodelopt/onnx/autocast/utils.pymodelopt/onnx/utils.pytests/unit/onnx/autocast/test_precisionconverter.py
| def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: | ||
| assert node.op_type == "Cast" | ||
| output_type = get_cast_to_type(node) | ||
|
|
||
| # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed | ||
| # Cast to low precision -> cast to high precision affects precision and should not be removed | ||
| precision_order = [ | ||
| onnx.TensorProto.DOUBLE, | ||
| onnx.TensorProto.FLOAT, | ||
| onnx.TensorProto.FLOAT16, | ||
| onnx.TensorProto.BFLOAT16, | ||
| ] | ||
| consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"] | ||
|
|
||
| # If the first cast has additional consumers, we should not remove it | ||
| if len(consumers) != 1: | ||
| return False | ||
|
|
||
| next_node = consumers[0] | ||
| first_cast_type = output_type | ||
| second_cast_type = get_cast_to_type(next_node) | ||
|
|
||
| return ( | ||
| first_cast_type in precision_order | ||
| and second_cast_type in precision_order | ||
| and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and read the function
cat -n modelopt/onnx/utils.py | sed -n '1290,1330p'Repository: NVIDIA/Model-Optimizer
Length of output: 1919
🏁 Script executed:
# Also check if the helper function _get_tensor_type_by_name exists (used in proposed fix)
rg "_get_tensor_type_by_name" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 216
🏁 Script executed:
# Search for get_cast_to_type and get_consumer_nodes to understand them
rg "def get_cast_to_type|def get_consumer_nodes" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 211
🏁 Script executed:
# Check if there are tests for this function
fd "test.*\.py" | xargs rg "_is_sequential_cast" 2>/dev/null || echo "No tests found"Repository: NVIDIA/Model-Optimizer
Length of output: 80
🏁 Script executed:
# Find where _is_sequential_cast is called
rg "_is_sequential_cast" modelopt/onnx/utils.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 440
🏁 Script executed:
# Look at the context around _bypass_cast_node to understand what happens when removal occurs
cat -n modelopt/onnx/utils.py | sed -n '1325,1360p'Repository: NVIDIA/Model-Optimizer
Length of output: 2048
🏁 Script executed:
# Search for comments or documentation about cast removal logic
rg "sequential.*cast|cast.*removal|bypass.*cast" modelopt/onnx/utils.py -i -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 1472
🏁 Script executed:
# Look for any test files that might test cast optimization
fd "test.*\.py" modelopt/onnx/ | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Get more context about _is_same_type_cast to understand the overall pattern
cat -n modelopt/onnx/utils.py | sed -n '1285,1295p'Repository: NVIDIA/Model-Optimizer
Length of output: 572
🏁 Script executed:
# Look for the full remove_redundant_casts function to understand the overall flow
rg "def remove_redundant_casts" modelopt/onnx/utils.py -A 30Repository: NVIDIA/Model-Optimizer
Length of output: 1393
🏁 Script executed:
# Check if there's any documentation or related test about safe cast removal
fd "*.py" modelopt/onnx/ | xargs grep -l "cast\|Cast" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 349
🏁 Script executed:
# Look at how _get_tensor_type_by_name works and what it can return
rg "def _get_tensor_type_by_name" modelopt/onnx/utils.py -A 20Repository: NVIDIA/Model-Optimizer
Length of output: 991
🏁 Script executed:
# Check for any tests or usage examples that show what types node.input can have
fd ".py" modelopt/onnx/ --type f | xargs grep -l "remove_redundant_casts\|_is_sequential_cast" 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 133
🏁 Script executed:
# Look at the overall structure of the optimization and find if there are any tests
fd ".py" --type f modelopt/onnx/ | xargs grep -l "test.*cast\|cast.*test" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check the precisionconverter file to see if it has any tests or usage
cat -n modelopt/onnx/autocast/precisionconverter.py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 2246
🏁 Script executed:
# Let me understand the logic better by looking at a concrete example
# First, let me see if there are any examples or documentation about this function
rg "sequential.*cast" modelopt/ -i --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 555
🏁 Script executed:
# Check if the source type is ever checked anywhere in the function
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p' | grep -E "input|src|source"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the broader context to understand how the function comment relates to the logic
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p'Repository: NVIDIA/Model-Optimizer
Length of output: 1323
🏁 Script executed:
# Let me trace through a concrete scenario to understand the issue better
# Looking at the comment and logic more carefully
cat -n modelopt/onnx/utils.py | sed -n '1296,1325p'Repository: NVIDIA/Model-Optimizer
Length of output: 1415
🏁 Script executed:
# Understand what happens when we remove the first cast by looking at bypass logic
cat -n modelopt/onnx/utils.py | sed -n '1325,1355p'Repository: NVIDIA/Model-Optimizer
Length of output: 1685
Sequential-cast removal must validate the original source type against the final cast target.
The function only compares the two cast target types (lines 1318–1322) but doesn't verify that removing the first cast preserves the input type to the second cast. When _bypass_cast_node rewires the graph, it directly connects the original source to the second cast. If the source type differs from what the second cast expects, this changes behavior.
For example:
FLOAT16 → cast(FLOAT) → cast(FLOAT16)passes the current check (FLOAT ≤ FLOAT16)- After removal:
FLOAT16 → cast(FLOAT16)— but cast2 was designed for FLOAT input, causing incorrect behavior
Add a check to ensure the source type matches the second cast's target type before removal.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1296 - 1322, The current
_is_sequential_cast only compares the two Cast target types; modify it to also
fetch the data type of the original source feeding the first Cast (e.g., inspect
the producer of node.input[0] or its ValueInfo/initializer) and verify that this
source type equals the second cast's target type (the value returned by
get_cast_to_type(next_node)) before returning True; this extra check ensures
that when _bypass_cast_node rewires the graph the source type is compatible with
the second Cast. Use get_consumer_nodes, get_cast_to_type and the node.input[0]
producer lookup to locate and compare types.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
40ce80f to
05c33b2
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (3)
modelopt/torch/_deploy/utils/torch_onnx.py (2)
576-576:⚠️ Potential issue | 🟠 MajorRebuild
model_metadataafter the last graph rewrite.Starting with
quantize_weights(), this function mutates node names, tensor dtypes, and Q/DQ structure aftermodel_metadatahas already been captured above. The returned metadata can therefore describe a different graph than the bytes written to disk.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` at line 576, model_metadata is captured before quantize_weights mutates the ONNX graph (node names, tensor dtypes, Q/DQ structure), so the saved metadata can be out of sync with onnx_opt_graph; after the final graph rewrite (the call to quantize_weights that returns onnx_opt_graph) re-run the metadata extraction routine (the same function that produced model_metadata earlier) to rebuild model_metadata from the mutated onnx_opt_graph so the metadata matches the bytes written to disk; update any downstream uses to reference the new model_metadata variable produced after quantize_weights.
581-592:⚠️ Potential issue | 🟠 MajorThe FP16 path is still “autocast +
convert_float_to_float16()”.With this new rewrite block,
weights_dtype="fp16"still traces undertorch.autocast("cuda")earlier in the function, so FP16 export is not actually usingconvert_float_to_float16()instead of autocast. That makes the exported graph depend on both mechanisms and undermines the stated pipeline change.🛠️ Suggested earlier-function change
- use_torch_autocast = not ( - is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" - ) + use_torch_autocast = weights_dtype == "bf16" and not ( + is_fp4_quantized(model) or is_mxfp8_quantized(model) + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 581 - 592, The current FP16 branch still runs under torch.autocast earlier, so the export uses both autocast and convert_float_to_float16; modify the export flow so when weights_dtype == "fp16" you do NOT run the model export under torch.autocast("cuda") (or skip the autocast context) so the graph is produced in full-precision then transformed only by convert_float_to_float16 and change_casts_to_fp16; locate the autocast usage earlier in this module (torch.autocast or the export context manager) and add a conditional to bypass it when weights_dtype == "fp16", ensuring convert_float_to_float16/remove_redundant_casts are the sole FP16 transformations.modelopt/onnx/export/fp8_exporter.py (1)
121-147:⚠️ Potential issue | 🟠 MajorReject native FP8 Q/DQ export below opset 19.
This conversion emits native
QuantizeLinear/DequantizeLinearwithFLOAT8E4M3FNandsaturate, but it still doesn't guard againstonnx_opset < 19. Exporting FP8 with a lower opset will silently produce an invalid model instead of failing early or upgrading the opset.🛠️ Suggested guard
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + opset = next( + (op.version for op in onnx_model.opset_import if op.domain in ("", "ai.onnx")), + 0, + ) + if opset < 19: + raise ValueError("Native FP8 ONNX Q/DQ requires ai.onnx opset >= 19.") + logger.info("Post-processing FP8 quantized model") graph = gs.import_onnx(onnx_model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, The conversion loop for TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx opset < 19; detect the model/export opset (e.g., an existing variable or by examining graph.opset or a passed opset parameter) before performing the conversions in the loops that change node.op to "QuantizeLinear"/"DequantizeLinear" and set FLOAT8E4M3FN/saturate, and if the opset is less than 19 either raise a clear exception or upgrade the opset to >= 19 before modifying nodes (apply this check where you manipulate node.op and node.attrs in the FP8 exporter function that iterates graph.nodes).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 127-134: The injected zero-point Constant currently uses node.name
which may be empty and cause duplicate tensor names; change the naming for the
Constant created from zp_tensor/zp_values/zero_point to a guaranteed-unique
string (e.g., combine node.name when present with a unique suffix such as a
uuid4 or the node's memory id or an incrementing counter, or use an ONNX/graph
helper that returns a unique name) so each FP8 zero-point Constant has a
distinct tensor name even for unnamed TRT FP8 Q nodes.
In `@modelopt/onnx/utils.py`:
- Around line 1262-1276: The helper _get_tensor_type_by_name must also handle
producer-only tensors emitted by nodes (e.g., a Constant node that produces a
tensor but has no value_info entry); modify _get_tensor_type_by_name to iterate
model.graph.node and when a node.output matches tensor_name, if node.op_type ==
"Constant" extract the TensorProto from the node attribute (attribute named
"value") and return its data_type (or elem_type equivalent), otherwise
skip/continue so non-materialized producer-only tensors do not cause an
exception and allow remove_redundant_casts() to fold Constant->Cast patterns.
- Around line 1403-1417: The two cast-removal checks can both match the same
node causing duplicate removal; make them mutually exclusive by ensuring once a
node is handled by _is_sequential_cast(onnx_model, node) (where you call
_bypass_cast_node and append to nodes_to_remove) you skip the subsequent
_is_foldable_constant_cast_pattern check (e.g., use an elif or continue) so you
only call _bypass_cast_node, _convert_constant_values and append to
nodes_to_remove once; update the block containing _is_sequential_cast,
_is_foldable_constant_cast_pattern, _bypass_cast_node, _convert_constant_values,
get_producer_nodes, nodes_to_remove, and logger.debug accordingly.
- Around line 1499-1511: The current logic uses any(...) on tensor_to_consumers
to set Casts to FP16 even if only one consumer is in target_op_types, which
incorrectly changes shared Cast outputs; modify the check so the Cast is
retargeted only when the entire fanout is eligible (replace the any(...) test
with an all(...) test, and treat empty consumer lists as ineligible/skip), then
keep the existing loop over node.attribute that looks for attr.name == "to" and
change attr.i from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 only when
that all-consumers condition holds.
---
Duplicate comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: The conversion loop for
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx opset <
19; detect the model/export opset (e.g., an existing variable or by examining
graph.opset or a passed opset parameter) before performing the conversions in
the loops that change node.op to "QuantizeLinear"/"DequantizeLinear" and set
FLOAT8E4M3FN/saturate, and if the opset is less than 19 either raise a clear
exception or upgrade the opset to >= 19 before modifying nodes (apply this check
where you manipulate node.op and node.attrs in the FP8 exporter function that
iterates graph.nodes).
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Line 576: model_metadata is captured before quantize_weights mutates the ONNX
graph (node names, tensor dtypes, Q/DQ structure), so the saved metadata can be
out of sync with onnx_opt_graph; after the final graph rewrite (the call to
quantize_weights that returns onnx_opt_graph) re-run the metadata extraction
routine (the same function that produced model_metadata earlier) to rebuild
model_metadata from the mutated onnx_opt_graph so the metadata matches the bytes
written to disk; update any downstream uses to reference the new model_metadata
variable produced after quantize_weights.
- Around line 581-592: The current FP16 branch still runs under torch.autocast
earlier, so the export uses both autocast and convert_float_to_float16; modify
the export flow so when weights_dtype == "fp16" you do NOT run the model export
under torch.autocast("cuda") (or skip the autocast context) so the graph is
produced in full-precision then transformed only by convert_float_to_float16 and
change_casts_to_fp16; locate the autocast usage earlier in this module
(torch.autocast or the export context manager) and add a conditional to bypass
it when weights_dtype == "fp16", ensuring
convert_float_to_float16/remove_redundant_casts are the sole FP16
transformations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7704d15e-7bdf-4b56-b725-8f52b69a07e2
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/onnx/autocast/graphsanitizer.pymodelopt/onnx/autocast/precisionconverter.pymodelopt/onnx/autocast/utils.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/export/nvfp4_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pytests/unit/onnx/autocast/test_precisionconverter.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/onnx/autocast/test_precisionconverter.py
| # Create FP8 zero point constant | ||
| zp_tensor = onnx.TensorProto() | ||
| zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| zp_tensor.dims.extend([1]) # 1-element tensor | ||
| zp_tensor.raw_data = b"\x00" # Zero in FP8 | ||
| zp_values = LazyValues(zp_tensor) | ||
| zero_point = gs.Constant(node.name + "_zero_point", zp_values) | ||
| node.inputs.append(zero_point) |
There was a problem hiding this comment.
Use a guaranteed-unique tensor name for the injected zero point.
node.name is optional in ONNX, so node.name + "_zero_point" can collapse to the same tensor name for multiple unnamed TRT FP8 Q nodes. That can make the exported graph invalid due to duplicate tensor names.
🛠️ Safer naming
- zero_point = gs.Constant(node.name + "_zero_point", zp_values)
+ zero_point = gs.Constant(f"{node.outputs[0].name}_zero_point", zp_values)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/export/fp8_exporter.py` around lines 127 - 134, The injected
zero-point Constant currently uses node.name which may be empty and cause
duplicate tensor names; change the naming for the Constant created from
zp_tensor/zp_values/zero_point to a guaranteed-unique string (e.g., combine
node.name when present with a unique suffix such as a uuid4 or the node's memory
id or an incrementing counter, or use an ONNX/graph helper that returns a unique
name) so each FP8 zero-point Constant has a distinct tensor name even for
unnamed TRT FP8 Q nodes.
| def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str): | ||
| """Get the tensor element type. Searches value_info, initializers, inputs, and outputs.""" | ||
| for vi in model.graph.value_info: | ||
| if vi.name == tensor_name: | ||
| return vi.type.tensor_type.elem_type | ||
| for init in model.graph.initializer: | ||
| if init.name == tensor_name: | ||
| return init.data_type | ||
| for inp in model.graph.input: | ||
| if inp.name == tensor_name: | ||
| return inp.type.tensor_type.elem_type | ||
| for out in model.graph.output: | ||
| if out.name == tensor_name: | ||
| return out.type.tensor_type.elem_type | ||
| raise Exception(f"did not find tensor {tensor_name}") |
There was a problem hiding this comment.
Teach _get_tensor_type_by_name() about producer-only tensors.
This helper only searches value_info, initializers, inputs, and outputs. A Cast fed by a Constant output that is not materialized in value_info now throws here before remove_redundant_casts() can fold the Constant -> Cast pattern.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1262 - 1276, The helper
_get_tensor_type_by_name must also handle producer-only tensors emitted by nodes
(e.g., a Constant node that produces a tensor but has no value_info entry);
modify _get_tensor_type_by_name to iterate model.graph.node and when a
node.output matches tensor_name, if node.op_type == "Constant" extract the
TensorProto from the node attribute (attribute named "value") and return its
data_type (or elem_type equivalent), otherwise skip/continue so non-materialized
producer-only tensors do not cause an exception and allow
remove_redundant_casts() to fold Constant->Cast patterns.
| # Find sequential casts that don't change precision | ||
| if _is_sequential_cast(onnx_model, node): | ||
| nodes_to_remove.append(node) | ||
| _bypass_cast_node(onnx_model, node) | ||
| logger.debug(f"Found removable double-cast: {node.name}") | ||
|
|
||
| # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. | ||
| if _is_foldable_constant_cast_pattern(onnx_model, node): | ||
| nodes_to_remove.append(node) | ||
| cast_producers = get_producer_nodes(onnx_model, node.input[0]) | ||
| assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" | ||
| constant_producer = cast_producers[0] | ||
| _convert_constant_values(constant_producer, node) | ||
| _bypass_cast_node(onnx_model, node) | ||
| logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") |
There was a problem hiding this comment.
Make the redundant-cast branches mutually exclusive.
A node can satisfy both _is_sequential_cast() and _is_foldable_constant_cast_pattern() (for example Constant -> Cast -> Cast). In that case this block bypasses it twice, appends it twice, and the second graph.node.remove(node) will fail.
🛠️ Minimal fix
if _is_sequential_cast(onnx_model, node):
nodes_to_remove.append(node)
_bypass_cast_node(onnx_model, node)
logger.debug(f"Found removable double-cast: {node.name}")
+ continue
# Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
if _is_foldable_constant_cast_pattern(onnx_model, node):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1403 - 1417, The two cast-removal checks
can both match the same node causing duplicate removal; make them mutually
exclusive by ensuring once a node is handled by _is_sequential_cast(onnx_model,
node) (where you call _bypass_cast_node and append to nodes_to_remove) you skip
the subsequent _is_foldable_constant_cast_pattern check (e.g., use an elif or
continue) so you only call _bypass_cast_node, _convert_constant_values and
append to nodes_to_remove once; update the block containing _is_sequential_cast,
_is_foldable_constant_cast_pattern, _bypass_cast_node, _convert_constant_values,
get_producer_nodes, nodes_to_remove, and logger.debug accordingly.
| # Check if this Cast outputs to a target op type | ||
| cast_output = node.output[0] | ||
| consumers = tensor_to_consumers.get(cast_output, []) | ||
| feeds_target = any(c.op_type in target_op_types for c in consumers) | ||
|
|
||
| if not feeds_target: | ||
| continue | ||
|
|
||
| # Check if Cast is to FP32, and change to FP16 | ||
| for attr in node.attribute: | ||
| if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: | ||
| attr.i = onnx.TensorProto.FLOAT16 | ||
| break |
There was a problem hiding this comment.
Only retarget Casts whose entire fanout is eligible.
any(...) means one Concat/Add consumer is enough to flip a shared Cast to FP16 for every consumer of that tensor. If the cast output also feeds a non-target branch, that branch is silently changed too.
🛠️ Safe fallback
- feeds_target = any(c.op_type in target_op_types for c in consumers)
+ feeds_target = bool(consumers) and all(c.op_type in target_op_types for c in consumers)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1499 - 1511, The current logic uses
any(...) on tensor_to_consumers to set Casts to FP16 even if only one consumer
is in target_op_types, which incorrectly changes shared Cast outputs; modify the
check so the Cast is retargeted only when the entire fanout is eligible (replace
the any(...) test with an all(...) test, and treat empty consumer lists as
ineligible/skip), then keep the existing loop over node.attribute that looks for
attr.name == "to" and change attr.i from onnx.TensorProto.FLOAT to
onnx.TensorProto.FLOAT16 only when that all-consumers condition holds.
Review CommentsThanks for the PR — the core idea of replacing TRT-specific FP8 QDQ nodes with native ONNX ops is solid, and the refactoring to centralize cast utilities in 1. BF16 regression in
|
What does this PR do?
Type of change:
New feature
Overview:
Testing
Results:
Before replacement:
After replacement:
Before your PR is "Ready for review"
Summary by CodeRabbit
New Features
Improvements
Tests
Changelog