Problem
When using torch.use_deterministic_algorithms(True), the sparsemax function fails due to the lack of deterministic support for cumsum in CUDA. This issue occurs specifically in the _sparsemax_threshold_and_support function, where the operation:
topk_cumsum = topk.cumsum(dim) - 1
triggers the following error:
RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.
Without a deterministic support models trained on GPU with the use of sparmax have unexcepted behavior when running on CPU for inference (with a radical drop in prediction accuracy).
Steps to reproduce
import torch
from entmax import sparsemax
torch.use_deterministic_algorithms(True)
x = torch.tensor([-2, 0, 0.5]).to("cuda")
sparsemax(x, dim=0)
Environment
- Entmax version:
1.3
- PyTorch version:
2.4.0
- CUDA version:
12.1
- Python version:
3.12.4
- OS:
Ubuntu 20.04 LTS
Dependencies
entmax==1.3
└── torch [required: >=1.3, installed: 2.4.0]
├── filelock [required: Any, installed: 3.15.4]
├── fsspec [required: Any, installed: 2024.6.1]
├── Jinja2 [required: Any, installed: 3.1.4]
│ └── MarkupSafe [required: >=2.0, installed: 2.1.5]
├── networkx [required: Any, installed: 3.2.1]
├── nvidia-cublas-cu12 [required: ==12.1.3.1, installed: 12.1.3.1]
├── nvidia-cuda-cupti-cu12 [required: ==12.1.105, installed: 12.1.105]
├── nvidia-cuda-nvrtc-cu12 [required: ==12.1.105, installed: 12.1.105]
├── nvidia-cuda-runtime-cu12 [required: ==12.1.105, installed: 12.1.105]
├── nvidia-cudnn-cu12 [required: ==9.1.0.70, installed: 9.1.0.70]
│ └── nvidia-cublas-cu12 [required: Any, installed: 12.1.3.1]
├── nvidia-cufft-cu12 [required: ==11.0.2.54, installed: 11.0.2.54]
├── nvidia-curand-cu12 [required: ==10.3.2.106, installed: 10.3.2.106]
├── nvidia-cusolver-cu12 [required: ==11.4.5.107, installed: 11.4.5.107]
│ ├── nvidia-cublas-cu12 [required: Any, installed: 12.1.3.1]
│ ├── nvidia-cusparse-cu12 [required: Any, installed: 12.1.0.106]
│ │ └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.5.82]
│ └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.5.82]
├── nvidia-cusparse-cu12 [required: ==12.1.0.106, installed: 12.1.0.106]
│ └── nvidia-nvjitlink-cu12 [required: Any, installed: 12.5.82]
├── nvidia-nccl-cu12 [required: ==2.20.5, installed: 2.20.5]
├── nvidia-nvtx-cu12 [required: ==12.1.105, installed: 12.1.105]
├── setuptools [required: Any, installed: 69.5.1]
├── sympy [required: Any, installed: 1.13.2]
│ └── mpmath [required: >=1.1.0,<1.4, installed: 1.3.0]
├── triton [required: ==3.0.0, installed: 3.0.0]
│ └── filelock [required: Any, installed: 3.15.4]
└── typing_extensions [required: >=4.8.0, installed: 4.12.2]
Solution
According to this issue, deterministic support for cumsum is resolved in PyTorch 2.6.0, but this version is not released yet.
For older versions, the following workarounds could be considered:
- Add support for CPU fallback or alternative deterministic algorithm for
cumsum.
- Explicitly document this limitation in the Entmax README or add an explicit warning when running.
Problem
When using
torch.use_deterministic_algorithms(True), the sparsemax function fails due to the lack of deterministic support forcumsumin CUDA. This issue occurs specifically in the_sparsemax_threshold_and_supportfunction, where the operation:triggers the following error:
Without a deterministic support models trained on GPU with the use of sparmax have unexcepted behavior when running on CPU for inference (with a radical drop in prediction accuracy).
Steps to reproduce
Environment
1.32.4.012.13.12.4Ubuntu 20.04 LTSDependencies
Solution
According to this issue, deterministic support for
cumsumis resolved in PyTorch 2.6.0, but this version is not released yet.For older versions, the following workarounds could be considered:
cumsum.