Skip to content

Fix: resolve TorchScript JIT stability and PyTorch 2.2 compatibility bugs#277

Merged
phiandark merged 4 commits into
NVIDIA:mainfrom
changzhiai:fix/torchscript-jit-compat
May 12, 2026
Merged

Fix: resolve TorchScript JIT stability and PyTorch 2.2 compatibility bugs#277
phiandark merged 4 commits into
NVIDIA:mainfrom
changzhiai:fix/torchscript-jit-compat

Conversation

@changzhiai
Copy link
Copy Markdown
Contributor

Summary

This Pull Request addresses two critical issues in the cuequivariance-torch library that were preventing successful training on certain PyTorch versions and blocking TorchScript export for deployment (e.g., for use in LAMMPS).


1. TorchScript JIT Stability & Type Inference

  • Problem: In tp_channel_wise.py, the JIT compiler failed to infer types for empty dictionaries. Additionally, a variable shadowing issue existed where indices_out was being used as both an input argument (Tensor) and an internal dictionary, which is not permitted in TorchScript.
  • Fix:
    • Renamed the input argument to indices_0 to ensure stable typing.
    • Added explicit type hints (e.g., Dict[int, torch.Tensor]) to internal dictionaries to assist the JIT compiler.

2. PyTorch 2.2+ Compatibility & Scripting Support

  • Problem: In segmented_polynomial.py, a call to the private function torch.fx._symbolic_trace.is_fx_symbolic_tracing() caused an AttributeError in PyTorch 2.2.x.
  • Fix:
    • Used getattr with the built-in bool function as a default fallback. This safely returns False if the function is missing and is fully compatible with torch.jit.script.

Files Modified:

  • cuequivariance_torch/operations/tp_channel_wise.py
  • cuequivariance_torch/primitives/segmented_polynomial.py

Verification:

  • Verified that models can be successfully exported to TorchScript using torch.jit.script for LAMMPS deployment.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 2, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@phiandark
Copy link
Copy Markdown
Collaborator

Hello @changzhiai , thank you very much for the PR, all changes look good to me.
Could you please "signoff" your commits as mentioned in the contributing guide?
After that I should be able to approve this.
Thanks!

@changzhiai changzhiai force-pushed the fix/torchscript-jit-compat branch from 1f05449 to 5d49e47 Compare May 6, 2026 19:06
@changzhiai
Copy link
Copy Markdown
Contributor Author

Hi @phiandark, thanks a lot for the guide. Now I already "signoff" the commits.

changzhiai added 3 commits May 6, 2026 12:38
…roduct

Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
@changzhiai changzhiai force-pushed the fix/torchscript-jit-compat branch from 5d49e47 to acd70b4 Compare May 6, 2026 19:38
@phiandark
Copy link
Copy Markdown
Collaborator

Hi @changzhiai , thank you for signing off.

Our CI caught another issue: this changes the named argument of ChannelWise and that can break downstream use, which might be bad for existing implementations. If it's just a matter of duplicate names, would you mind leaving the named argument as indices_out and change the internal version to something else instead?

Also, the pre-commit style check fails: can you make sure to install and run the pre-commit hooks before your next push. Then I can re-approve and merge.
Thank you!

…and fx-tracing fixes for Pytorch 2.2

Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
@changzhiai changzhiai force-pushed the fix/torchscript-jit-compat branch from 004c9eb to 0f59ba5 Compare May 12, 2026 17:27
@changzhiai
Copy link
Copy Markdown
Contributor Author

Hi @phiandark, thank you so much for your suggestion. I have updated the PR with the following fixes:

  1. Backward Compatibility: Reverted the indices_out argument rename in ChannelWiseTensorProduct and handled the naming conflict internally to prevent breaking downstream usage.
  2. Style Check: ran the pre-commit hooks (ruff and ruff format). All style checks are now passing locally.
  3. PyTorch Compatibility: Implemented a more robust getattr check for is_fx_symbolic_tracing in SegmentedPolynomial. This should resolve the documentation build failure by ensuring compatibility with older PyTorch versions (like 2.2).

Thank you!

@phiandark phiandark merged commit 74c226e into NVIDIA:main May 12, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants