Skip to content

🚀[FEA]: Add TensorRT compilation utility and hybrid Warp example#1565

Open
manmeet3591 wants to merge 2 commits intoNVIDIA:mainfrom
manmeet3591:feature/tensorrt-warp-inference
Open

🚀[FEA]: Add TensorRT compilation utility and hybrid Warp example#1565
manmeet3591 wants to merge 2 commits intoNVIDIA:mainfrom
manmeet3591:feature/tensorrt-warp-inference

Conversation

@manmeet3591
Copy link
Copy Markdown

@manmeet3591 manmeet3591 commented Apr 13, 2026

This PR adds a new inference optimization capability to PhysicsNeMo and a corresponding example for hybrid Warp + TensorRT execution.

New Features

  • TensorRT Compilation Utility: Added physicsnemo/utils/inference.py containing a compile_to_trt function. This utility wraps torch_tensorrt to simplify the process of optimizing PhysicsNeMo models for high-performance inference.
  • Hybrid Inference Example: Added examples/minimal/inference/torch_trt_warp_inference.py. This example demonstrates how to integrate NVIDIA Warp (for geometric tasks like neighbor search) with a TensorRT-optimized neural network in a single, zero-copy pipeline.

Why this is useful

Many Physics-AI models require complex geometric preprocessing (best handled by Warp) and high-speed neural network inference (best handled by TensorRT). This PR provides the necessary utilities and patterns to build such hybrid pipelines efficiently.

Signed-off-by: Manmeet Singh <manmeet20singh11@gmail.com>
This commit introduces a new inference utility module that provides
support for compiling PhysicsNeMo models to TensorRT using Torch-TensorRT.
It also adds a minimal example demonstrating a hybrid inference pipeline
that combines NVIDIA Warp for geometric processing (neighbor search)
with TensorRT for accelerated neural network execution.

Signed-off-by: Manmeet Singh <manmeet20singh11@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 13, 2026

Greptile Summary

This PR adds a compile_to_trt utility wrapping torch_tensorrt and a hybrid Warp+TensorRT inference example. Two blocking bugs need to be fixed before the new code is usable:

  • physicsnemo/utils/inference.py is missing Set in its typing import, causing a NameError at import time (breaking the entire physicsnemo.utils package on import).
  • The example calls radius_search_warp(..., device=device.type), but that function does not accept a device parameter — it derives the device from the input tensors — causing a TypeError at runtime.

Important Files Changed

Filename Overview
physicsnemo/utils/inference.py New TensorRT compilation utility — has a critical missing Set import that causes NameError at module import time; also has an unused Union import.
examples/minimal/inference/torch_trt_warp_inference.py New hybrid Warp+TensorRT example — calls radius_search_warp with an unsupported device keyword argument, which will raise TypeError at runtime.
physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py Type annotation updated from wp.context.Device to wp.Device for both count_neighbors and gather_neighbors — uses the correct public Warp API.
physicsnemo/utils/init.py Exports compile_to_trt and is_trt_available from the new inference module; straightforward addition.
physicsnemo/mesh/README.md Documentation link fixes — paths updated to be relative to the file's location rather than the repo root, improving portability when viewed from GitHub or rendered locally.

Reviews (1): Last reviewed commit: "🚀[FEA]: Add TensorRT compilation utilit..." | Re-trigger Greptile

# limitations under the License.

import logging
from typing import Any, List, Optional, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Missing Set import causes NameError at import time

Set is used in the type annotation on line 34 (Optional[Set[torch.dtype]]) but is not imported from typing. Because Python evaluates function annotations eagerly at definition time (without from __future__ import annotations), importing this module will immediately raise NameError: name 'Set' is not defined. Union is also imported but never used.

Suggested change
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Set

Comment on lines +63 to +66
queries = torch.randn(1000, 3, device=device)
radius = 0.1

# 3. Geometric Processing with Warp (Neighbor Search)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 radius_search_warp does not accept a device keyword argument

The function signature in physicsnemo/models/figconvnet/warp_neighbor_search.py is radius_search_warp(points, queries, radius, grid_dim=...) — it derives the device directly from the input tensors and has no device parameter. Passing device=device.type will raise TypeError: radius_search_warp() got an unexpected keyword argument 'device' at runtime.

Suggested change
queries = torch.randn(1000, 3, device=device)
radius = 0.1
# 3. Geometric Processing with Warp (Neighbor Search)
neighbor_index, neighbor_dist, neighbor_offset = radius_search_warp(
points, queries, radius
)

# limitations under the License.

import logging
from typing import Any, List, Optional, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused Union import

Union is imported but never referenced in this file. It can be removed to keep the imports clean.

Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @manmeet3591,

Thank you for opening this PR!

While this contribution is interesting, I'm not sure that it belongs in PhysicsNeMo-Core - it:
a) adds a new optional dependency of TensorRT that is used in this example, and we'd like to keep the dependency list for PhysicsNeMo minimal
b) the main compilation logic in inference.py is basically a thin wrapper around torch_tensorrt.compile(), which could be instead used directly by end-users in their downstream training utilities.

For example, by analogy: we don't include any PhysicsNeMo-specific wrappers for torch.compile support, yet many downstream users torch.compile() their PhysicsNeMo models in their own scripts.

Is there an angle of the value proposition here that I'm not fully understanding?

adjacency relationships, which is used internally for all computations. (See the
dedicated
[`physicsnemo.mesh.neighbors._adjacency.py`](physicsnemo/mesh/neighbors/_adjacency.py)
[`physicsnemo.mesh.neighbors._adjacency.py`](./neighbors/_adjacency.py)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these unrelated changes, or bring them in with a separate PR.

wp_points: wp.array(dtype=wp.vec3),
wp_queries: wp.array(dtype=wp.vec3),
wp_launch_device: wp.context.Device | None,
wp_launch_device: wp.Device | None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea, and fixes a deprecation warning, but is unrelated to the main PR here - please bring these changes in a separate PR.

def compile_to_trt(
model: torch.nn.Module,
input_signature: List[torch.Tensor],
enabled_precisions: Optional[Set[torch.dtype]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enabled_precisions: Optional[Set[torch.dtype]] = None,
enabled_precisions: Set[torch.dtype] | None = None,

Modernizes type-hint syntax


def compile_to_trt(
model: torch.nn.Module,
input_signature: List[torch.Tensor],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_signature: List[torch.Tensor],
input_signature: list[torch.Tensor],

trt_model = torch_tensorrt.compile(model, **compile_spec)
logger.info("TensorRT compilation successful.")
return trt_model
except Exception as e:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend allowing this to fail, rather than re-raising here. If we do choose to re-raise, perhaps we can narrow scope to tighter than a bare Exception?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree.

import torch

try:
import torch_tensorrt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

violates repo-wide optional import conventions; will trip up importlinter

@coreyjadams
Copy link
Copy Markdown
Collaborator

Hi @manmeet3591 I want to echo @peterdsharpe:

  1. Thanks for opening this, I actually do think there is significant value here. We've put a lot of work in physicsnemo to support warp kernels without graph breaks, for example. TensorRT inference is a similar story and important for our models.
  2. That said, I also can't understand quite where the "secret sauce" is here: I don't see any special treatment of warp in the tensorRT steps, so I'm confused what steps are necessary for enablement. Is there something that, if you didn't have it here, would prevent this from working?

It's an interesting direction. I think it's worth exploring, if we can see the benefit? I am certainly grateful you've started this conversation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants