Skip to content

OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852

Open
ajrasane wants to merge 7 commits intomainfrom
ajrasane/onnx_qdq
Open

OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
ajrasane wants to merge 7 commits intomainfrom
ajrasane/onnx_qdq

Conversation

@ajrasane
Copy link
Contributor

@ajrasane ajrasane commented Feb 4, 2026

What does this PR do?

Type of change:
New feature

Overview:

  • Updated FP8 quant exporter to replace modelopt custom QDQ nodes with native ONNX QDQ nodes
  • Updated get_onnx_bytes_and_metadata to make convert_float_to_float16() default instead of autocast
  • Created util functions to fix graph structure after conversion

Testing

python torch_quant_to_onnx.py --quantize_mode=fp8 \
	--onnx_save_path=<model_path> \
	--calibration_data_size 64 \
	--batch_size 128

python evaluate.py --onnx_path=<model_path> \
	--model_name=vit_base_patch16_224 \
	--results_path=./results.txt \
	--batch_size 128

Results:
Before replacement:

The top1 accuracy of the model is 85.06%
The top5 accuracy of the model is 97.558%
Inference latency of the model is 5.27963 ms

After replacement:

The top1 accuracy of the model is 85.054%
The top5 accuracy of the model is 97.542%
Inference latency of the model is 5.74771 ms

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: No
  • Replaced modelopt QDQ nodes with native ONNX qdq nodes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Summary by CodeRabbit

  • New Features

    • ONNX utilities added to remove redundant Casts, fold Constant→Cast patterns, and convert targeted Casts to FP16.
  • Improvements

    • FP8 QDQ nodes are converted to native ONNX QDQ/Dequantize nodes for better compatibility.
    • Export pipeline streamlined: more consistent FP16 handling, unified weight quantization, cast cleanup, and added logging for traceability.
  • Tests

    • Unit tests updated to use the new ONNX utilities.
  • Changelog

    • Entry added noting FP8 QDQ → native ONNX QDQ conversion.

@ajrasane ajrasane requested review from a team as code owners February 4, 2026 01:08
@ajrasane ajrasane requested a review from cjluo-nv February 4, 2026 01:08
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 4, 2026

📝 Walkthrough

Walkthrough

Adds FP8 post-processing to convert TRT-specific FP8 QDQ nodes to native ONNX QDQ nodes, introduces ONNX Cast utilities (redundant-cast removal and targeted FP16 cast conversion), and updates Torch ONNX export flow to use these utilities with a reordered FP16/quantization pipeline.

Changes

Cohort / File(s) Summary
FP8 Export Post-Processing
modelopt/onnx/export/fp8_exporter.py
Replaced no-op post_process with logic converting TRT_FP8QuantizeLinearQuantizeLinear (adds FP8 zero_point when missing, sets saturate) and TRT_FP8DequantizeLinearDequantizeLinear; performs graph cleanup/toposort, re-exports ONNX, and adds logging. Updated compress_weights docs/comments.
ONNX Cast Utilities
modelopt/onnx/utils.py
Added cast-focused utilities: read Cast target type, get producer/consumer nodes, stash/replace tensor names, detect/fold redundant casts (same-type, sequential, Constant→Cast), bypass/rewrite connections, fold Constant→Cast, remove_redundant_casts(onnx.ModelProto), and change_casts_to_fp16(model, target_op_types). Minor tweak to randomize_weights_onnx_bytes metadata access.
Precision Converter (autocast)
modelopt/onnx/autocast/precisionconverter.py
Removed several in-class cast-management helpers and delegated those responsibilities to centralized onnx_utils functions (producer/consumer lookups, bypassing, type checks, redundant-cast removal). Updated flows to call onnx_utils.remove_redundant_casts(self.model) and related helpers.
Autocast Utilities / GraphSanitizer
modelopt/onnx/autocast/utils.py, modelopt/onnx/autocast/graphsanitizer.py
Replaced local producer/consumer/cast helpers with calls to onnx_utils (removed get_consumer_nodes, get_producer_nodes, get_cast_to_type from autocast utils). GraphSanitizer updated to use onnx_utils equivalents.
Torch ONNX Integration
modelopt/torch/_deploy/utils/torch_onnx.py
Imported and exposed change_casts_to_fp16 and remove_redundant_casts, patched onnxconverter_common.remove_unnecessary_cast_node with a suppress wrapper, and reordered FP16/quantization pipeline: always quantize_weights, apply FP16 conversion for weights, convert FP32 casts feeding Concat/Add to FP16, then run redundant-cast removal before final IR/external-data handling.
NVFP4 Exporter minor tweak
modelopt/onnx/export/nvfp4_exporter.py
Replaced dict.get(..., None) with dict.get(...) for initializer lookups in three places (stylistic; behavior unchanged).
Tests
tests/unit/onnx/autocast/test_precisionconverter.py
Updated test assertions to use onnx_utils.get_consumer_nodes instead of removed local utils.get_consumer_nodes.
Changelog
CHANGELOG.rst
Added entry noting modelopt FP8 QDQ nodes are replaced with native ONNX QDQ nodes (0.43).

Sequence Diagram(s)

sequenceDiagram
    participant Exporter as FP8 Exporter
    participant Graph as ONNX Graph
    participant TRT as TRT_FP8 Nodes
    participant Native as Native ONNX Ops
    participant Cleaner as Graph Cleaner
    Exporter->>Graph: scan for TRT_FP8QuantizeLinear / TRT_FP8DequantizeLinear
    Graph->>TRT: identify TRT_FP8QuantizeLinear nodes
    Exporter->>Graph: for each TRT_FP8QuantizeLinear -> create `zero_point` const if missing
    Exporter->>Graph: replace TRT_FP8QuantizeLinear with QuantizeLinear (set saturate)
    Graph->>TRT: identify TRT_FP8DequantizeLinear nodes
    Exporter->>Graph: replace TRT_FP8DequantizeLinear with DequantizeLinear
    Exporter->>Cleaner: invoke cleanup & topological sort
    Cleaner->>Graph: remove unused nodes, fix edges, toposort
    Cleaner->>Native: graph now uses native ONNX QDQ nodes
    Exporter->>Exporter: export cleaned ONNX model
    Note over Exporter,Cleaner: logger.info/debug traces conversions
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: replacing Model-Optimizer's custom FP8 QDQ nodes with native ONNX QDQ nodes, which is the core feature addition across the fp8_exporter and supporting utilities.
Security Anti-Patterns ✅ Passed PR introduces no security anti-patterns from SECURITY.md guidelines; all modified code operates on ONNX graph structures without unsafe deserialization or dynamic execution.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ajrasane/onnx_qdq
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 104-140: The post_process function's docstring mentions updating
GELU nodes to tanh approximation and inserting Cast nodes after Sqrt, but the
implementation in post_process only converts
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to
QuantizeLinear/DequantizeLinear; either remove or revise those docstring lines
to reflect current behavior, or implement the missing steps: locate GELU nodes
in graph.nodes and replace/modify them to the tanh-approx variant, and insert
Cast nodes immediately after Sqrt nodes' outputs; reference post_process,
TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, GELU, and Sqrt when making the
change.
- Around line 119-126: The FP8 zero-point tensor zp_tensor is missing explicit
shape metadata; update the creation of zp_tensor (used to build zero_point and
appended to node.inputs) to set its dims explicitly (e.g., call
zp_tensor.dims.extend([1]) for a 1-element tensor) so it matches other tensors
created in this module (see the FP8 weights tensor creation) and ensures ONNX
runtimes receive shape info.

In `@modelopt/onnx/utils.py`:
- Around line 1314-1349: In change_casts_to_fp16, only modify Cast nodes that
actually cast from FP32: for each Cast node (node.op_type == "Cast") look up the
source tensor name node.input[0] in graph.initializer, graph.input,
graph.value_info or graph.output to get its element_type and only change the
node.attribute "to" from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 if
the source dtype is FLOAT; also avoid changing Casts that are FP16->FP32 and add
a debug log entry when you modify a Cast (include node.name or node.output[0]
and original->new dtypes) to aid debugging.
🧹 Nitpick comments (1)
modelopt/onnx/utils.py (1)

1218-1261: Consider edge case where first Cast has multiple consumers.

The function checks len(node.outputs[0].outputs) != 1 (line 1231) to ensure the first Cast's output goes to exactly one node. However, this may be overly restrictive. If the first Cast feeds into a duplicate second Cast AND other nodes, you could still remove the duplicate Cast while preserving the connection to other consumers. The current logic skips this optimization opportunity.

This is a minor optimization opportunity and the current implementation is safe.

@codecov
Copy link

codecov bot commented Feb 4, 2026

Codecov Report

❌ Patch coverage is 83.33333% with 27 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.10%. Comparing base (58417e5) to head (05c33b2).

Files with missing lines Patch % Lines
modelopt/onnx/utils.py 84.12% 20 Missing ⚠️
modelopt/torch/_deploy/utils/torch_onnx.py 71.42% 4 Missing ⚠️
modelopt/onnx/autocast/precisionconverter.py 78.57% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #852      +/-   ##
==========================================
+ Coverage   70.07%   70.10%   +0.03%     
==========================================
  Files         221      221              
  Lines       25499    25505       +6     
==========================================
+ Hits        17869    17881      +12     
+ Misses       7630     7624       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gcunhase
Copy link
Contributor

gcunhase commented Feb 9, 2026

@ajrasane can you please add the before and after accuracy results in the PR description? I.e: with FP8 custom Q/DQ nodes vs with FP8 native Q/DQ nodes. Thanks!

@gcunhase
Copy link
Contributor

gcunhase commented Feb 9, 2026

Let's also add this change to the Changelog file.

op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate the goal/need of this function? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because after the convert_float_to_float16() function, one of the inputs for these nodes is FP16, while the other is FP32. Hence we run into a compilation issue with TensorRT. To fix this, I manually update them here for these operators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the explanation. Can you please update the docstring to give a bit more details? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run into an error while building the engine with TensorRT:

[03/12/2026-21:06:15] [E] Error[9]: Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [myelin_graph.h:attachExceptionMsgToGraph:1146] MyelinCheckException: operand.h:456: CHECK(is_tensor()) failed.  In compileGraph at optimizer/myelin/codeGenerator.cpp:1421
[03/12/2026-21:06:15] [E] Error[10]: IBuilder::buildSerializedNetworkToStream: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[x_cast_to_fp16...(Unnamed Layer* 1752) [ElementWise]]}. In computeCosts at optimizer/common/tactic/optimizer.cpp:4115)

I also remember that you had previously mentioned that autocast is not supposed to be used after quantization as it would need a separate design. Hence I removed it from here. Let me know if that is no more the case.

@ajrasane ajrasane requested a review from a team as a code owner February 13, 2026 14:09
@ajrasane ajrasane requested a review from galagam February 13, 2026 14:09
@gcunhase
Copy link
Contributor

gcunhase commented Feb 13, 2026

5.74771 ms

Accuracy looks good, any idea why perf is slower after this PR?

Also, can you please specify which model these numbers are for?

Thanks.

Copy link
Contributor

@gcunhase gcunhase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @galagam are you okay with making the redundant casts function a utils function? Thanks!

Copy link
Contributor

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?

logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}")

if removed_count > 0:
graph.cleanup().toposort()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall some issues with toposort.
If you see any failures do to it, we can probably omit, _bypass_cast maintains node sorting.

@galagam
Copy link
Contributor

galagam commented Feb 15, 2026

LGTM, @galagam are you okay with making the redundant casts function a utils function? Thanks!

AutoCast's unit testing covers this part well, and indeed, I see there's quite a few failures with this refactor.
Approved the general concept, but need to make sure we don't cause regressions/behavior changes for AutoCast.
Thanks.
@gcunhase @ajrasane

@ajrasane ajrasane force-pushed the ajrasane/onnx_qdq branch from 0186223 to 788313f Compare March 12, 2026 18:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

1151-1172: ⚠️ Potential issue | 🟠 Major

This shared cast cleanup can undo output-name preservation.

_cleanup() already fixes network output names before this call, but onnx_utils.remove_redundant_casts() bypasses output casts by replacing graph.outputs with the cast input. For redundant casts on model outputs, that reverts the exported output tensor name to the pre-cast name and can trip _sanity_check() or break the public I/O contract.

🧹 Nitpick comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)

62-71: Scope this onnxconverter_common workaround to the conversion call.

Patching the module at import time changes behavior process-wide, and suppress(AttributeError) hides every upstream AttributeError, not just the known list/attr bug. A temporary patch around convert_float_to_float16() is much safer, and it avoids making this module import brittle if the upstream symbol changes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 62 - 71, The current
import-time monkey-patch of _f16_module.remove_unnecessary_cast_node (using
_original_remove_unnecessary_cast_node and
_patched_remove_unnecessary_cast_node) is global and hides all AttributeError
via suppress(AttributeError); instead, scope the workaround only around the call
to convert_float_to_float16(): before calling convert_float_to_float16() save
the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal
wrapper that only catches the specific list/attribute error, call
convert_float_to_float16(), and finally restore the original function in a
try/finally so the patch is temporary and does not swallow unrelated
AttributeErrors or affect the rest of the process.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: In FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).

In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 576-599: The model_metadata is built too early and can become
stale after graph rewrites; move the metadata creation so it runs after all ONNX
mutations: after quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.
- Around line 581-588: The FP16 export path should not use torch.autocast during
tracing because you already perform explicit post-export conversion with
convert_float_to_float16; update the autocast logic so the
torch.autocast("cuda") context is only entered when weights_dtype == "bf16" (and
not when weights_dtype == "fp16"), i.e., change the condition that currently
enables autocast for weights_dtype != "fp32" to specifically check for "bf16"
and leave the FP16 path to rely solely on convert_float_to_float16; keep the
convert_float_to_float16 call for FP16 unchanged.

---

Nitpick comments:
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 62-71: The current import-time monkey-patch of
_f16_module.remove_unnecessary_cast_node (using
_original_remove_unnecessary_cast_node and
_patched_remove_unnecessary_cast_node) is global and hides all AttributeError
via suppress(AttributeError); instead, scope the workaround only around the call
to convert_float_to_float16(): before calling convert_float_to_float16() save
the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal
wrapper that only catches the specific list/attribute error, call
convert_float_to_float16(), and finally restore the original function in a
try/finally so the patch is temporary and does not swallow unrelated
AttributeErrors or affect the rest of the process.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a86e0e09-2562-47a8-9985-98fa922fd2f0

📥 Commits

Reviewing files that changed from the base of the PR and between 2ebf0a2 and 788313f.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • modelopt/onnx/autocast/precisionconverter.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/export/nvfp4_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/utils.py

Comment on lines +121 to +147
# Convert TRT_FP8QuantizeLinear to native QuantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
node.op = "QuantizeLinear"
# Add FP8 zero_point if not present
if len(node.inputs) == 2:
# Create FP8 zero point constant
zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1]) # 1-element tensor
zp_tensor.raw_data = b"\x00" # Zero in FP8
zp_values = LazyValues(zp_tensor)
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
node.inputs.append(zero_point)
# Add saturate attribute for FP8
node.attrs["saturate"] = 1
logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear")

# Convert TRT_FP8DequantizeLinear to native DequantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8DequantizeLinear":
node.op = "DequantizeLinear"
logger.debug(
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
)

graph.cleanup().toposort()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Using the official ONNX operator and type documentation, what is the minimum ai.onnx opset that supports FLOAT8E4M3FN with QuantizeLinear and DequantizeLinear, including the saturate attribute on QuantizeLinear?

💡 Result:

Minimum ai.onnx opset: 19.

  • QuantizeLinear first supports tensor(float8e4m3fn) (via its T2 type constraint) in since_version = 19, and this is also the version that defines the saturate attribute (applying to float8 quantization). [1]
  • DequantizeLinear first supports tensor(float8e4m3fn) (via its T1 type constraint) in since_version = 19. [2]

Sources:
[1] QuantizeLinear (ONNX operator doc, version 19) (onnx.ai)
[2] DequantizeLinear (ONNX operator doc, version 19) (onnx.ai)

Citations:


🏁 Script executed:

