Fix: resolve TorchScript JIT stability and PyTorch 2.2 compatibility bugs#277
Conversation
|
Hello @changzhiai , thank you very much for the PR, all changes look good to me. |
1f05449 to
5d49e47
Compare
|
Hi @phiandark, thanks a lot for the guide. Now I already "signoff" the commits. |
…roduct Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
5d49e47 to
acd70b4
Compare
|
Hi @changzhiai , thank you for signing off. Our CI caught another issue: this changes the named argument of ChannelWise and that can break downstream use, which might be bad for existing implementations. If it's just a matter of duplicate names, would you mind leaving the named argument as Also, the pre-commit style check fails: can you make sure to install and run the pre-commit hooks before your next push. Then I can re-approve and merge. |
…and fx-tracing fixes for Pytorch 2.2 Signed-off-by: Changzhi Ai <changzhiai@outlook.com>
004c9eb to
0f59ba5
Compare
|
Hi @phiandark, thank you so much for your suggestion. I have updated the PR with the following fixes:
Thank you! |
Summary
This Pull Request addresses two critical issues in the
cuequivariance-torchlibrary that were preventing successful training on certain PyTorch versions and blocking TorchScript export for deployment (e.g., for use in LAMMPS).1. TorchScript JIT Stability & Type Inference
tp_channel_wise.py, the JIT compiler failed to infer types for empty dictionaries. Additionally, a variable shadowing issue existed whereindices_outwas being used as both an input argument (Tensor) and an internal dictionary, which is not permitted in TorchScript.indices_0to ensure stable typing.Dict[int, torch.Tensor]) to internal dictionaries to assist the JIT compiler.2. PyTorch 2.2+ Compatibility & Scripting Support
segmented_polynomial.py, a call to the private functiontorch.fx._symbolic_trace.is_fx_symbolic_tracing()caused anAttributeErrorin PyTorch 2.2.x.getattrwith the built-inboolfunction as a default fallback. This safely returnsFalseif the function is missing and is fully compatible withtorch.jit.script.Files Modified:
cuequivariance_torch/operations/tp_channel_wise.pycuequivariance_torch/primitives/segmented_polynomial.pyVerification:
torch.jit.scriptfor LAMMPS deployment.