Skip to content

[ROCm] Pass bool attention masks directly to hipDNN#3198

Open
rkayaith wants to merge 1 commit intousers/rkayaith/sdpa-hipdnn-backendfrom
users/rkayaith/sdpa-hipdnn-bool-mask
Open

[ROCm] Pass bool attention masks directly to hipDNN#3198
rkayaith wants to merge 1 commit intousers/rkayaith/sdpa-hipdnn-backendfrom
users/rkayaith/sdpa-hipdnn-bool-mask

Conversation

@rkayaith
Copy link
Copy Markdown

@rkayaith rkayaith commented May 1, 2026

This PR adds hipDNN as a supported backend for SDPA.

The approach taken here is to re-use the existing CUDNN_ATTENTION backend, adding support for it by routing through hipDNN when compiled for ROCm. i.e. torch.nn.attention.SDPBackend.CUDNN_ATTENTION and aten::_scaled_dot_product_cudnn_attention now work on ROCm.

The primary change here is adding cudnn/hip/MHA.cpp, which is a "fork" of cudnn/MHA.cpp, modified to use hipDNN instead of cuDNN. There's various differences between the implementations that made it simpler to just fork the entire file rather than trying to keep both implementations in the same file/rely on hipify for translation:

  • various minor API differences between cuDNN/hipDNN which made it difficult to re-use code directly
  • differences in feature support: hipDNN doesn't support nested tensors which simplifies a lot of code, however it is more flexible in other aspects e.g. dtypes.

Additionally, since hipDNN provides an API for querying engine support for a graph, during backend selection sdp::can_use_cudnn_attention calls the newly added at::native::check_cudnn_sdpa_support, which constructs the hipDNN graph for both forwards and backwards (if potentially needed) to query support. This is not cached, as it's assumed construction + querying support is implemented efficiently. The sequence of hipDNN calls:

                SELECTION
                ─────────
       sdp::can_use_cudnn_attention()
                    │
                    ▼
    at::native::check_cudnn_sdpa_support()
                    │
                    ▼
     fe::graph::Graph::is_supported_ext()
                    │
                    ▼
                   bool
    
                EXECUTION
                ─────────
  at::native::run_cudnn_SDP_{fprop,bprop}()
                    │
                    ▼
           MHAGraphCache lookup
                    │
               ┌────┴────┐
               ▼         ▼
             HIT:      MISS:
             reuse       │
             cached      ▼
             graph     fe::graph::Graph::validate()
               │       fe::graph::Graph::build_operation_graph()
               │       fe::graph::Graph::create_execution_plans()
               │       fe::graph::Graph::check_support()
               │       fe::graph::Graph::build_plans()
               │         │
               └────┬────┘
                    ▼
         fe::graph::Graph::execute()

Both backends still share the same attention.cu kernel (hipDNN uses the hipified version).

This is separated into a stack of PRs for easier review:

  • PR 1/3: [ROCm] Stop hipifying native/cudnn/ and native/quantized/cudnn/ files

    • Various cudnn files are being hipified at the moment, with no functional purpose (it just results in code that's ifdef'd out. This disables the hipification rules and fixes up includes/directives so the CUDA files compile cleanly on ROCM. The primary motivation here is to stop generating cudnn/hip/MHA.cpp; the following changes add this file back with the hipDNN backend implementation.
  • PR 2/3: [ROCm] Integrate hipDNN as an SDPA backend:

    • I'd recommend reviewing the individual commits here:
      • Add aotriton.images/ to .gitignore - NFC change
      • Extract compute_matching_strides from alloc_with_matching_layout - NFC change
      • Split cudnn/MHA.cpp into separate cuDNN and hipDNN files
      • Copy cuDNN MHA implementation verbatim into hipDNN MHA
        • This is just to make it easier to review the next commit, as it shows the diff between cuDNN and hipDNN.
      • [ROCm] Add hipDNN SDPA backend dispatch
        • Primary implementation, looking at this commit will likely be the easiest way to review the hipDNN specific changes.
      • Update tests to reflect cudnn backend being available on ROCM.
        • CUDNN_ATTENTION tests can now be run on ROCm.
  • PR 3/3: [ROCm] Pass bool attention masks directly to hipDNN (this PR)

    • hipDNN supports boolean attention masks more efficiently than float masks. This skips the float->bool conversion that's normally done, and passes the original mask directly to hipDNN.

Open questions for reviewers:

  • Reusing the CUDNN_ATTENTION backend was done due to API similarities, though this could potentially be confusing. Would it be preferrable to completely separate hipDNN by adding a new backend?
  • Would additional code-sharing between the cuDNN and hipDNN MHA.cpp files be recommended?

@rkayaith rkayaith changed the title PR 3/3: Pass bool attention masks directly to hipDNN [ROCm] Pass bool attention masks directly to hipDNN May 1, 2026
@rkayaith rkayaith force-pushed the users/rkayaith/sdpa-hipdnn-backend branch from ec3343b to 88e9332 Compare May 1, 2026 22:19
@rkayaith rkayaith force-pushed the users/rkayaith/sdpa-hipdnn-bool-mask branch from 44b86bb to 25e716f Compare May 1, 2026 22:19
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 1, 2026

Jenkins build for 25e716ffb3ee7c68e18c4b3ebb1f0894f8dbcd37 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@rkayaith rkayaith force-pushed the users/rkayaith/sdpa-hipdnn-bool-mask branch from 25e716f to d3ead50 Compare May 1, 2026 22:49
@rkayaith rkayaith force-pushed the users/rkayaith/sdpa-hipdnn-backend branch from 88e9332 to 48379ff Compare May 1, 2026 22:49
hipDNN now exposes a BOOLEAN tensor data type (rocm-libraries
3bc293796d), so we can route bool attn_mask tensors through the
PyTorch dispatch -> hipDNN -> fusilli pipeline unchanged instead of
materializing a float additive mask via convert_boolean_attn_mask.

attention.cpp: skip the upstream bool->float conversion on
USE_ROCM + cudnn_attention; convert_boolean_attn_mask is now invoked
per-backend (no behavior change for non-ROCm or non-cudnn backends).
cudnn/hip/MHA.cpp: map kBool -> fe::DataType_t::BOOLEAN in
to_fe_data_type so the bool mask is described correctly in the
fusilli graph; drop the bool special-case in check_cudnn_sdpa_support
that previously remapped to q.scalar_type.

Test: test_cudnn_attention_bool_mask compares CUDNN_ATTENTION vs
MATH backend with a triangular bool mask (no fully-masked rows, to
avoid iree#24175).
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 1, 2026

Jenkins build for d3ead50515a81b95568859b91731128c31c5d240 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@rkayaith rkayaith marked this pull request as ready for review May 4, 2026 18:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant