From df8994d33e212fd6bfb3cd3c56693f1381c1cca1 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 15 Apr 2026 16:45:24 +0800 Subject: [PATCH 01/15] add mask support for flash backend --- src/diffusers/models/attention_dispatch.py | 250 +++++++++++++++++++-- tests/others/test_varlen_pack_helpers.py | 149 ++++++++++++ 2 files changed, 381 insertions(+), 18 deletions(-) create mode 100644 tests/others/test_varlen_pack_helpers.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 837d573d8c4d..df105741963f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -78,7 +78,12 @@ if _CAN_USE_FLASH_ATTN: try: from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward + from flash_attn.flash_attn_interface import ( + _wrapped_flash_attn_backward, + _wrapped_flash_attn_forward, + _wrapped_flash_attn_varlen_backward, + _wrapped_flash_attn_varlen_forward, + ) except (ImportError, OSError, RuntimeError) as e: # Handle ABI mismatch or other import failures gracefully. # This can happen when flash_attn was compiled against a different PyTorch version. @@ -88,11 +93,15 @@ flash_attn_varlen_func = None _wrapped_flash_attn_backward = None _wrapped_flash_attn_forward = None + _wrapped_flash_attn_varlen_backward = None + _wrapped_flash_attn_varlen_forward = None else: flash_attn_func = None flash_attn_varlen_func = None _wrapped_flash_attn_backward = None _wrapped_flash_attn_forward = None + _wrapped_flash_attn_varlen_backward = None + _wrapped_flash_attn_varlen_forward = None if _CAN_USE_FLASH_ATTN_3: @@ -636,6 +645,74 @@ def _prepare_for_flash_attn_or_sage_varlen( return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) +def _padded_to_unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + """gather valid tokens from a padded `(batch, seq, ...)` tensor into a packed `(nnz, ...)` tensor.""" + return tensor.reshape(-1, *tensor.shape[2:])[indices] + + +def _unpad_to_padded(packed: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + """scatter a packed `(nnz, ...)` tensor back to padded `(batch_size, seq_len, ...)`.""" + output = torch.zeros(batch_size * seq_len, *packed.shape[1:], dtype=packed.dtype, device=packed.device) + output[indices] = packed + return output.view(batch_size, seq_len, *packed.shape[1:]) + + +@dataclass +class _VarlenPackedInputs: + """Inputs for varlen attention kernels: packed (unpadded) K/V, full-length Q, and KV index metadata.""" + + # tensors: query is full-length (flattened), key/value are packed (unpadded) + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + + # cumulative sequence lengths for K (derived from attn_mask) + cu_seqlens_q: torch.Tensor # (batch_size + 1,) — uniform stride of seq_len_q + cu_seqlens_k: torch.Tensor # (batch_size + 1,) + max_seqlen_k: int + + # shape metadata for unpacking outputs + batch_size: int + seq_len_q: int + seq_len_kv: int + + # flat indices of valid KV tokens in the (batch * seq_kv) dimension. + indices_k: torch.Tensor + + def unpack(self, packed_out: torch.Tensor) -> torch.Tensor: + return packed_out.view(self.batch_size, self.seq_len_q, *packed_out.shape[1:]) + + +def _pack_qkv( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, +) -> _VarlenPackedInputs: + """Pack Q/K/V tensors by removing padding tokens identified by *attn_mask*.""" + batch_size = query.shape[0] + seq_len_q = query.shape[1] + seq_len_kv = key.shape[1] + + _, (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size, seq_len_q, attn_mask, query.device + ) + indices_k = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() + + return _VarlenPackedInputs( + query=query.flatten(0, 1), + key=_padded_to_unpad(key, indices_k), + value=_padded_to_unpad(value, indices_k), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_k=max_seqlen_k, + batch_size=batch_size, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + indices_k=indices_k, + ) + + def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: """ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in @@ -1092,8 +1169,6 @@ def _flash_attention_forward_op( _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): - if attn_mask is not None: - raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") @@ -1111,6 +1186,63 @@ def _flash_attention_forward_op( if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 + if attn_mask is not None: + if return_lse: + raise NotImplementedError("`return_lse=True` with `attn_mask` is not yet supported for flash-attn 2.") + + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + packed = _pack_qkv(query, key, value, attn_mask_2d) + + with torch.set_grad_enabled(grad_enabled): + out_packed, lse, _, rng_state = _wrapped_flash_attn_varlen_forward( + packed.query, + packed.key, + packed.value, + packed.cu_seqlens_q, + packed.cu_seqlens_k, + packed.seq_len_q, + packed.max_seqlen_k, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + + out = packed.unpack(out_packed) + + if _save_ctx: + ctx.save_for_backward( + packed.query, + packed.key, + packed.value, + out_packed, + lse, + rng_state, + packed.cu_seqlens_q, + packed.cu_seqlens_k, + packed.indices_k, + ) + ctx.is_varlen_masked = True + ctx.max_seqlen_k = packed.max_seqlen_k + ctx.batch_size = batch_size + ctx.seq_len_q = seq_len_q + ctx.seq_len_kv = seq_len_kv + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + return out + with torch.set_grad_enabled(grad_enabled): out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( query, @@ -1146,6 +1278,60 @@ def _flash_attention_backward_op( *args, **kwargs, ): + if getattr(ctx, "is_varlen_masked", False): + ( + query_packed, + key_packed, + value_packed, + out_packed, + lse, + rng_state, + cu_seqlens_q, + cu_seqlens_k, + indices_k, + ) = ctx.saved_tensors + + grad_out_packed = grad_out.flatten(0, 1) + + dq = torch.empty_like(query_packed) + dk = torch.empty_like(key_packed) + dv = torch.empty_like(value_packed) + + _wrapped_flash_attn_varlen_backward( # noqa: F841 + grad_out_packed, + query_packed, + key_packed, + value_packed, + out_packed, + lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.seq_len_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + grad_query = dq.view(ctx.batch_size, ctx.seq_len_q, *dq.shape[1:]) + grad_key = _unpad_to_padded(dk, indices_k, ctx.batch_size, ctx.seq_len_kv) + grad_value = _unpad_to_padded(dv, indices_k, ctx.batch_size, ctx.seq_len_kv) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + query, key, value, out, lse, rng_state = ctx.saved_tensors grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) @@ -2325,27 +2511,55 @@ def _flash_attention( _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for flash-attn 2.") - if _parallel_config is None: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - if return_lse: - out, lse, *_ = out + if attn_mask is None: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + if return_lse: + raise NotImplementedError( + "`return_lse=True` with `attn_mask` is not yet supported for the FLASH backend." + ) + batch_size, _, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + packed = _pack_qkv(query, key, value, attn_mask_2d) + + out_packed = flash_attn_varlen_func( + q=packed.query, + k=packed.key, + v=packed.value, + cu_seqlens_q=packed.cu_seqlens_q, + cu_seqlens_k=packed.cu_seqlens_k, + max_seqlen_q=packed.seq_len_q, + max_seqlen_k=packed.max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out_packed, lse, *_ = out_packed + + out = packed.unpack(out_packed) else: + if attn_mask is not None and _parallel_config.context_parallel_config.ring_degree > 1: + raise NotImplementedError("`attn_mask` is not yet supported for flash-attn 2 with ring attention.") + out = _templated_context_parallel_attention( query, key, value, - None, + attn_mask, dropout_p, is_causal, scale, diff --git a/tests/others/test_varlen_pack_helpers.py b/tests/others/test_varlen_pack_helpers.py new file mode 100644 index 000000000000..316fbfa46e94 --- /dev/null +++ b/tests/others/test_varlen_pack_helpers.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +import torch.nn.functional as F + +from diffusers.models.attention_dispatch import AttentionBackendName, _pack_qkv, dispatch_attention_fn + + +# A mask with non-contiguous valid tokens (gaps in the middle of each row). +# Row 0: positions 0-2 valid, 3-4 invalid, 5-9 valid → 8 valid tokens +# Row 1: position 0 valid, 1-3 invalid, 4-9 valid → 7 valid tokens +_NON_PREFIX_MASK = torch.tensor( + [ + [True, True, True, False, False, True, True, True, True, True], + [True, False, False, False, True, True, True, True, True, True], + ], + dtype=torch.bool, +) + + +def _make_qkv(batch_size, seq_len, num_heads, head_dim, dtype=torch.float32): + """Return reproducible (batch_size, seq_len, num_heads, head_dim) Q/K/V tensors.""" + g = torch.Generator().manual_seed(42) + q = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) + return q, k, v + + +class TestPackQkv(unittest.TestCase): + """_pack_qkv: shapes, cu_seqlens, and round-trip via unpack().""" + + def test_kv_packed_q_full_length(self): + """attn_mask is a KV-validity mask: K/V are packed, Q is kept at full length.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 + q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) + + num_valid_tokens = int(_NON_PREFIX_MASK.sum()) + # K and V must be packed down to valid tokens only + self.assertEqual(packed.key.shape, (num_valid_tokens, num_heads, head_dim)) + self.assertEqual(packed.value.shape, (num_valid_tokens, num_heads, head_dim)) + # Q must remain full-length (flattened but not filtered) + self.assertEqual(packed.query.shape, (batch_size * seq_len, num_heads, head_dim)) + self.assertEqual(packed.seq_len_q, seq_len) + self.assertEqual(packed.cu_seqlens_q[-1].item(), batch_size * seq_len) + self.assertEqual(packed.cu_seqlens_k[-1].item(), num_valid_tokens) + + def test_cu_seqlens_reflect_valid_counts_not_positions(self): + """Non-prefix mask: cu_seqlens counts valid tokens per batch item, ignoring gaps.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 + q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) + + # Row 0 has 8 valid tokens; row 1 has 7 valid tokens (see _NON_PREFIX_MASK). + valid_per_item = _NON_PREFIX_MASK.sum(dim=-1) + self.assertEqual(packed.cu_seqlens_k[1].item(), valid_per_item[0].item()) + self.assertEqual(packed.cu_seqlens_k[2].item(), int(valid_per_item.sum())) + self.assertEqual(packed.max_seqlen_k, int(valid_per_item.max())) + + def test_unpack_reshapes_full_length_output(self): + """unpack() with indices_q=None just reshapes the flat output back to (batch_size, seq_len, ...).""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 + q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) + + # Q is not packed, so a fake attention output matching the full flattened Q shape + # should round-trip back to the original padded layout unchanged. + recovered = packed.unpack(packed.query) + self.assertEqual(recovered.shape, (batch_size, seq_len, num_heads, head_dim)) + self.assertTrue(torch.allclose(recovered, q)) + + def test_cross_attn_q_is_not_packed(self): + """Cross-attention (seq_q != seq_kv): Q remains full-length instead of being packed.""" + batch_size, seq_len_q, seq_len_kv, num_heads, head_dim = 2, 5, 10, 2, 8 + # _NON_PREFIX_MASK has shape (2, 10) and applies to KV tokens only + q = torch.randn(batch_size, seq_len_q, num_heads, head_dim) + k = torch.randn(batch_size, seq_len_kv, num_heads, head_dim) + v = torch.randn(batch_size, seq_len_kv, num_heads, head_dim) + + packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) + + self.assertEqual(packed.query.shape, (batch_size * seq_len_q, num_heads, head_dim)) + self.assertEqual(packed.seq_len_q, seq_len_q) + + +class TestDispatchAttentionWithMask(unittest.TestCase): + """dispatch_attention_fn must honour attn_mask for all supported mask shapes.""" + + def _sdpa_ref(self, q, k, v, bool_mask_2d): + """SDPA reference: converts a 2D bool mask to an additive float mask and runs SDPA.""" + # Additive mask convention: 0.0 for positions to attend to, -inf for positions to ignore. + additive_mask = torch.zeros_like(bool_mask_2d, dtype=q.dtype) + additive_mask = additive_mask.masked_fill(~bool_mask_2d, float("-inf")) + additive_mask = additive_mask[:, None, None, :] # (batch_size, 1, 1, seq_len_kv) + q, k, v = (t.permute(0, 2, 1, 3) for t in (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=additive_mask) + return out.permute(0, 2, 1, 3) + + def test_non_prefix_mask_matches_sdpa_reference(self): + """Non-prefix mask: NATIVE backend output must match SDPA reference.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 + q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + + ref = self._sdpa_ref(q, k, v, _NON_PREFIX_MASK) + out = dispatch_attention_fn(q, k, v, attn_mask=_NON_PREFIX_MASK, backend=AttentionBackendName.NATIVE) + + self.assertTrue(torch.allclose(ref, out, atol=1e-5), f"Max diff: {(ref - out).abs().max():.2e}") + + def test_all_valid_mask_equals_no_mask(self): + """All-True mask must produce the same output as passing no mask at all.""" + batch_size, seq_len, num_heads, head_dim = 2, 8, 2, 32 + q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + all_valid_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + + out_masked = dispatch_attention_fn(q, k, v, attn_mask=all_valid_mask, backend=AttentionBackendName.NATIVE) + out_no_mask = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.NATIVE) + + self.assertTrue(torch.allclose(out_masked, out_no_mask, atol=1e-6)) + + def test_4d_bool_mask_equivalent_to_2d(self): + """4D bool mask (batch_size, 1, 1, seq_len) must normalize to the same result as the 2D mask.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 + q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + + out_2d = dispatch_attention_fn(q, k, v, attn_mask=_NON_PREFIX_MASK, backend=AttentionBackendName.NATIVE) + out_4d = dispatch_attention_fn( + q, k, v, attn_mask=_NON_PREFIX_MASK[:, None, None, :], backend=AttentionBackendName.NATIVE + ) + + self.assertTrue(torch.allclose(out_2d, out_4d, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() From 2d12f46aa41400609f90fc1675d3fd48c24623b5 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 15 Apr 2026 18:41:47 +0800 Subject: [PATCH 02/15] fix test case --- ...elpers.py => test_flash_attention_mask.py} | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) rename tests/others/{test_varlen_pack_helpers.py => test_flash_attention_mask.py} (78%) diff --git a/tests/others/test_varlen_pack_helpers.py b/tests/others/test_flash_attention_mask.py similarity index 78% rename from tests/others/test_varlen_pack_helpers.py rename to tests/others/test_flash_attention_mask.py index 316fbfa46e94..d7a4b67cf8c9 100644 --- a/tests/others/test_varlen_pack_helpers.py +++ b/tests/others/test_flash_attention_mask.py @@ -17,7 +17,12 @@ import torch import torch.nn.functional as F -from diffusers.models.attention_dispatch import AttentionBackendName, _pack_qkv, dispatch_attention_fn +from diffusers.models.attention_dispatch import ( + _CAN_USE_FLASH_ATTN, + AttentionBackendName, + _pack_qkv, + dispatch_attention_fn, +) # A mask with non-contiguous valid tokens (gaps in the middle of each row). @@ -98,8 +103,9 @@ def test_cross_attn_q_is_not_packed(self): self.assertEqual(packed.seq_len_q, seq_len_q) -class TestDispatchAttentionWithMask(unittest.TestCase): - """dispatch_attention_fn must honour attn_mask for all supported mask shapes.""" +@unittest.skipUnless(_CAN_USE_FLASH_ATTN, "flash-attn is required for these tests") +class TestFlashAttentionWithMask(unittest.TestCase): + """Flash attention backend must produce results consistent with the SDPA reference when attn_mask is given.""" def _sdpa_ref(self, q, k, v, bool_mask_2d): """SDPA reference: converts a 2D bool mask to an additive float mask and runs SDPA.""" @@ -112,37 +118,46 @@ def _sdpa_ref(self, q, k, v, bool_mask_2d): return out.permute(0, 2, 1, 3) def test_non_prefix_mask_matches_sdpa_reference(self): - """Non-prefix mask: NATIVE backend output must match SDPA reference.""" + """Non-prefix mask: FLASH backend output must match SDPA reference.""" batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 - q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) + device = torch.device("cuda") + q, k, v = ( + t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) + ) + mask = _NON_PREFIX_MASK.to(device) - ref = self._sdpa_ref(q, k, v, _NON_PREFIX_MASK) - out = dispatch_attention_fn(q, k, v, attn_mask=_NON_PREFIX_MASK, backend=AttentionBackendName.NATIVE) + ref = self._sdpa_ref(q, k, v, mask) + out = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) - self.assertTrue(torch.allclose(ref, out, atol=1e-5), f"Max diff: {(ref - out).abs().max():.2e}") + self.assertTrue(torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}") def test_all_valid_mask_equals_no_mask(self): """All-True mask must produce the same output as passing no mask at all.""" batch_size, seq_len, num_heads, head_dim = 2, 8, 2, 32 - q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) - all_valid_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + device = torch.device("cuda") + q, k, v = ( + t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) + ) + all_valid_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) - out_masked = dispatch_attention_fn(q, k, v, attn_mask=all_valid_mask, backend=AttentionBackendName.NATIVE) - out_no_mask = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.NATIVE) + out_masked = dispatch_attention_fn(q, k, v, attn_mask=all_valid_mask, backend=AttentionBackendName.FLASH) + out_no_mask = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.FLASH) - self.assertTrue(torch.allclose(out_masked, out_no_mask, atol=1e-6)) + self.assertTrue(torch.allclose(out_masked, out_no_mask, atol=1e-3)) def test_4d_bool_mask_equivalent_to_2d(self): """4D bool mask (batch_size, 1, 1, seq_len) must normalize to the same result as the 2D mask.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 - q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) - - out_2d = dispatch_attention_fn(q, k, v, attn_mask=_NON_PREFIX_MASK, backend=AttentionBackendName.NATIVE) - out_4d = dispatch_attention_fn( - q, k, v, attn_mask=_NON_PREFIX_MASK[:, None, None, :], backend=AttentionBackendName.NATIVE + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 + device = torch.device("cuda") + q, k, v = ( + t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) ) + mask = _NON_PREFIX_MASK.to(device) + + out_2d = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) + out_4d = dispatch_attention_fn(q, k, v, attn_mask=mask[:, None, None, :], backend=AttentionBackendName.FLASH) - self.assertTrue(torch.allclose(out_2d, out_4d, atol=1e-6)) + self.assertTrue(torch.allclose(out_2d, out_4d, atol=1e-3)) if __name__ == "__main__": From 003fa34d9365759d77cb4286adc412e0308e4ab4 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 16 Apr 2026 10:25:53 +0800 Subject: [PATCH 03/15] refactor test --- tests/others/test_flash_attention.py | 98 +++++++++++++ tests/others/test_flash_attention_mask.py | 164 ---------------------- 2 files changed, 98 insertions(+), 164 deletions(-) create mode 100644 tests/others/test_flash_attention.py delete mode 100644 tests/others/test_flash_attention_mask.py diff --git a/tests/others/test_flash_attention.py b/tests/others/test_flash_attention.py new file mode 100644 index 000000000000..9dc141bae478 --- /dev/null +++ b/tests/others/test_flash_attention.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.nn.functional as F + +from diffusers.models.attention_dispatch import ( + _CAN_USE_FLASH_ATTN, + AttentionBackendName, + dispatch_attention_fn, +) + + +# A mask with non-contiguous valid tokens. +_NON_PREFIX_MASK = torch.tensor( + [ + [True, True, True, False, False, True, True, True, True, True], + [True, False, False, False, True, True, True, True, True, True], + ], + dtype=torch.bool, +) + + +def _make_qkv(batch_size, seq_len, num_heads, head_dim, dtype=torch.float32): + g = torch.Generator().manual_seed(42) + q = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) + return q, k, v + + +def _sdpa_ref(q, k, v, bool_mask_2d=None): + if bool_mask_2d is not None: + additive_mask = torch.zeros_like(bool_mask_2d, dtype=q.dtype) + additive_mask = additive_mask.masked_fill(~bool_mask_2d, float("-inf")) + additive_mask = additive_mask[:, None, None, :] # (batch_size, 1, 1, seq_len_kv) + else: + additive_mask = None + q, k, v = (t.permute(0, 2, 1, 3) for t in (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=additive_mask) + return out.permute(0, 2, 1, 3) + + +@pytest.mark.skipif(not _CAN_USE_FLASH_ATTN, reason="flash-attn is required for these tests") +class TestFlashAttention: + """Flash attention backend must produce results consistent with the SDPA reference when attn_mask is given.""" + + def test_no_mask_matches_sdpa_reference(self): + """FLASH backend output must match SDPA reference without any masking.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 + device = torch.device("cuda") + q, k, v = ( + t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) + ) + ref = _sdpa_ref(q, k, v) + out = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.FLASH) + + assert torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}" + + def test_mask_matches_sdpa_reference(self): + """FLASH backend output must match SDPA reference with attention mask.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 + device = torch.device("cuda") + q, k, v = ( + t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) + ) + mask = _NON_PREFIX_MASK.to(device) + + ref = _sdpa_ref(q, k, v, mask) + out = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) + + assert torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}" + + def test_4d_bool_mask_equivalent_to_2d(self): + """4D bool mask (batch_size, 1, 1, seq_len) must normalize to the same result as the 2D mask.""" + batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 + device = torch.device("cuda") + q, k, v = ( + t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) + ) + mask = _NON_PREFIX_MASK.to(device) + + out_2d = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) + out_4d = dispatch_attention_fn(q, k, v, attn_mask=mask[:, None, None, :], backend=AttentionBackendName.FLASH) + + assert torch.allclose(out_2d, out_4d, atol=1e-3) diff --git a/tests/others/test_flash_attention_mask.py b/tests/others/test_flash_attention_mask.py deleted file mode 100644 index d7a4b67cf8c9..000000000000 --- a/tests/others/test_flash_attention_mask.py +++ /dev/null @@ -1,164 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -import torch -import torch.nn.functional as F - -from diffusers.models.attention_dispatch import ( - _CAN_USE_FLASH_ATTN, - AttentionBackendName, - _pack_qkv, - dispatch_attention_fn, -) - - -# A mask with non-contiguous valid tokens (gaps in the middle of each row). -# Row 0: positions 0-2 valid, 3-4 invalid, 5-9 valid → 8 valid tokens -# Row 1: position 0 valid, 1-3 invalid, 4-9 valid → 7 valid tokens -_NON_PREFIX_MASK = torch.tensor( - [ - [True, True, True, False, False, True, True, True, True, True], - [True, False, False, False, True, True, True, True, True, True], - ], - dtype=torch.bool, -) - - -def _make_qkv(batch_size, seq_len, num_heads, head_dim, dtype=torch.float32): - """Return reproducible (batch_size, seq_len, num_heads, head_dim) Q/K/V tensors.""" - g = torch.Generator().manual_seed(42) - q = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) - k = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) - v = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) - return q, k, v - - -class TestPackQkv(unittest.TestCase): - """_pack_qkv: shapes, cu_seqlens, and round-trip via unpack().""" - - def test_kv_packed_q_full_length(self): - """attn_mask is a KV-validity mask: K/V are packed, Q is kept at full length.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 - q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) - packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) - - num_valid_tokens = int(_NON_PREFIX_MASK.sum()) - # K and V must be packed down to valid tokens only - self.assertEqual(packed.key.shape, (num_valid_tokens, num_heads, head_dim)) - self.assertEqual(packed.value.shape, (num_valid_tokens, num_heads, head_dim)) - # Q must remain full-length (flattened but not filtered) - self.assertEqual(packed.query.shape, (batch_size * seq_len, num_heads, head_dim)) - self.assertEqual(packed.seq_len_q, seq_len) - self.assertEqual(packed.cu_seqlens_q[-1].item(), batch_size * seq_len) - self.assertEqual(packed.cu_seqlens_k[-1].item(), num_valid_tokens) - - def test_cu_seqlens_reflect_valid_counts_not_positions(self): - """Non-prefix mask: cu_seqlens counts valid tokens per batch item, ignoring gaps.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 - q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) - packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) - - # Row 0 has 8 valid tokens; row 1 has 7 valid tokens (see _NON_PREFIX_MASK). - valid_per_item = _NON_PREFIX_MASK.sum(dim=-1) - self.assertEqual(packed.cu_seqlens_k[1].item(), valid_per_item[0].item()) - self.assertEqual(packed.cu_seqlens_k[2].item(), int(valid_per_item.sum())) - self.assertEqual(packed.max_seqlen_k, int(valid_per_item.max())) - - def test_unpack_reshapes_full_length_output(self): - """unpack() with indices_q=None just reshapes the flat output back to (batch_size, seq_len, ...).""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 16 - q, k, v = _make_qkv(batch_size, seq_len, num_heads, head_dim) - packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) - - # Q is not packed, so a fake attention output matching the full flattened Q shape - # should round-trip back to the original padded layout unchanged. - recovered = packed.unpack(packed.query) - self.assertEqual(recovered.shape, (batch_size, seq_len, num_heads, head_dim)) - self.assertTrue(torch.allclose(recovered, q)) - - def test_cross_attn_q_is_not_packed(self): - """Cross-attention (seq_q != seq_kv): Q remains full-length instead of being packed.""" - batch_size, seq_len_q, seq_len_kv, num_heads, head_dim = 2, 5, 10, 2, 8 - # _NON_PREFIX_MASK has shape (2, 10) and applies to KV tokens only - q = torch.randn(batch_size, seq_len_q, num_heads, head_dim) - k = torch.randn(batch_size, seq_len_kv, num_heads, head_dim) - v = torch.randn(batch_size, seq_len_kv, num_heads, head_dim) - - packed = _pack_qkv(q, k, v, _NON_PREFIX_MASK) - - self.assertEqual(packed.query.shape, (batch_size * seq_len_q, num_heads, head_dim)) - self.assertEqual(packed.seq_len_q, seq_len_q) - - -@unittest.skipUnless(_CAN_USE_FLASH_ATTN, "flash-attn is required for these tests") -class TestFlashAttentionWithMask(unittest.TestCase): - """Flash attention backend must produce results consistent with the SDPA reference when attn_mask is given.""" - - def _sdpa_ref(self, q, k, v, bool_mask_2d): - """SDPA reference: converts a 2D bool mask to an additive float mask and runs SDPA.""" - # Additive mask convention: 0.0 for positions to attend to, -inf for positions to ignore. - additive_mask = torch.zeros_like(bool_mask_2d, dtype=q.dtype) - additive_mask = additive_mask.masked_fill(~bool_mask_2d, float("-inf")) - additive_mask = additive_mask[:, None, None, :] # (batch_size, 1, 1, seq_len_kv) - q, k, v = (t.permute(0, 2, 1, 3) for t in (q, k, v)) - out = F.scaled_dot_product_attention(q, k, v, attn_mask=additive_mask) - return out.permute(0, 2, 1, 3) - - def test_non_prefix_mask_matches_sdpa_reference(self): - """Non-prefix mask: FLASH backend output must match SDPA reference.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 - device = torch.device("cuda") - q, k, v = ( - t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) - ) - mask = _NON_PREFIX_MASK.to(device) - - ref = self._sdpa_ref(q, k, v, mask) - out = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) - - self.assertTrue(torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}") - - def test_all_valid_mask_equals_no_mask(self): - """All-True mask must produce the same output as passing no mask at all.""" - batch_size, seq_len, num_heads, head_dim = 2, 8, 2, 32 - device = torch.device("cuda") - q, k, v = ( - t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) - ) - all_valid_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) - - out_masked = dispatch_attention_fn(q, k, v, attn_mask=all_valid_mask, backend=AttentionBackendName.FLASH) - out_no_mask = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.FLASH) - - self.assertTrue(torch.allclose(out_masked, out_no_mask, atol=1e-3)) - - def test_4d_bool_mask_equivalent_to_2d(self): - """4D bool mask (batch_size, 1, 1, seq_len) must normalize to the same result as the 2D mask.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 - device = torch.device("cuda") - q, k, v = ( - t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) - ) - mask = _NON_PREFIX_MASK.to(device) - - out_2d = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) - out_4d = dispatch_attention_fn(q, k, v, attn_mask=mask[:, None, None, :], backend=AttentionBackendName.FLASH) - - self.assertTrue(torch.allclose(out_2d, out_4d, atol=1e-3)) - - -if __name__ == "__main__": - unittest.main() From bc61551bf62a96c172bbd0df23f0aa0770821791 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 16 Apr 2026 10:59:24 +0800 Subject: [PATCH 04/15] add protection --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index df105741963f..118a0bb5f4a7 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2496,7 +2496,7 @@ def _templated_context_parallel_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_attn_mask_or_causal, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention( From 86fec43029d72cf2afa3b8a7186bb28c16b780d0 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 16 Apr 2026 11:09:30 +0800 Subject: [PATCH 05/15] fix comment --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 118a0bb5f4a7..11837c0799ea 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -689,7 +689,7 @@ def _pack_qkv( value: torch.Tensor, attn_mask: torch.Tensor, ) -> _VarlenPackedInputs: - """Pack Q/K/V tensors by removing padding tokens identified by *attn_mask*.""" + """Pack K/V tensors by removing padding tokens identified by *attn_mask*.""" batch_size = query.shape[0] seq_len_q = query.shape[1] seq_len_kv = key.shape[1] From 5034b2be73973d6c6dafe4159abb26d34737004a Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 16 Apr 2026 15:04:37 +0800 Subject: [PATCH 06/15] update according to suggestion --- src/diffusers/models/attention_dispatch.py | 479 +++++++++++---------- tests/others/test_flash_attention.py | 98 ----- 2 files changed, 240 insertions(+), 337 deletions(-) delete mode 100644 tests/others/test_flash_attention.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 11837c0799ea..74f88c7874d2 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -78,12 +78,7 @@ if _CAN_USE_FLASH_ATTN: try: from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import ( - _wrapped_flash_attn_backward, - _wrapped_flash_attn_forward, - _wrapped_flash_attn_varlen_backward, - _wrapped_flash_attn_varlen_forward, - ) + from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward except (ImportError, OSError, RuntimeError) as e: # Handle ABI mismatch or other import failures gracefully. # This can happen when flash_attn was compiled against a different PyTorch version. @@ -93,15 +88,11 @@ flash_attn_varlen_func = None _wrapped_flash_attn_backward = None _wrapped_flash_attn_forward = None - _wrapped_flash_attn_varlen_backward = None - _wrapped_flash_attn_varlen_forward = None else: flash_attn_func = None flash_attn_varlen_func = None _wrapped_flash_attn_backward = None _wrapped_flash_attn_forward = None - _wrapped_flash_attn_varlen_backward = None - _wrapped_flash_attn_varlen_forward = None if _CAN_USE_FLASH_ATTN_3: @@ -335,6 +326,13 @@ class _HubKernelConfig: wrapped_backward_attr: str | None = None wrapped_forward_fn: Callable | None = None wrapped_backward_fn: Callable | None = None + # Some backends (e.g. flash attention) have separate kernels for variable-length inputs + varlen_function_attr: str | None = None + varlen_kernel_fn: Callable | None = None + wrapped_varlen_forward_attr: str | None = None + wrapped_varlen_backward_attr: str | None = None + wrapped_varlen_forward_fn: Callable | None = None + wrapped_varlen_backward_fn: Callable | None = None # Registry for hub-based attention kernels @@ -354,8 +352,11 @@ class _HubKernelConfig: AttentionBackendName.FLASH_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", + varlen_function_attr="flash_attn_varlen_func", wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", + wrapped_varlen_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", + wrapped_varlen_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", version=1, ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( @@ -657,62 +658,6 @@ def _unpad_to_padded(packed: torch.Tensor, indices: torch.Tensor, batch_size: in return output.view(batch_size, seq_len, *packed.shape[1:]) -@dataclass -class _VarlenPackedInputs: - """Inputs for varlen attention kernels: packed (unpadded) K/V, full-length Q, and KV index metadata.""" - - # tensors: query is full-length (flattened), key/value are packed (unpadded) - query: torch.Tensor - key: torch.Tensor - value: torch.Tensor - - # cumulative sequence lengths for K (derived from attn_mask) - cu_seqlens_q: torch.Tensor # (batch_size + 1,) — uniform stride of seq_len_q - cu_seqlens_k: torch.Tensor # (batch_size + 1,) - max_seqlen_k: int - - # shape metadata for unpacking outputs - batch_size: int - seq_len_q: int - seq_len_kv: int - - # flat indices of valid KV tokens in the (batch * seq_kv) dimension. - indices_k: torch.Tensor - - def unpack(self, packed_out: torch.Tensor) -> torch.Tensor: - return packed_out.view(self.batch_size, self.seq_len_q, *packed_out.shape[1:]) - - -def _pack_qkv( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor, -) -> _VarlenPackedInputs: - """Pack K/V tensors by removing padding tokens identified by *attn_mask*.""" - batch_size = query.shape[0] - seq_len_q = query.shape[1] - seq_len_kv = key.shape[1] - - _, (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen_with_mask( - batch_size, seq_len_q, attn_mask, query.device - ) - indices_k = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() - - return _VarlenPackedInputs( - query=query.flatten(0, 1), - key=_padded_to_unpad(key, indices_k), - value=_padded_to_unpad(value, indices_k), - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - batch_size=batch_size, - seq_len_q=seq_len_q, - seq_len_kv=seq_len_kv, - indices_k=indices_k, - ) - - def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: """ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in @@ -785,10 +730,24 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: config = _HUB_KERNELS_REGISTRY[backend] needs_kernel = config.kernel_fn is None + needs_varlen_kernel = config.varlen_function_attr is not None and config.varlen_kernel_fn is None needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None + needs_wrapped_varlen_forward = ( + config.wrapped_varlen_forward_attr is not None and config.wrapped_varlen_forward_fn is None + ) + needs_wrapped_varlen_backward = ( + config.wrapped_varlen_backward_attr is not None and config.wrapped_varlen_backward_fn is None + ) - if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward): + if not ( + needs_kernel + or needs_varlen_kernel + or needs_wrapped_forward + or needs_wrapped_backward + or needs_wrapped_varlen_forward + or needs_wrapped_varlen_backward + ): return try: @@ -798,12 +757,23 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: if needs_kernel: config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) + if needs_varlen_kernel: + config.varlen_kernel_fn = _resolve_kernel_attr(kernel_module, config.varlen_function_attr) + if needs_wrapped_forward: config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) if needs_wrapped_backward: config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr) + if needs_wrapped_varlen_forward: + config.wrapped_varlen_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_varlen_forward_attr) + + if needs_wrapped_varlen_backward: + config.wrapped_varlen_backward_fn = _resolve_kernel_attr( + kernel_module, config.wrapped_varlen_backward_attr + ) + except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") raise @@ -1186,63 +1156,6 @@ def _flash_attention_forward_op( if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 - if attn_mask is not None: - if return_lse: - raise NotImplementedError("`return_lse=True` with `attn_mask` is not yet supported for flash-attn 2.") - - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - packed = _pack_qkv(query, key, value, attn_mask_2d) - - with torch.set_grad_enabled(grad_enabled): - out_packed, lse, _, rng_state = _wrapped_flash_attn_varlen_forward( - packed.query, - packed.key, - packed.value, - packed.cu_seqlens_q, - packed.cu_seqlens_k, - packed.seq_len_q, - packed.max_seqlen_k, - dropout_p, - scale, - is_causal, - window_size[0], - window_size[1], - softcap, - alibi_slopes, - return_lse, - ) - - out = packed.unpack(out_packed) - - if _save_ctx: - ctx.save_for_backward( - packed.query, - packed.key, - packed.value, - out_packed, - lse, - rng_state, - packed.cu_seqlens_q, - packed.cu_seqlens_k, - packed.indices_k, - ) - ctx.is_varlen_masked = True - ctx.max_seqlen_k = packed.max_seqlen_k - ctx.batch_size = batch_size - ctx.seq_len_q = seq_len_q - ctx.seq_len_kv = seq_len_kv - ctx.dropout_p = dropout_p - ctx.scale = scale - ctx.is_causal = is_causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - return out - with torch.set_grad_enabled(grad_enabled): out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( query, @@ -1278,60 +1191,6 @@ def _flash_attention_backward_op( *args, **kwargs, ): - if getattr(ctx, "is_varlen_masked", False): - ( - query_packed, - key_packed, - value_packed, - out_packed, - lse, - rng_state, - cu_seqlens_q, - cu_seqlens_k, - indices_k, - ) = ctx.saved_tensors - - grad_out_packed = grad_out.flatten(0, 1) - - dq = torch.empty_like(query_packed) - dk = torch.empty_like(key_packed) - dv = torch.empty_like(value_packed) - - _wrapped_flash_attn_varlen_backward( # noqa: F841 - grad_out_packed, - query_packed, - key_packed, - value_packed, - out_packed, - lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.seq_len_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.scale, - ctx.is_causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state, - ) - - grad_query = dq.view(ctx.batch_size, ctx.seq_len_q, *dq.shape[1:]) - grad_key = _unpad_to_padded(dk, indices_k, ctx.batch_size, ctx.seq_len_kv) - grad_value = _unpad_to_padded(dv, indices_k, ctx.batch_size, ctx.seq_len_kv) - - grad_query = grad_query[..., : grad_out.shape[-1]] - grad_key = grad_key[..., : grad_out.shape[-1]] - grad_value = grad_value[..., : grad_out.shape[-1]] - - return grad_query, grad_key, grad_value - query, key, value, out, lse, rng_state = ctx.saved_tensors grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) @@ -1378,8 +1237,6 @@ def _flash_attention_hub_forward_op( _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): - if attn_mask is not None: - raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.") @@ -1404,6 +1261,78 @@ def _flash_attention_hub_forward_op( if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 + if attn_mask is not None: + if return_lse: + raise NotImplementedError( + "`return_lse=True` with `attn_mask` is not yet supported for flash-attn hub kernels." + ) + + wrapped_varlen_forward_fn = config.wrapped_varlen_forward_fn + if wrapped_varlen_forward_fn is None: + raise RuntimeError( + "Flash attention hub kernels must expose `_wrapped_flash_attn_varlen_forward` for masked attention." + ) + + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + _, (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size, seq_len_q, attn_mask_2d, query.device + ) + indices_k = torch.nonzero(attn_mask_2d.flatten(), as_tuple=False).flatten() + query_packed = query.flatten(0, 1) + key_packed = _padded_to_unpad(key, indices_k) + value_packed = _padded_to_unpad(value, indices_k) + + with torch.set_grad_enabled(grad_enabled): + out_packed, lse, _, rng_state = wrapped_varlen_forward_fn( + query_packed, + key_packed, + value_packed, + cu_seqlens_q, + cu_seqlens_k, + seq_len_q, + max_seqlen_k, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + + out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) + + if _save_ctx: + ctx.save_for_backward( + query_packed, + key_packed, + value_packed, + out_packed, + lse, + rng_state, + cu_seqlens_q, + cu_seqlens_k, + indices_k, + ) + ctx.is_varlen_masked = True + ctx.max_seqlen_k = max_seqlen_k + ctx.batch_size = batch_size + ctx.seq_len_q = seq_len_q + ctx.seq_len_kv = seq_len_kv + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + return out + with torch.set_grad_enabled(grad_enabled): out, lse, S_dmask, rng_state = wrapped_forward_fn( query, @@ -1440,6 +1369,67 @@ def _flash_attention_hub_backward_op( **kwargs, ): config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] + + if getattr(ctx, "is_varlen_masked", False): + wrapped_varlen_backward_fn = config.wrapped_varlen_backward_fn + if wrapped_varlen_backward_fn is None: + raise RuntimeError( + "Flash attention hub kernels must expose `_wrapped_flash_attn_varlen_backward` for masked attention." + ) + + ( + query_packed, + key_packed, + value_packed, + out_packed, + lse, + rng_state, + cu_seqlens_q, + cu_seqlens_k, + indices_k, + ) = ctx.saved_tensors + + grad_out_packed = grad_out.flatten(0, 1) + + dq = torch.empty_like(query_packed) + dk = torch.empty_like(key_packed) + dv = torch.empty_like(value_packed) + + wrapped_varlen_backward_fn( # noqa: F841 + grad_out_packed, + query_packed, + key_packed, + value_packed, + out_packed, + lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.seq_len_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + grad_query = dq.view(ctx.batch_size, ctx.seq_len_q, *dq.shape[1:]) + grad_key = _unpad_to_padded(dk, indices_k, ctx.batch_size, ctx.seq_len_kv) + grad_value = _unpad_to_padded(dv, indices_k, ctx.batch_size, ctx.seq_len_kv) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + wrapped_backward_fn = config.wrapped_backward_fn if wrapped_backward_fn is None: raise RuntimeError( @@ -2496,7 +2486,7 @@ def _templated_context_parallel_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH, - constraints=[_check_attn_mask_or_causal, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention( @@ -2511,55 +2501,27 @@ def _flash_attention( _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None - if _parallel_config is None: - if attn_mask is None: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - if return_lse: - out, lse, *_ = out - else: - if return_lse: - raise NotImplementedError( - "`return_lse=True` with `attn_mask` is not yet supported for the FLASH backend." - ) - batch_size, _, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - packed = _pack_qkv(query, key, value, attn_mask_2d) - - out_packed = flash_attn_varlen_func( - q=packed.query, - k=packed.key, - v=packed.value, - cu_seqlens_q=packed.cu_seqlens_q, - cu_seqlens_k=packed.cu_seqlens_k, - max_seqlen_q=packed.seq_len_q, - max_seqlen_k=packed.max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - if return_lse: - out_packed, lse, *_ = out_packed + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") - out = packed.unpack(out_packed) + if _parallel_config is None: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out else: - if attn_mask is not None and _parallel_config.context_parallel_config.ring_degree > 1: - raise NotImplementedError("`attn_mask` is not yet supported for flash-attn 2 with ring attention.") - out = _templated_context_parallel_attention( query, key, value, - attn_mask, + None, dropout_p, is_causal, scale, @@ -2577,7 +2539,7 @@ def _flash_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_HUB, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_attn_mask_or_causal, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention_hub( @@ -2592,28 +2554,67 @@ def _flash_attention_hub( _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for flash-attn 2.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn if _parallel_config is None: - out = func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - if return_lse: - out, lse, *_ = out + if attn_mask is None: + out = func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + if return_lse: + raise NotImplementedError( + "`return_lse=True` with `attn_mask` is not yet supported for the FLASH_HUB backend." + ) + batch_size, _, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + _, (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size, query.shape[1], attn_mask_2d, query.device + ) + indices_k = torch.nonzero(attn_mask_2d.flatten(), as_tuple=False).flatten() + query_packed = query.flatten(0, 1) + key_packed = _padded_to_unpad(key, indices_k) + value_packed = _padded_to_unpad(value, indices_k) + + varlen_func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].varlen_kernel_fn + out_packed = varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=query.shape[1], + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out_packed, lse, *_ = out_packed + + out = out_packed.view(batch_size, query.shape[1], *out_packed.shape[1:]) else: + if attn_mask is not None and _parallel_config.context_parallel_config.ring_degree > 1: + raise NotImplementedError( + "`attn_mask` is not yet supported for flash-attn hub kernels with ring attention." + ) + out = _templated_context_parallel_attention( query, key, value, - None, + attn_mask, dropout_p, is_causal, scale, diff --git a/tests/others/test_flash_attention.py b/tests/others/test_flash_attention.py deleted file mode 100644 index 9dc141bae478..000000000000 --- a/tests/others/test_flash_attention.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch -import torch.nn.functional as F - -from diffusers.models.attention_dispatch import ( - _CAN_USE_FLASH_ATTN, - AttentionBackendName, - dispatch_attention_fn, -) - - -# A mask with non-contiguous valid tokens. -_NON_PREFIX_MASK = torch.tensor( - [ - [True, True, True, False, False, True, True, True, True, True], - [True, False, False, False, True, True, True, True, True, True], - ], - dtype=torch.bool, -) - - -def _make_qkv(batch_size, seq_len, num_heads, head_dim, dtype=torch.float32): - g = torch.Generator().manual_seed(42) - q = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) - k = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) - v = torch.randn(batch_size, seq_len, num_heads, head_dim, generator=g, dtype=dtype) - return q, k, v - - -def _sdpa_ref(q, k, v, bool_mask_2d=None): - if bool_mask_2d is not None: - additive_mask = torch.zeros_like(bool_mask_2d, dtype=q.dtype) - additive_mask = additive_mask.masked_fill(~bool_mask_2d, float("-inf")) - additive_mask = additive_mask[:, None, None, :] # (batch_size, 1, 1, seq_len_kv) - else: - additive_mask = None - q, k, v = (t.permute(0, 2, 1, 3) for t in (q, k, v)) - out = F.scaled_dot_product_attention(q, k, v, attn_mask=additive_mask) - return out.permute(0, 2, 1, 3) - - -@pytest.mark.skipif(not _CAN_USE_FLASH_ATTN, reason="flash-attn is required for these tests") -class TestFlashAttention: - """Flash attention backend must produce results consistent with the SDPA reference when attn_mask is given.""" - - def test_no_mask_matches_sdpa_reference(self): - """FLASH backend output must match SDPA reference without any masking.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 - device = torch.device("cuda") - q, k, v = ( - t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) - ) - ref = _sdpa_ref(q, k, v) - out = dispatch_attention_fn(q, k, v, attn_mask=None, backend=AttentionBackendName.FLASH) - - assert torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}" - - def test_mask_matches_sdpa_reference(self): - """FLASH backend output must match SDPA reference with attention mask.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 - device = torch.device("cuda") - q, k, v = ( - t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) - ) - mask = _NON_PREFIX_MASK.to(device) - - ref = _sdpa_ref(q, k, v, mask) - out = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) - - assert torch.allclose(ref, out, atol=1e-2), f"Max diff: {(ref - out).abs().max():.2e}" - - def test_4d_bool_mask_equivalent_to_2d(self): - """4D bool mask (batch_size, 1, 1, seq_len) must normalize to the same result as the 2D mask.""" - batch_size, seq_len, num_heads, head_dim = 2, 10, 2, 32 - device = torch.device("cuda") - q, k, v = ( - t.to(device=device, dtype=torch.float16) for t in _make_qkv(batch_size, seq_len, num_heads, head_dim) - ) - mask = _NON_PREFIX_MASK.to(device) - - out_2d = dispatch_attention_fn(q, k, v, attn_mask=mask, backend=AttentionBackendName.FLASH) - out_4d = dispatch_attention_fn(q, k, v, attn_mask=mask[:, None, None, :], backend=AttentionBackendName.FLASH) - - assert torch.allclose(out_2d, out_4d, atol=1e-3) From e05bb28a76476612865919ed868473feeb861664 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 16 Apr 2026 15:10:46 +0800 Subject: [PATCH 07/15] revert change --- src/diffusers/models/attention_dispatch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 74f88c7874d2..e329c5b9b0cb 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1139,6 +1139,8 @@ def _flash_attention_forward_op( _save_ctx: bool = True, _parallel_config: "ParallelConfig" | None = None, ): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") From 534fdc1dbcfb86013102ff65ccb3aa592b90b34e Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 17 Apr 2026 09:46:37 +0800 Subject: [PATCH 08/15] fix according to claude review --- src/diffusers/models/attention_dispatch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e329c5b9b0cb..95b706a9e8a3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2589,6 +2589,10 @@ def _flash_attention_hub( value_packed = _padded_to_unpad(value, indices_k) varlen_func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].varlen_kernel_fn + if varlen_func is None: + raise RuntimeError( + "Flash attention hub kernels must expose `flash_attn_varlen_func` for masked attention." + ) out_packed = varlen_func( q=query_packed, k=key_packed, @@ -2602,8 +2606,6 @@ def _flash_attention_hub( causal=is_causal, return_attn_probs=return_lse, ) - if return_lse: - out_packed, lse, *_ = out_packed out = out_packed.view(batch_size, query.shape[1], *out_packed.shape[1:]) else: From 1cd670b8f40376f9735ab36a0d3404656b861ee6 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 27 Apr 2026 14:37:25 +0800 Subject: [PATCH 09/15] add test converage for QwenImage --- .../test_models_transformer_qwenimage.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 7933aa98f3f2..199cb3a3600a 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -21,11 +21,12 @@ from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, torch_device +from ...testing_utils import enable_full_determinism, is_kernels_available, torch_device from ..testing_utils import ( AttentionTesterMixin, BaseModelTesterConfig, BitsAndBytesTesterMixin, + ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin, LoraHotSwappingForModelTesterMixin, LoraTesterMixin, @@ -279,6 +280,37 @@ class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, """Context Parallel inference tests for QwenImage Transformer.""" +class TestQwenImageTransformerContextParallelAttnBackends( + QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin +): + """Context Parallel inference x attention backends tests for QwenImage Transformer""" + + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) + @pytest.mark.parametrize( + "attention_backend", + [ + "native", + pytest.param( + "flash_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), + pytest.param( + "_flash_3_hub", + marks=[ + pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + pytest.mark.xfail(reason="`attn_mask` is not supported for flash-attn 3.", strict=True), + ], + ), + ], + ) + @pytest.mark.parametrize("ulysses_anything", [True, False]) + @torch.no_grad() + def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything): + if cp_type == "ring_degree" and attention_backend in ("flash_hub", "_flash_3_hub"): + pytest.xfail("`attn_mask` is not yet supported for flash-attn hub kernels with ring attention.") + super().test_context_parallel_attn_backend_inference(cp_type, attention_backend, ulysses_anything) + + class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): """LoRA adapter tests for QwenImage Transformer.""" From 99e1660d24ac1a24fcf11176dc37c4c4cd7df500 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 4 May 2026 17:28:27 +0800 Subject: [PATCH 10/15] add SP support and fix non-contiguous mask for flash_varlen kernel --- src/diffusers/models/attention_dispatch.py | 395 ++++++++++-------- tests/models/testing_utils/parallelism.py | 11 + tests/models/testing_utils/utils.py | 1 + .../test_models_transformer_qwenimage.py | 28 +- tests/others/test_attention_backends.py | 9 + 5 files changed, 239 insertions(+), 205 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 396a16fdee59..2431e23ca8a1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -326,13 +326,6 @@ class _HubKernelConfig: wrapped_backward_attr: str | None = None wrapped_forward_fn: Callable | None = None wrapped_backward_fn: Callable | None = None - # Some backends (e.g. flash attention) have separate kernels for variable-length inputs - varlen_function_attr: str | None = None - varlen_kernel_fn: Callable | None = None - wrapped_varlen_forward_attr: str | None = None - wrapped_varlen_backward_attr: str | None = None - wrapped_varlen_forward_fn: Callable | None = None - wrapped_varlen_backward_fn: Callable | None = None # Registry for hub-based attention kernels @@ -352,16 +345,15 @@ class _HubKernelConfig: AttentionBackendName.FLASH_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", - varlen_function_attr="flash_attn_varlen_func", wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", - wrapped_varlen_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", - wrapped_varlen_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", version=1, ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", + wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", + wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( @@ -730,24 +722,10 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: config = _HUB_KERNELS_REGISTRY[backend] needs_kernel = config.kernel_fn is None - needs_varlen_kernel = config.varlen_function_attr is not None and config.varlen_kernel_fn is None needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None - needs_wrapped_varlen_forward = ( - config.wrapped_varlen_forward_attr is not None and config.wrapped_varlen_forward_fn is None - ) - needs_wrapped_varlen_backward = ( - config.wrapped_varlen_backward_attr is not None and config.wrapped_varlen_backward_fn is None - ) - if not ( - needs_kernel - or needs_varlen_kernel - or needs_wrapped_forward - or needs_wrapped_backward - or needs_wrapped_varlen_forward - or needs_wrapped_varlen_backward - ): + if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward): return try: @@ -757,23 +735,12 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: if needs_kernel: config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) - if needs_varlen_kernel: - config.varlen_kernel_fn = _resolve_kernel_attr(kernel_module, config.varlen_function_attr) - if needs_wrapped_forward: config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) if needs_wrapped_backward: config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr) - if needs_wrapped_varlen_forward: - config.wrapped_varlen_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_varlen_forward_attr) - - if needs_wrapped_varlen_backward: - config.wrapped_varlen_backward_fn = _resolve_kernel_attr( - kernel_module, config.wrapped_varlen_backward_attr - ) - except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") raise @@ -1241,6 +1208,8 @@ def _flash_attention_hub_forward_op( *, window_size: tuple[int, int] = (-1, -1), ): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.") if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.") @@ -1264,78 +1233,6 @@ def _flash_attention_hub_forward_op( if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): dropout_p = dropout_p if dropout_p > 0 else 1e-30 - if attn_mask is not None: - if return_lse: - raise NotImplementedError( - "`return_lse=True` with `attn_mask` is not yet supported for flash-attn hub kernels." - ) - - wrapped_varlen_forward_fn = config.wrapped_varlen_forward_fn - if wrapped_varlen_forward_fn is None: - raise RuntimeError( - "Flash attention hub kernels must expose `_wrapped_flash_attn_varlen_forward` for masked attention." - ) - - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - _, (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen_with_mask( - batch_size, seq_len_q, attn_mask_2d, query.device - ) - indices_k = torch.nonzero(attn_mask_2d.flatten(), as_tuple=False).flatten() - query_packed = query.flatten(0, 1) - key_packed = _padded_to_unpad(key, indices_k) - value_packed = _padded_to_unpad(value, indices_k) - - with torch.set_grad_enabled(grad_enabled): - out_packed, lse, _, rng_state = wrapped_varlen_forward_fn( - query_packed, - key_packed, - value_packed, - cu_seqlens_q, - cu_seqlens_k, - seq_len_q, - max_seqlen_k, - dropout_p, - scale, - is_causal, - window_size[0], - window_size[1], - softcap, - alibi_slopes, - return_lse, - ) - - out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) - - if _save_ctx: - ctx.save_for_backward( - query_packed, - key_packed, - value_packed, - out_packed, - lse, - rng_state, - cu_seqlens_q, - cu_seqlens_k, - indices_k, - ) - ctx.is_varlen_masked = True - ctx.max_seqlen_k = max_seqlen_k - ctx.batch_size = batch_size - ctx.seq_len_q = seq_len_q - ctx.seq_len_kv = seq_len_kv - ctx.dropout_p = dropout_p - ctx.scale = scale - ctx.is_causal = is_causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - return out - with torch.set_grad_enabled(grad_enabled): out, lse, S_dmask, rng_state = wrapped_forward_fn( query, @@ -1372,86 +1269,188 @@ def _flash_attention_hub_backward_op( **kwargs, ): config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution." + ) - if getattr(ctx, "is_varlen_masked", False): - wrapped_varlen_backward_fn = config.wrapped_varlen_backward_fn - if wrapped_varlen_backward_fn is None: - raise RuntimeError( - "Flash attention hub kernels must expose `_wrapped_flash_attn_varlen_backward` for masked attention." - ) + query, key, value, out, lse, rng_state = ctx.saved_tensors + grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) - ( - query_packed, - key_packed, - value_packed, - out_packed, - lse, - rng_state, - cu_seqlens_q, - cu_seqlens_k, - indices_k, - ) = ctx.saved_tensors + _ = wrapped_backward_fn( + grad_out, + query, + key, + value, + out, + lse, + grad_query, + grad_key, + grad_value, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + + +def _flash_varlen_attention_hub_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: "ParallelConfig" | None = None, + *, + window_size: tuple[int, int] = (-1, -1), +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn varlen hub kernels.") + + config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB] + wrapped_forward_fn = config.wrapped_forward_fn + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_forward_fn is None or wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and " + "`_wrapped_flash_attn_varlen_backward` for context parallel execution." + ) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + softcap = 0.0 + alibi_slopes = None + deterministic = False + grad_enabled = any(x.requires_grad for x in (query, key, value)) - grad_out_packed = grad_out.flatten(0, 1) + if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 - dq = torch.empty_like(query_packed) - dk = torch.empty_like(key_packed) - dv = torch.empty_like(value_packed) + batch_size, seq_len_q, num_heads, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) + ) + indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() + query_packed = query.flatten(0, 1) + key_packed = _padded_to_unpad(key, indices_k) + value_packed = _padded_to_unpad(value, indices_k) + max_seqlen_q = seq_len_q + else: + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + query_packed = query.flatten(0, 1) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) + seqlens_k = None - wrapped_varlen_backward_fn( # noqa: F841 - grad_out_packed, + with torch.set_grad_enabled(grad_enabled): + out_packed, lse, _, rng_state = wrapped_forward_fn( query_packed, key_packed, value_packed, - out_packed, - lse, - dq, - dk, - dv, cu_seqlens_q, cu_seqlens_k, - ctx.seq_len_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.scale, - ctx.is_causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state, + max_seqlen_q, + max_seqlen_k, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + + out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) + + if _save_ctx: + ctx.save_for_backward( + query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k ) + ctx.seqlens_k = seqlens_k # None if unmasked + ctx.indices_k = indices_k if attn_mask is not None else None + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.batch_size = batch_size + ctx.seq_len_q = seq_len_q + ctx.seq_len_kv = seq_len_kv + ctx.num_heads = num_heads + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic - grad_query = dq.view(ctx.batch_size, ctx.seq_len_q, *dq.shape[1:]) - grad_key = _unpad_to_padded(dk, indices_k, ctx.batch_size, ctx.seq_len_kv) - grad_value = _unpad_to_padded(dv, indices_k, ctx.batch_size, ctx.seq_len_kv) + # (num_heads, batch_size * seq_len_q) -> (batch_size, seq_len_q, num_heads) + lse_sp = lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous() - grad_query = grad_query[..., : grad_out.shape[-1]] - grad_key = grad_key[..., : grad_out.shape[-1]] - grad_value = grad_value[..., : grad_out.shape[-1]] + return (out, lse_sp) if return_lse else out - return grad_query, grad_key, grad_value +def _flash_varlen_attention_hub_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB] wrapped_backward_fn = config.wrapped_backward_fn if wrapped_backward_fn is None: raise RuntimeError( - "Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution." + "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` " + "for context parallel execution." ) - query, key, value, out, lse, rng_state = ctx.saved_tensors - grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + + grad_out_packed = grad_out.flatten(0, 1) + grad_query, grad_key, grad_value = ( + torch.empty_like(query_packed), + torch.empty_like(key_packed), + torch.empty_like(value_packed), + ) _ = wrapped_backward_fn( - grad_out, - query, - key, - value, - out, + grad_out_packed, + query_packed, + key_packed, + value_packed, + out_packed, lse, grad_query, grad_key, grad_value, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, ctx.dropout_p, ctx.scale, ctx.is_causal, @@ -1463,6 +1462,15 @@ def _flash_attention_hub_backward_op( rng_state, ) + grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:]) + + if ctx.seqlens_k is not None: + grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) + grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) + else: + grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:]) + grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:]) + grad_query = grad_query[..., : grad_out.shape[-1]] grad_key = grad_key[..., : grad_out.shape[-1]] grad_value = grad_value[..., : grad_out.shape[-1]] @@ -2677,7 +2685,7 @@ def _flash_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_HUB, - constraints=[_check_attn_mask_or_causal, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], supports_context_parallel=True, ) def _flash_attention_hub( @@ -2693,6 +2701,8 @@ def _flash_attention_hub( _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn if _parallel_config is None: @@ -2714,7 +2724,7 @@ def _flash_attention_hub( query, key, value, - attn_mask, + None, dropout_p, is_causal, scale, @@ -2733,7 +2743,7 @@ def _flash_attention_hub( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=False, + supports_context_parallel=True, ) def _flash_varlen_attention_hub( query: torch.Tensor, @@ -2747,27 +2757,48 @@ def _flash_varlen_attention_hub( return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: + lse = None batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + if _parallel_config is not None: + if _parallel_config.context_parallel_config.ring_degree > 1: + raise NotImplementedError("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend.") + forward_op = functools.partial(_flash_varlen_attention_hub_forward_op, window_size=window_size) + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + False, + return_lse, + forward_op=forward_op, + backward_op=_flash_varlen_attention_hub_backward_op, + _parallel_config=_parallel_config, ) - ) + if return_lse: + out, lse = out + return (out, lse) if return_lse else out - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) + if attn_mask is not None: + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) + ) + indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() + key_packed = _padded_to_unpad(key, indices_k) + value_packed = _padded_to_unpad(value, indices_k) + else: + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn out = func( @@ -2784,9 +2815,13 @@ def _flash_varlen_attention_hub( window_size=window_size, return_attn_probs=return_lse, ) + if return_lse: + out, lse, *_ = out + else: + out = out out = out.unflatten(0, (batch_size, -1)) - return out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index f88d404f8c5e..e7f54954daad 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -374,6 +374,8 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) @is_context_parallel @require_torch_multi_accelerator class ContextParallelAttentionBackendsTesterMixin: + unsupported_attn_backends: list[str] = [] + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) @pytest.mark.parametrize( "attention_backend", @@ -383,6 +385,10 @@ class ContextParallelAttentionBackendsTesterMixin: "flash_hub", marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), ), + pytest.param( + "flash_varlen_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), pytest.param( "_flash_3_hub", marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), @@ -398,9 +404,14 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen if getattr(self.model_class, "_cp_plan", None) is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if attention_backend in self.unsupported_attn_backends: + pytest.xfail(f"{attention_backend} is not supported for this model.") + if cp_type == "ring_degree": if attention_backend == AttentionBackendName.NATIVE: pytest.skip("Skipping test because ring isn't supported with native attention backend.") + elif attention_backend in ("flash_varlen_hub"): + pytest.xfail("`ring_degree` is not yet supported for varlen attention hub kernels.") if ulysses_anything and "ulysses" not in cp_type: pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") diff --git a/tests/models/testing_utils/utils.py b/tests/models/testing_utils/utils.py index 7bec37db2496..eda02a79c315 100644 --- a/tests/models/testing_utils/utils.py +++ b/tests/models/testing_utils/utils.py @@ -6,6 +6,7 @@ _BF16_REQUIRED_BACKENDS = { AttentionBackendName._NATIVE_CUDNN, AttentionBackendName.FLASH_HUB, + AttentionBackendName.FLASH_VARLEN_HUB, AttentionBackendName._FLASH_3_HUB, } diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 199cb3a3600a..60f830e8b31a 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -21,7 +21,7 @@ from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, is_kernels_available, torch_device +from ...testing_utils import enable_full_determinism, torch_device from ..testing_utils import ( AttentionTesterMixin, BaseModelTesterConfig, @@ -285,30 +285,8 @@ class TestQwenImageTransformerContextParallelAttnBackends( ): """Context Parallel inference x attention backends tests for QwenImage Transformer""" - @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) - @pytest.mark.parametrize( - "attention_backend", - [ - "native", - pytest.param( - "flash_hub", - marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), - ), - pytest.param( - "_flash_3_hub", - marks=[ - pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), - pytest.mark.xfail(reason="`attn_mask` is not supported for flash-attn 3.", strict=True), - ], - ), - ], - ) - @pytest.mark.parametrize("ulysses_anything", [True, False]) - @torch.no_grad() - def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything): - if cp_type == "ring_degree" and attention_backend in ("flash_hub", "_flash_3_hub"): - pytest.xfail("`attn_mask` is not yet supported for flash-attn hub kernels with ring attention.") - super().test_context_parallel_attn_backend_inference(cp_type, attention_backend, ulysses_anything) + # flash_hub and _flash_3_hub do not support attn_mask + unsupported_attn_backends = ["flash_hub", "_flash_3_hub"] class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 01f4521c5adc..fe338ab8b7d7 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -38,6 +38,10 @@ "flash_hub", torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16) ), + ( + "flash_varlen_hub", + torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), + ), ( "_flash_3_hub", torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), @@ -62,6 +66,11 @@ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), True ), + ( + "flash_varlen_hub", + torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), + True + ), ( "_flash_3_hub", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), From 3d8cbf4e8f76ddb71a7465b6e6882da68486bebc Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 4 May 2026 17:49:37 +0800 Subject: [PATCH 11/15] revert change --- tests/others/test_attention_backends.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index fe338ab8b7d7..01f4521c5adc 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -38,10 +38,6 @@ "flash_hub", torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16) ), - ( - "flash_varlen_hub", - torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), - ), ( "_flash_3_hub", torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), @@ -66,11 +62,6 @@ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), True ), - ( - "flash_varlen_hub", - torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), - True - ), ( "_flash_3_hub", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), From 1b39db42bc8e0e89c6211ffd621f5c1ea78185a0 Mon Sep 17 00:00:00 2001 From: Cheung Ka Wai Date: Wed, 6 May 2026 14:44:09 +0800 Subject: [PATCH 12/15] Update tests/models/testing_utils/parallelism.py Co-authored-by: Sayak Paul --- tests/models/testing_utils/parallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index e7f54954daad..8f776010fbf7 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -411,7 +411,7 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen if attention_backend == AttentionBackendName.NATIVE: pytest.skip("Skipping test because ring isn't supported with native attention backend.") elif attention_backend in ("flash_varlen_hub"): - pytest.xfail("`ring_degree` is not yet supported for varlen attention hub kernels.") + pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.") if ulysses_anything and "ulysses" not in cp_type: pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") From 04a1bf507712888b685b7348cc52a9e08ba30352 Mon Sep 17 00:00:00 2001 From: Cheung Ka Wai Date: Wed, 6 May 2026 14:44:28 +0800 Subject: [PATCH 13/15] Update tests/models/testing_utils/parallelism.py Co-authored-by: Sayak Paul --- tests/models/testing_utils/parallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 8f776010fbf7..d4f5e99d6763 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -405,7 +405,7 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") if attention_backend in self.unsupported_attn_backends: - pytest.xfail(f"{attention_backend} is not supported for this model.") + pytest.skip(f"{attention_backend} is not supported for this model.") if cp_type == "ring_degree": if attention_backend == AttentionBackendName.NATIVE: From 37a6db5491dc4b41bd7f3bb6e8e1ee003be835cf Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 6 May 2026 14:11:47 +0800 Subject: [PATCH 14/15] drop `_padded_to_unpad` --- src/diffusers/models/attention_dispatch.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 2431e23ca8a1..5a1274860460 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -638,11 +638,6 @@ def _prepare_for_flash_attn_or_sage_varlen( return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) -def _padded_to_unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - """gather valid tokens from a padded `(batch, seq, ...)` tensor into a packed `(nnz, ...)` tensor.""" - return tensor.reshape(-1, *tensor.shape[2:])[indices] - - def _unpad_to_padded(packed: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: """scatter a packed `(nnz, ...)` tensor back to padded `(batch_size, seq_len, ...)`.""" output = torch.zeros(batch_size * seq_len, *packed.shape[1:], dtype=packed.dtype, device=packed.device) @@ -1355,8 +1350,8 @@ def _flash_varlen_attention_hub_forward_op( ) indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() query_packed = query.flatten(0, 1) - key_packed = _padded_to_unpad(key, indices_k) - value_packed = _padded_to_unpad(value, indices_k) + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] max_seqlen_q = seq_len_q else: (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( @@ -2789,8 +2784,8 @@ def _flash_varlen_attention_hub( _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) ) indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() - key_packed = _padded_to_unpad(key, indices_k) - value_packed = _padded_to_unpad(value, indices_k) + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] else: (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) From 849062ac3f578699ef0cf49e3db589ff07504be3 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 6 May 2026 14:16:52 +0800 Subject: [PATCH 15/15] follow `if _parallel_config is None` pattern --- src/diffusers/models/attention_dispatch.py | 77 +++++++++++----------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5a1274860460..5990cda9b8cd 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2756,7 +2756,43 @@ def _flash_varlen_attention_hub( batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if _parallel_config is not None: + if _parallel_config is None: + if attn_mask is not None: + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) + ) + indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] + else: + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) + + query_packed = query.flatten(0, 1) + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn + out = func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + out = out.unflatten(0, (batch_size, -1)) + else: if _parallel_config.context_parallel_config.ring_degree > 1: raise NotImplementedError("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend.") forward_op = functools.partial(_flash_varlen_attention_hub_forward_op, window_size=window_size) @@ -2776,45 +2812,6 @@ def _flash_varlen_attention_hub( ) if return_lse: out, lse = out - return (out, lse) if return_lse else out - - if attn_mask is not None: - attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) - ) - indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() - key_packed = key.reshape(-1, *key.shape[2:])[indices_k] - value_packed = value.reshape(-1, *value.shape[2:])[indices_k] - else: - (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) - ) - key_packed = key.flatten(0, 1) - value_packed = value.flatten(0, 1) - - query_packed = query.flatten(0, 1) - - func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn - out = func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - window_size=window_size, - return_attn_probs=return_lse, - ) - if return_lse: - out, lse, *_ = out - else: - out = out - out = out.unflatten(0, (batch_size, -1)) return (out, lse) if return_lse else out