# Check the fp8_exporter.py file to understand the full context
# Look for the function containing this code and any opset checks
fd "fp8_exporter.py" -x wc -l {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 153


🏁 Script executed:

# Read the file to understand the function context and any opset validation
cat -n modelopt/onnx/export/fp8_exporter.py | head -200

Repository: NVIDIA/Model-Optimizer

Length of output: 7470


🏁 Script executed:

# Search for opset-related code and checks in the file
rg -n "opset|onnx_opset" modelopt/onnx/export/fp8_exporter.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Find the base exporter to understand the parent class
fd "base_exporter.py" modelopt/onnx/export

Repository: NVIDIA/Model-Optimizer

Length of output: 103


🏁 Script executed:

# Search for calls to post_process and get_onnx_bytes_and_metadata
rg -n "post_process|get_onnx_bytes_and_metadata" modelopt/onnx/export --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 942


🏁 Script executed:

# Check if there's opset validation in the base exporter
cat -n modelopt/onnx/export/base_exporter.py | head -300

Repository: NVIDIA/Model-Optimizer

Length of output: 2365


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata|onnx_opset" modelopt/onnx --type py -B 2 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 823


🏁 Script executed:

# Look for where FP8QuantExporter is instantiated/used
rg -n "FP8QuantExporter" modelopt/onnx --type py -B 2 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1792


🏁 Script executed:

# Check the quantizer file that likely calls the exporter
fd "quantizer.py" modelopt/onnx -x wc -l {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search more broadly for onnx_opset and opset parameter usage
rg -n "onnx_opset|opset_version" modelopt/onnx --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 3244


🏁 Script executed:

# Look for quantizer or export related files that handle opset
fd "quantizer|export" modelopt/onnx --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Check the quantizer base class
rg -n "class.*Quantizer" modelopt/onnx --type py -A 10 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the ort_patching.py file around the FP8 opset check
cat -n modelopt/onnx/quantization/ort_patching.py | sed -n '220,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1445


🏁 Script executed:

# Check where _check_opset_version is called and in what context
rg -n "_check_opset_version|check_opset_version" modelopt/onnx/quantization --type py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1253


🏁 Script executed:

# Look at the quantize.py file to understand the full flow
cat -n modelopt/onnx/quantization/quantize.py | sed -n '140,185p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2716


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata" modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the quantize_mode constants to understand FP8 mode handling
rg -n "QDQ_PRECISION_MIN_OPSET|fp8" modelopt/onnx/quantization/quantize.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 3436


🏁 Script executed:

# Check if there's a separate export entry point that may bypass opset checks
rg -n "def.*export|get_onnx_bytes" modelopt/onnx --type py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 473


🏁 Script executed:

# Search for quantize_fp8 function
rg -n "def quantize_fp8|def quantize" modelopt/onnx/quantization/fp8.py -A 10 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 604


🏁 Script executed:

# Look for where process_model is called
rg -n "process_model" modelopt/onnx --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 453


🏁 Script executed:

# Check the fp8.py file to see if it calls FP8QuantExporter.post_process
cat -n modelopt/onnx/quantization/fp8.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 7314


🏁 Script executed:

# Continue reading fp8.py to see where FP8QuantExporter is used
cat -n modelopt/onnx/quantization/fp8.py | sed -n '150,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5087


🏁 Script executed:

# Search for where FP8QuantExporter.process_model is called
rg -n "\.process_model|FP8QuantExporter" modelopt/onnx --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1639


🏁 Script executed:

# Check for any test files or examples that use FP8 export
fd "test.*fp8|fp8.*test" modelopt --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Search for where process_model is called with FP8QuantExporter
rg -n "process_model\|FP8QuantExporter" modelopt/onnx

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look in llm_export_utils since it showed up in opset search
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '155,170p'

Repository: NVIDIA/Model-Optimizer

Length of output: 379


🏁 Script executed:

# Check the entire quantize function end to see if it calls process_model or post_process
cat -n modelopt/onnx/quantization/fp8.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 5585


🏁 Script executed:

# Search for get_onnx_bytes_and_metadata more carefully
rg -n "get_onnx_bytes_and_metadata\|get_onnx_bytes"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for uses of FP8QuantExporter more comprehensively
rg -rn "FP8QuantExporter" modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 667


🏁 Script executed:

# Check if there's an export function that uses the exporter classes
rg -n "def export\|def get_onnx" modelopt/onnx -B 2 -A 5 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the torch_onnx.py to see how exporters are used
cat -n modelopt/torch/_deploy/utils/torch_onnx.py | grep -A 10 -B 10 "FP8QuantExporter\|n\|onnx_exporters"

Repository: NVIDIA/Model-Optimizer

Length of output: 31963


🏁 Script executed:

# Check the actual file content more carefully
grep -n "class.*Exporter" modelopt/onnx/export/*.py

Repository: NVIDIA/Model-Optimizer

Length of output: 564


🏁 Script executed:

# Look at what might be calling process_model on FP8QuantExporter
rg -n "process_model\|post_process" modelopt/torch

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Add opset >= 19 validation before FP8 Q/DQ conversion in FP8QuantExporter.post_process().

The code converts TRT custom ops to native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and the saturate attribute, but does not verify that the model's opset is >= 19 (the minimum required for these operators). When callers invoke get_onnx_bytes_and_metadata() with onnx_opset < 19 on a FP8-quantized model, the post-processor will silently generate an invalid ONNX model instead of upgrading the opset or raising an error.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, In
FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).

Comment on lines +576 to +599
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)

if dq_only:
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)

try:
# TODO: Single-precision torch model assumed
param_dtype = next(model.parameters()).dtype
except StopIteration:
param_dtype = torch.float32
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
if is_int4_quantized(model) or is_mxfp8_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
disable_shape_infer=True,
check_fp16_ready=False,
)
else:
onnx_opt_graph = convert_to_f16(
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
)
if weights_dtype == "fp16":
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
disable_shape_infer=True,
check_fp16_ready=False,
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])

# TensorRT expects all scales to be postive
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)

# TensorRT expects all scales to be postive
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)

# TODO: Remove manual ir_version change once ORT supports ir_version 11
# Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version
onnx_opt_graph.ir_version = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Rebuild model_metadata after the final ONNX rewrites.

Lines 572-574 capture metadata before quantize_weights(), convert_float_to_float16(), change_casts_to_fp16(), and remove_redundant_casts(). Those passes add/remove nodes and can rewrite I/O tensors, so the returned metadata can drift from the serialized model bytes.

🛠️ Proposed fix

Move the metadata creation block below the last graph mutation.

-    model_metadata = create_model_metadata(
-        tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
-    )

Then re-add it after onnx_opt_graph.ir_version = 10:

+    model_metadata = create_model_metadata(
+        tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 576 - 599, The
model_metadata is built too early and can become stale after graph rewrites;
move the metadata creation so it runs after all ONNX mutations: after
quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.

Comment on lines +581 to +588
if weights_dtype == "fp16":
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
disable_shape_infer=True,
check_fp16_ready=False,
op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd modelopt && wc -l torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 103


🏁 Script executed:

cd modelopt && sed -n '500,515p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 760


🏁 Script executed:

cd modelopt && sed -n '575,595p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 831


🏁 Script executed:

cd modelopt && sed -n '480,530p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2582


🏁 Script executed:

cd modelopt && sed -n '550,600p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1877


🏁 Script executed:

cd modelopt && grep -n "weights_dtype" torch/_deploy/utils/torch_onnx.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 424


🏁 Script executed:

cd modelopt && grep -B5 -A5 "weights_dtype == \"fp16\"" torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 413


🏁 Script executed:

cd modelopt && grep -n "bf16\|fp16" torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 378


🏁 Script executed:

cd modelopt && grep -r "convert_float_to_float16" --include="*.py" | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 242


🏁 Script executed:

cd modelopt && grep -B10 "use_torch_autocast" torch/_deploy/utils/torch_onnx.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 786


🏁 Script executed:

find modelopt -name "*test*torch_onnx*" -o -name "*torch_onnx*test*" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cd modelopt && grep -r "weights_dtype.*fp16\|fp16.*weights_dtype" --include="*.py" tests/ 2>/dev/null | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cd modelopt && git log --oneline -20 -- torch/_deploy/utils/torch_onnx.py 2>/dev/null | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 116


🏁 Script executed:

cd modelopt && sed -n '440,460p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1172


🏁 Script executed:

cd modelopt && sed -n '500,595p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3831


🏁 Script executed:

cd modelopt && git log -p --all -- torch/_deploy/utils/torch_onnx.py 2>/dev/null | grep -A20 -B5 "use_torch_autocast\|weights_dtype.*fp16" | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cd modelopt && find . -path ./modelopt -prune -o -name "*.py" -type f -exec grep -l "weights_dtype.*fp16\|fp16.*weights_dtype" {} \; 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 101


🏁 Script executed:

cd modelopt && find . -name "*.py" -type f -exec grep -l "torch_onnx.*fp16\|export.*fp16" {} \; 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cd modelopt && sed -n '1,50p' torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1648


🏁 Script executed:

cd modelopt && git diff HEAD~5..HEAD -- torch/_deploy/utils/torch_onnx.py 2>/dev/null | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cd modelopt && git show HEAD:torch/_deploy/utils/torch_onnx.py 2>/dev/null | sed -n '500,520p'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cd modelopt && grep -n "autocast\|bf16" torch/_deploy/utils/torch_onnx.py

Repository: NVIDIA/Model-Optimizer

Length of output: 419


🏁 Script executed:

cd modelopt && python3 << 'EOF'
# Check the logic of use_torch_autocast for different weights_dtype values
def analyze_autocast_logic(is_fp4, is_mxfp8, weights_dtype):
    use_torch_autocast = not (
        is_fp4 or is_mxfp8 or weights_dtype == "fp32"
    )
    return use_torch_autocast

print("Current logic (assuming not quantized):")
for dtype in ["fp32", "fp16", "bf16"]:
    result = analyze_autocast_logic(False, False, dtype)
    print(f"  weights_dtype={dtype}: use_torch_autocast={result}")

print("\nProposed logic (assuming not quantized):")
def proposed_logic(is_fp4, is_mxfp8, weights_dtype):
    use_torch_autocast = weights_dtype == "bf16" and not (is_fp4 or is_mxfp8)
    return use_torch_autocast

for dtype in ["fp32", "fp16", "bf16"]:
    result = proposed_logic(False, False, dtype)
    print(f"  weights_dtype={dtype}: use_torch_autocast={result}")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 426


The FP16 export path mixes autocast with explicit graph conversion, which may cause unintended behavior.

Lines 505-508 enable torch.autocast("cuda") for both FP16 and BF16 (weights_dtype != "fp32" and not quantized). This means the model is traced under autocast, which can promote operations to lower precision. Then at lines 581-588, convert_float_to_float16() applies explicit post-export conversion. This dual approach can conflict—autocast during tracing may insert different operators than expected, undermining the explicit conversion strategy.

BF16 benefits from autocast (mixed precision promotion during tracing), but FP16 should rely solely on post-export graph conversion without autocast interference.

Recommended fix: Disable autocast for FP16 and keep it only for BF16:

-    use_torch_autocast = not (
-        is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
-    )
+    use_torch_autocast = weights_dtype == "bf16" and not (
+        is_fp4_quantized(model) or is_mxfp8_quantized(model)
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 581 - 588, The FP16
export path should not use torch.autocast during tracing because you already
perform explicit post-export conversion with convert_float_to_float16; update
the autocast logic so the torch.autocast("cuda") context is only entered when
weights_dtype == "bf16" (and not when weights_dtype == "fp16"), i.e., change the
condition that currently enables autocast for weights_dtype != "fp32" to
specifically check for "bf16" and leave the FP16 path to rely solely on
convert_float_to_float16; keep the convert_float_to_float16 call for FP16
unchanged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/onnx/utils.py (1)

1507-1511: ⚠️ Potential issue | 🟠 Major

Guard FP32→FP16 cast rewrites by input dtype.

At Line 1507, this rewrites any Cast(to=FLOAT) feeding target ops, including casts from non-FP32 sources. That can silently change behavior by removing intentional upcasts.

💡 Proposed fix
     for node in model.graph.node:
         if node.op_type != "Cast":
             continue
@@
         if not feeds_target:
             continue
 
+        cast_input_type = _get_tensor_type_by_name(model, node.input[0])
+        if cast_input_type != onnx.TensorProto.FLOAT:
+            continue
+
         # Check if Cast is to FP32, and change to FP16
         for attr in node.attribute:
             if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
                 attr.i = onnx.TensorProto.FLOAT16
                 break
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

921-923: Prefer public onnx_utils APIs over underscored helpers across modules.

Using _get_tensor_type_by_name, _bypass_cast_node, and _is_same_type_cast from another module couples this class to private implementation details. Exposing public wrappers in modelopt/onnx/utils.py would make this boundary safer.

Also applies to: 942-942, 1097-1102

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/autocast/precisionconverter.py` around lines 921 - 923, The
code currently calls private helpers onnx_utils._get_tensor_type_by_name,
onnx_utils._bypass_cast_node, and onnx_utils._is_same_type_cast from
precisionconverter.py; replace these calls with public wrapper APIs (e.g.,
get_tensor_type_by_name, bypass_cast_node, is_same_type_cast) exported from
modelopt/onnx/utils.py and update precisionconverter.py to call those public
names (also update the other locations that use the underscored helpers around
lines referenced). Add the small public wrapper implementations in
modelopt/onnx/utils.py that delegate to the existing private functions so other
modules use the stable public API and then run tests to ensure behavior is
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/utils.py`:
- Around line 1296-1322: The current _is_sequential_cast only compares the two
Cast target types; modify it to also fetch the data type of the original source
feeding the first Cast (e.g., inspect the producer of node.input[0] or its
ValueInfo/initializer) and verify that this source type equals the second cast's
target type (the value returned by get_cast_to_type(next_node)) before returning
True; this extra check ensures that when _bypass_cast_node rewires the graph the
source type is compatible with the second Cast. Use get_consumer_nodes,
get_cast_to_type and the node.input[0] producer lookup to locate and compare
types.

---

Nitpick comments:
In `@modelopt/onnx/autocast/precisionconverter.py`:
- Around line 921-923: The code currently calls private helpers
onnx_utils._get_tensor_type_by_name, onnx_utils._bypass_cast_node, and
onnx_utils._is_same_type_cast from precisionconverter.py; replace these calls
with public wrapper APIs (e.g., get_tensor_type_by_name, bypass_cast_node,
is_same_type_cast) exported from modelopt/onnx/utils.py and update
precisionconverter.py to call those public names (also update the other
locations that use the underscored helpers around lines referenced). Add the
small public wrapper implementations in modelopt/onnx/utils.py that delegate to
the existing private functions so other modules use the stable public API and
then run tests to ensure behavior is unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 188d9d44-97fd-4011-abb0-37b5f6bdbf27

📥 Commits

Reviewing files that changed from the base of the PR and between 788313f and 40ce80f.

📒 Files selected for processing (5)
  • modelopt/onnx/autocast/graphsanitizer.py
  • modelopt/onnx/autocast/precisionconverter.py
  • modelopt/onnx/autocast/utils.py
  • modelopt/onnx/utils.py
  • tests/unit/onnx/autocast/test_precisionconverter.py

Comment on lines +1296 to +1322
def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
assert node.op_type == "Cast"
output_type = get_cast_to_type(node)

# Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed
# Cast to low precision -> cast to high precision affects precision and should not be removed
precision_order = [
onnx.TensorProto.DOUBLE,
onnx.TensorProto.FLOAT,
onnx.TensorProto.FLOAT16,
onnx.TensorProto.BFLOAT16,
]
consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"]

# If the first cast has additional consumers, we should not remove it
if len(consumers) != 1:
return False

next_node = consumers[0]
first_cast_type = output_type
second_cast_type = get_cast_to_type(next_node)

return (
first_cast_type in precision_order
and second_cast_type in precision_order
and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and read the function
cat -n modelopt/onnx/utils.py | sed -n '1290,1330p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1919


🏁 Script executed:

# Also check if the helper function _get_tensor_type_by_name exists (used in proposed fix)
rg "_get_tensor_type_by_name" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 216


🏁 Script executed:

# Search for get_cast_to_type and get_consumer_nodes to understand them
rg "def get_cast_to_type|def get_consumer_nodes" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 211


🏁 Script executed:

# Check if there are tests for this function
fd "test.*\.py" | xargs rg "_is_sequential_cast" 2>/dev/null || echo "No tests found"

Repository: NVIDIA/Model-Optimizer

Length of output: 80


🏁 Script executed:

# Find where _is_sequential_cast is called
rg "_is_sequential_cast" modelopt/onnx/utils.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 440


🏁 Script executed:

# Look at the context around _bypass_cast_node to understand what happens when removal occurs
cat -n modelopt/onnx/utils.py | sed -n '1325,1360p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2048


🏁 Script executed:

# Search for comments or documentation about cast removal logic
rg "sequential.*cast|cast.*removal|bypass.*cast" modelopt/onnx/utils.py -i -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1472


🏁 Script executed:

# Look for any test files that might test cast optimization
fd "test.*\.py" modelopt/onnx/ | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Get more context about _is_same_type_cast to understand the overall pattern
cat -n modelopt/onnx/utils.py | sed -n '1285,1295p'

Repository: NVIDIA/Model-Optimizer

Length of output: 572


🏁 Script executed:

# Look for the full remove_redundant_casts function to understand the overall flow
rg "def remove_redundant_casts" modelopt/onnx/utils.py -A 30

Repository: NVIDIA/Model-Optimizer

Length of output: 1393


🏁 Script executed:

# Check if there's any documentation or related test about safe cast removal
fd "*.py" modelopt/onnx/ | xargs grep -l "cast\|Cast" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 349


🏁 Script executed:

# Look at how _get_tensor_type_by_name works and what it can return
rg "def _get_tensor_type_by_name" modelopt/onnx/utils.py -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 991


🏁 Script executed:

# Check for any tests or usage examples that show what types node.input can have
fd ".py" modelopt/onnx/ --type f | xargs grep -l "remove_redundant_casts\|_is_sequential_cast" 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 133


🏁 Script executed:

# Look at the overall structure of the optimization and find if there are any tests
fd ".py" --type f modelopt/onnx/ | xargs grep -l "test.*cast\|cast.*test" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check the precisionconverter file to see if it has any tests or usage
cat -n modelopt/onnx/autocast/precisionconverter.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 2246


🏁 Script executed:

# Let me understand the logic better by looking at a concrete example
# First, let me see if there are any examples or documentation about this function
rg "sequential.*cast" modelopt/ -i --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 555


🏁 Script executed:

# Check if the source type is ever checked anywhere in the function
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p' | grep -E "input|src|source"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the broader context to understand how the function comment relates to the logic
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1323


🏁 Script executed:

# Let me trace through a concrete scenario to understand the issue better
# Looking at the comment and logic more carefully
cat -n modelopt/onnx/utils.py | sed -n '1296,1325p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1415


🏁 Script executed:

# Understand what happens when we remove the first cast by looking at bypass logic
cat -n modelopt/onnx/utils.py | sed -n '1325,1355p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1685


Sequential-cast removal must validate the original source type against the final cast target.

The function only compares the two cast target types (lines 1318–1322) but doesn't verify that removing the first cast preserves the input type to the second cast. When _bypass_cast_node rewires the graph, it directly connects the original source to the second cast. If the source type differs from what the second cast expects, this changes behavior.

For example:

  • FLOAT16 → cast(FLOAT) → cast(FLOAT16) passes the current check (FLOAT ≤ FLOAT16)
  • After removal: FLOAT16 → cast(FLOAT16) — but cast2 was designed for FLOAT input, causing incorrect behavior

Add a check to ensure the source type matches the second cast's target type before removal.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1296 - 1322, The current
_is_sequential_cast only compares the two Cast target types; modify it to also
fetch the data type of the original source feeding the first Cast (e.g., inspect
the producer of node.input[0] or its ValueInfo/initializer) and verify that this
source type equals the second cast's target type (the value returned by
get_cast_to_type(next_node)) before returning True; this extra check ensures
that when _bypass_cast_node rewires the graph the source type is compatible with
the second Cast. Use get_consumer_nodes, get_cast_to_type and the node.input[0]
producer lookup to locate and compare types.

@ajrasane ajrasane enabled auto-merge (squash) March 13, 2026 17:01
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane force-pushed the ajrasane/onnx_qdq branch from 40ce80f to 05c33b2 Compare March 13, 2026 17:07
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (3)
modelopt/torch/_deploy/utils/torch_onnx.py (2)

576-576: ⚠️ Potential issue | 🟠 Major

Rebuild model_metadata after the last graph rewrite.

Starting with quantize_weights(), this function mutates node names, tensor dtypes, and Q/DQ structure after model_metadata has already been captured above. The returned metadata can therefore describe a different graph than the bytes written to disk.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` at line 576, model_metadata is
captured before quantize_weights mutates the ONNX graph (node names, tensor
dtypes, Q/DQ structure), so the saved metadata can be out of sync with
onnx_opt_graph; after the final graph rewrite (the call to quantize_weights that
returns onnx_opt_graph) re-run the metadata extraction routine (the same
function that produced model_metadata earlier) to rebuild model_metadata from
the mutated onnx_opt_graph so the metadata matches the bytes written to disk;
update any downstream uses to reference the new model_metadata variable produced
after quantize_weights.

581-592: ⚠️ Potential issue | 🟠 Major

The FP16 path is still “autocast + convert_float_to_float16()”.

With this new rewrite block, weights_dtype="fp16" still traces under torch.autocast("cuda") earlier in the function, so FP16 export is not actually using convert_float_to_float16() instead of autocast. That makes the exported graph depend on both mechanisms and undermines the stated pipeline change.

🛠️ Suggested earlier-function change
-    use_torch_autocast = not (
-        is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
-    )
+    use_torch_autocast = weights_dtype == "bf16" and not (
+        is_fp4_quantized(model) or is_mxfp8_quantized(model)
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 581 - 592, The
current FP16 branch still runs under torch.autocast earlier, so the export uses
both autocast and convert_float_to_float16; modify the export flow so when
weights_dtype == "fp16" you do NOT run the model export under
torch.autocast("cuda") (or skip the autocast context) so the graph is produced
in full-precision then transformed only by convert_float_to_float16 and
change_casts_to_fp16; locate the autocast usage earlier in this module
(torch.autocast or the export context manager) and add a conditional to bypass
it when weights_dtype == "fp16", ensuring
convert_float_to_float16/remove_redundant_casts are the sole FP16
transformations.
modelopt/onnx/export/fp8_exporter.py (1)

121-147: ⚠️ Potential issue | 🟠 Major

Reject native FP8 Q/DQ export below opset 19.

This conversion emits native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and saturate, but it still doesn't guard against onnx_opset < 19. Exporting FP8 with a lower opset will silently produce an invalid model instead of failing early or upgrading the opset.

🛠️ Suggested guard
     def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
+        opset = next(
+            (op.version for op in onnx_model.opset_import if op.domain in ("", "ai.onnx")),
+            0,
+        )
+        if opset < 19:
+            raise ValueError("Native FP8 ONNX Q/DQ requires ai.onnx opset >= 19.")
+
         logger.info("Post-processing FP8 quantized model")
         graph = gs.import_onnx(onnx_model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, The conversion
loop for TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx
opset < 19; detect the model/export opset (e.g., an existing variable or by
examining graph.opset or a passed opset parameter) before performing the
conversions in the loops that change node.op to
"QuantizeLinear"/"DequantizeLinear" and set FLOAT8E4M3FN/saturate, and if the
opset is less than 19 either raise a clear exception or upgrade the opset to >=
19 before modifying nodes (apply this check where you manipulate node.op and
node.attrs in the FP8 exporter function that iterates graph.nodes).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 127-134: The injected zero-point Constant currently uses node.name
which may be empty and cause duplicate tensor names; change the naming for the
Constant created from zp_tensor/zp_values/zero_point to a guaranteed-unique
string (e.g., combine node.name when present with a unique suffix such as a
uuid4 or the node's memory id or an incrementing counter, or use an ONNX/graph
helper that returns a unique name) so each FP8 zero-point Constant has a
distinct tensor name even for unnamed TRT FP8 Q nodes.

In `@modelopt/onnx/utils.py`:
- Around line 1262-1276: The helper _get_tensor_type_by_name must also handle
producer-only tensors emitted by nodes (e.g., a Constant node that produces a
tensor but has no value_info entry); modify _get_tensor_type_by_name to iterate
model.graph.node and when a node.output matches tensor_name, if node.op_type ==
"Constant" extract the TensorProto from the node attribute (attribute named
"value") and return its data_type (or elem_type equivalent), otherwise
skip/continue so non-materialized producer-only tensors do not cause an
exception and allow remove_redundant_casts() to fold Constant->Cast patterns.
- Around line 1403-1417: The two cast-removal checks can both match the same
node causing duplicate removal; make them mutually exclusive by ensuring once a
node is handled by _is_sequential_cast(onnx_model, node) (where you call
_bypass_cast_node and append to nodes_to_remove) you skip the subsequent
_is_foldable_constant_cast_pattern check (e.g., use an elif or continue) so you
only call _bypass_cast_node, _convert_constant_values and append to
nodes_to_remove once; update the block containing _is_sequential_cast,
_is_foldable_constant_cast_pattern, _bypass_cast_node, _convert_constant_values,
get_producer_nodes, nodes_to_remove, and logger.debug accordingly.
- Around line 1499-1511: The current logic uses any(...) on tensor_to_consumers
to set Casts to FP16 even if only one consumer is in target_op_types, which
incorrectly changes shared Cast outputs; modify the check so the Cast is
retargeted only when the entire fanout is eligible (replace the any(...) test
with an all(...) test, and treat empty consumer lists as ineligible/skip), then
keep the existing loop over node.attribute that looks for attr.name == "to" and
change attr.i from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 only when
that all-consumers condition holds.

---

Duplicate comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: The conversion loop for
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx opset <
19; detect the model/export opset (e.g., an existing variable or by examining
graph.opset or a passed opset parameter) before performing the conversions in
the loops that change node.op to "QuantizeLinear"/"DequantizeLinear" and set
FLOAT8E4M3FN/saturate, and if the opset is less than 19 either raise a clear
exception or upgrade the opset to >= 19 before modifying nodes (apply this check
where you manipulate node.op and node.attrs in the FP8 exporter function that
iterates graph.nodes).

In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Line 576: model_metadata is captured before quantize_weights mutates the ONNX
graph (node names, tensor dtypes, Q/DQ structure), so the saved metadata can be
out of sync with onnx_opt_graph; after the final graph rewrite (the call to
quantize_weights that returns onnx_opt_graph) re-run the metadata extraction
routine (the same function that produced model_metadata earlier) to rebuild
model_metadata from the mutated onnx_opt_graph so the metadata matches the bytes
written to disk; update any downstream uses to reference the new model_metadata
variable produced after quantize_weights.
- Around line 581-592: The current FP16 branch still runs under torch.autocast
earlier, so the export uses both autocast and convert_float_to_float16; modify
the export flow so when weights_dtype == "fp16" you do NOT run the model export
under torch.autocast("cuda") (or skip the autocast context) so the graph is
produced in full-precision then transformed only by convert_float_to_float16 and
change_casts_to_fp16; locate the autocast usage earlier in this module
(torch.autocast or the export context manager) and add a conditional to bypass
it when weights_dtype == "fp16", ensuring
convert_float_to_float16/remove_redundant_casts are the sole FP16
transformations.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7704d15e-7bdf-4b56-b725-8f52b69a07e2

📥 Commits

Reviewing files that changed from the base of the PR and between 40ce80f and 05c33b2.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/onnx/autocast/graphsanitizer.py
  • modelopt/onnx/autocast/precisionconverter.py
  • modelopt/onnx/autocast/utils.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/export/nvfp4_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • tests/unit/onnx/autocast/test_precisionconverter.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/onnx/autocast/test_precisionconverter.py

Comment on lines +127 to +134
# Create FP8 zero point constant
zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1]) # 1-element tensor
zp_tensor.raw_data = b"\x00" # Zero in FP8
zp_values = LazyValues(zp_tensor)
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
node.inputs.append(zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use a guaranteed-unique tensor name for the injected zero point.

node.name is optional in ONNX, so node.name + "_zero_point" can collapse to the same tensor name for multiple unnamed TRT FP8 Q nodes. That can make the exported graph invalid due to duplicate tensor names.

🛠️ Safer naming
-                    zero_point = gs.Constant(node.name + "_zero_point", zp_values)
+                    zero_point = gs.Constant(f"{node.outputs[0].name}_zero_point", zp_values)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 127 - 134, The injected
zero-point Constant currently uses node.name which may be empty and cause
duplicate tensor names; change the naming for the Constant created from
zp_tensor/zp_values/zero_point to a guaranteed-unique string (e.g., combine
node.name when present with a unique suffix such as a uuid4 or the node's memory
id or an incrementing counter, or use an ONNX/graph helper that returns a unique
name) so each FP8 zero-point Constant has a distinct tensor name even for
unnamed TRT FP8 Q nodes.

Comment on lines +1262 to +1276
def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str):
"""Get the tensor element type. Searches value_info, initializers, inputs, and outputs."""
for vi in model.graph.value_info:
if vi.name == tensor_name:
return vi.type.tensor_type.elem_type
for init in model.graph.initializer:
if init.name == tensor_name:
return init.data_type
for inp in model.graph.input:
if inp.name == tensor_name:
return inp.type.tensor_type.elem_type
for out in model.graph.output:
if out.name == tensor_name:
return out.type.tensor_type.elem_type
raise Exception(f"did not find tensor {tensor_name}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Teach _get_tensor_type_by_name() about producer-only tensors.

This helper only searches value_info, initializers, inputs, and outputs. A Cast fed by a Constant output that is not materialized in value_info now throws here before remove_redundant_casts() can fold the Constant -> Cast pattern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1262 - 1276, The helper
_get_tensor_type_by_name must also handle producer-only tensors emitted by nodes
(e.g., a Constant node that produces a tensor but has no value_info entry);
modify _get_tensor_type_by_name to iterate model.graph.node and when a
node.output matches tensor_name, if node.op_type == "Constant" extract the
TensorProto from the node attribute (attribute named "value") and return its
data_type (or elem_type equivalent), otherwise skip/continue so non-materialized
producer-only tensors do not cause an exception and allow
remove_redundant_casts() to fold Constant->Cast patterns.

Comment on lines +1403 to +1417
# Find sequential casts that don't change precision
if _is_sequential_cast(onnx_model, node):
nodes_to_remove.append(node)
_bypass_cast_node(onnx_model, node)
logger.debug(f"Found removable double-cast: {node.name}")

# Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
if _is_foldable_constant_cast_pattern(onnx_model, node):
nodes_to_remove.append(node)
cast_producers = get_producer_nodes(onnx_model, node.input[0])
assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant"
constant_producer = cast_producers[0]
_convert_constant_values(constant_producer, node)
_bypass_cast_node(onnx_model, node)
logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make the redundant-cast branches mutually exclusive.

A node can satisfy both _is_sequential_cast() and _is_foldable_constant_cast_pattern() (for example Constant -> Cast -> Cast). In that case this block bypasses it twice, appends it twice, and the second graph.node.remove(node) will fail.

🛠️ Minimal fix
             if _is_sequential_cast(onnx_model, node):
                 nodes_to_remove.append(node)
                 _bypass_cast_node(onnx_model, node)
                 logger.debug(f"Found removable double-cast: {node.name}")
+                continue
 
             # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
             if _is_foldable_constant_cast_pattern(onnx_model, node):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1403 - 1417, The two cast-removal checks
can both match the same node causing duplicate removal; make them mutually
exclusive by ensuring once a node is handled by _is_sequential_cast(onnx_model,
node) (where you call _bypass_cast_node and append to nodes_to_remove) you skip
the subsequent _is_foldable_constant_cast_pattern check (e.g., use an elif or
continue) so you only call _bypass_cast_node, _convert_constant_values and
append to nodes_to_remove once; update the block containing _is_sequential_cast,
_is_foldable_constant_cast_pattern, _bypass_cast_node, _convert_constant_values,
get_producer_nodes, nodes_to_remove, and logger.debug accordingly.

Comment on lines +1499 to +1511
# Check if this Cast outputs to a target op type
cast_output = node.output[0]
consumers = tensor_to_consumers.get(cast_output, [])
feeds_target = any(c.op_type in target_op_types for c in consumers)

if not feeds_target:
continue

# Check if Cast is to FP32, and change to FP16
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
attr.i = onnx.TensorProto.FLOAT16
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only retarget Casts whose entire fanout is eligible.

any(...) means one Concat/Add consumer is enough to flip a shared Cast to FP16 for every consumer of that tensor. If the cast output also feeds a non-target branch, that branch is silently changed too.

🛠️ Safe fallback
-        feeds_target = any(c.op_type in target_op_types for c in consumers)
+        feeds_target = bool(consumers) and all(c.op_type in target_op_types for c in consumers)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1499 - 1511, The current logic uses
any(...) on tensor_to_consumers to set Casts to FP16 even if only one consumer
is in target_op_types, which incorrectly changes shared Cast outputs; modify the
check so the Cast is retargeted only when the entire fanout is eligible (replace
the any(...) test with an all(...) test, and treat empty consumer lists as
ineligible/skip), then keep the existing loop over node.attribute that looks for
attr.name == "to" and change attr.i from onnx.TensorProto.FLOAT to
onnx.TensorProto.FLOAT16 only when that all-consumers condition holds.

@cjluo-nv
Copy link
Collaborator

Review Comments

Thanks for the PR — the core idea of replacing TRT-specific FP8 QDQ nodes with native ONNX ops is solid, and the refactoring to centralize cast utilities in onnx/utils.py makes sense. However, there are several concerns that I think should be addressed before merging.

1. BF16 regression in _convert_constant_values (High)

The original PrecisionConverter._convert_constant_values had special handling for bfloat16:

  • Reading: used read_f16_tensor_as_fp32() for bf16 input tensors
  • Writing: manually created TensorProto with raw bytes for bf16 output

The new onnx_utils._convert_constant_values (modelopt/onnx/utils.py) uses onnx.numpy_helper.to_array() / from_array() for all types, which doesn't handle bfloat16 natively. This could silently break bf16 constant folding — exactly the kind of AutoCast regression @galagam flagged.

2. BF16 path dropped in get_onnx_bytes_and_metadata (High)

The old code in torch_onnx.py handled weights_dtype in ["fp16", "bf16"] and used convert_to_f16() for the bf16 path. The new code only handles weights_dtype == "fp16". If weights_dtype == "bf16", no FP16/BF16 conversion happens at all. This looks like a silent behavioral regression.

3. Monkey-patching onnxconverter_common with suppress(AttributeError) (Medium)

The patch at torch_onnx.py:59-65 silently swallows all AttributeError exceptions from remove_unnecessary_cast_node. This could mask real bugs. Could you:

  • Add a comment explaining the specific upstream bug this works around?
  • Add a TODO with a link to the upstream issue (if one exists) so this can be removed when fixed?
  • Consider catching more narrowly if possible?

4. quantize_weights now called unconditionally (Medium)

In torch_onnx.py, the guard if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model) was removed. quantize_weights is now called for all models regardless of quantization mode. Can you confirm this is safe / a no-op for non-quantized models?

5. _get_tensor_type_by_name performance and scope change (Low)

The original PrecisionConverter._get_tensor_type used pre-built O(1) dict lookups (self.value_info_map, self.initializer_map). The new _get_tensor_type_by_name does a linear scan over value_info, initializer, input, and output lists on every call, and also searches graph inputs/outputs which the original didn't. For large models this could be noticeably slower. Worth considering caching or at least noting the tradeoff.

6. Minor: _is_same_type_castinput_types is not None is always True

input_types is a list comprehension result, so it's always a list (never None). The check and input_types is not None is a no-op. Copied from the original code but worth cleaning up while you're here.


Items 1, 2, and 4 are the ones I'd like to see addressed or explicitly justified before approving. The rest are suggestions. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants