[ROCm] Pass bool attention masks directly to hipDNN#3198
Open
rkayaith wants to merge 1 commit intousers/rkayaith/sdpa-hipdnn-backendfrom
Open
[ROCm] Pass bool attention masks directly to hipDNN#3198rkayaith wants to merge 1 commit intousers/rkayaith/sdpa-hipdnn-backendfrom
rkayaith wants to merge 1 commit intousers/rkayaith/sdpa-hipdnn-backendfrom
Conversation
This was referenced May 1, 2026
ec3343b to
88e9332
Compare
44b86bb to
25e716f
Compare
|
Jenkins build for 25e716ffb3ee7c68e18c4b3ebb1f0894f8dbcd37 commit finished as FAILURE |
25e716f to
d3ead50
Compare
88e9332 to
48379ff
Compare
hipDNN now exposes a BOOLEAN tensor data type (rocm-libraries 3bc293796d), so we can route bool attn_mask tensors through the PyTorch dispatch -> hipDNN -> fusilli pipeline unchanged instead of materializing a float additive mask via convert_boolean_attn_mask. attention.cpp: skip the upstream bool->float conversion on USE_ROCM + cudnn_attention; convert_boolean_attn_mask is now invoked per-backend (no behavior change for non-ROCm or non-cudnn backends). cudnn/hip/MHA.cpp: map kBool -> fe::DataType_t::BOOLEAN in to_fe_data_type so the bool mask is described correctly in the fusilli graph; drop the bool special-case in check_cudnn_sdpa_support that previously remapped to q.scalar_type. Test: test_cudnn_attention_bool_mask compares CUDNN_ATTENTION vs MATH backend with a triangular bool mask (no fully-masked rows, to avoid iree#24175).
|
Jenkins build for d3ead50515a81b95568859b91731128c31c5d240 commit finished as FAILURE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds hipDNN as a supported backend for SDPA.
The approach taken here is to re-use the existing
CUDNN_ATTENTIONbackend, adding support for it by routing through hipDNN when compiled for ROCm. i.e.torch.nn.attention.SDPBackend.CUDNN_ATTENTIONandaten::_scaled_dot_product_cudnn_attentionnow work on ROCm.The primary change here is adding
cudnn/hip/MHA.cpp, which is a "fork" ofcudnn/MHA.cpp, modified to use hipDNN instead of cuDNN. There's various differences between the implementations that made it simpler to just fork the entire file rather than trying to keep both implementations in the same file/rely on hipify for translation:Additionally, since hipDNN provides an API for querying engine support for a graph, during backend selection
sdp::can_use_cudnn_attentioncalls the newly addedat::native::check_cudnn_sdpa_support, which constructs the hipDNN graph for both forwards and backwards (if potentially needed) to query support. This is not cached, as it's assumed construction + querying support is implemented efficiently. The sequence of hipDNN calls:Both backends still share the same
attention.cukernel (hipDNN uses the hipified version).This is separated into a stack of PRs for easier review:
PR 1/3: [ROCm] Stop hipifying native/cudnn/ and native/quantized/cudnn/ files
ifdef'd out. This disables the hipification rules and fixes up includes/directives so the CUDA files compile cleanly on ROCM. The primary motivation here is to stop generatingcudnn/hip/MHA.cpp; the following changes add this file back with the hipDNN backend implementation.PR 2/3: [ROCm] Integrate hipDNN as an SDPA backend:
Add aotriton.images/ to .gitignore- NFC changeExtract compute_matching_strides from alloc_with_matching_layout- NFC changeSplit cudnn/MHA.cpp into separate cuDNN and hipDNN filesCopy cuDNN MHA implementation verbatim into hipDNN MHA[ROCm] Add hipDNN SDPA backend dispatchUpdate tests to reflect cudnn backend being available on ROCM.CUDNN_ATTENTIONtests can now be run on ROCm.PR 3/3: [ROCm] Pass bool attention masks directly to hipDNN (this PR)
Open questions for reviewers:
CUDNN_ATTENTIONbackend was done due to API similarities, though this could potentially be confusing. Would it be preferrable to completely separate hipDNN by adding a new backend?MHA.cppfiles be recommended?