diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d3114dd0753e..5990cda9b8cd 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -352,6 +352,8 @@ class _HubKernelConfig: AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", + wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward", + wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward", version=1, ), AttentionBackendName.SAGE_HUB: _HubKernelConfig( @@ -636,6 +638,13 @@ def _prepare_for_flash_attn_or_sage_varlen( return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) +def _unpad_to_padded(packed: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + """scatter a packed `(nnz, ...)` tensor back to padded `(batch_size, seq_len, ...)`.""" + output = torch.zeros(batch_size * seq_len, *packed.shape[1:], dtype=packed.dtype, device=packed.device) + output[indices] = packed + return output.view(batch_size, seq_len, *packed.shape[1:]) + + def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: """ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in @@ -1292,6 +1301,178 @@ def _flash_attention_hub_backward_op( return grad_query, grad_key, grad_value +def _flash_varlen_attention_hub_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: "ParallelConfig" | None = None, + *, + window_size: tuple[int, int] = (-1, -1), +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn varlen hub kernels.") + + config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB] + wrapped_forward_fn = config.wrapped_forward_fn + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_forward_fn is None or wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and " + "`_wrapped_flash_attn_varlen_backward` for context parallel execution." + ) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + softcap = 0.0 + alibi_slopes = None + deterministic = False + grad_enabled = any(x.requires_grad for x in (query, key, value)) + + if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + batch_size, seq_len_q, num_heads, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) + ) + indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() + query_packed = query.flatten(0, 1) + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] + max_seqlen_q = seq_len_q + else: + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + query_packed = query.flatten(0, 1) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) + seqlens_k = None + + with torch.set_grad_enabled(grad_enabled): + out_packed, lse, _, rng_state = wrapped_forward_fn( + query_packed, + key_packed, + value_packed, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + + out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) + + if _save_ctx: + ctx.save_for_backward( + query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k + ) + ctx.seqlens_k = seqlens_k # None if unmasked + ctx.indices_k = indices_k if attn_mask is not None else None + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.batch_size = batch_size + ctx.seq_len_q = seq_len_q + ctx.seq_len_kv = seq_len_kv + ctx.num_heads = num_heads + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + # (num_heads, batch_size * seq_len_q) -> (batch_size, seq_len_q, num_heads) + lse_sp = lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous() + + return (out, lse_sp) if return_lse else out + + +def _flash_varlen_attention_hub_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB] + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` " + "for context parallel execution." + ) + + query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + + grad_out_packed = grad_out.flatten(0, 1) + grad_query, grad_key, grad_value = ( + torch.empty_like(query_packed), + torch.empty_like(key_packed), + torch.empty_like(value_packed), + ) + + _ = wrapped_backward_fn( + grad_out_packed, + query_packed, + key_packed, + value_packed, + out_packed, + lse, + grad_query, + grad_key, + grad_value, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:]) + + if ctx.seqlens_k is not None: + grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) + grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) + else: + grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:]) + grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:]) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + + def _flash_attention_3_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -2557,7 +2738,7 @@ def _flash_attention_hub( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=False, + supports_context_parallel=True, ) def _flash_varlen_attention_hub( query: torch.Tensor, @@ -2571,46 +2752,68 @@ def _flash_varlen_attention_hub( return_lse: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: + lse = None batch_size, seq_len_q, _, _ = query.shape _, seq_len_kv, _, _ = key.shape - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) + if _parallel_config is None: + if attn_mask is not None: + attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) + ) + indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() + key_packed = key.reshape(-1, *key.shape[2:])[indices_k] + value_packed = value.reshape(-1, *value.shape[2:])[indices_k] + else: + (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) + ) + key_packed = key.flatten(0, 1) + value_packed = value.flatten(0, 1) - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) + query_packed = query.flatten(0, 1) - func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn - out = func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - window_size=window_size, - return_attn_probs=return_lse, - ) - out = out.unflatten(0, (batch_size, -1)) + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn + out = func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + out = out.unflatten(0, (batch_size, -1)) + else: + if _parallel_config.context_parallel_config.ring_degree > 1: + raise NotImplementedError("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend.") + forward_op = functools.partial(_flash_varlen_attention_hub_forward_op, window_size=window_size) + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + False, + return_lse, + forward_op=forward_op, + backward_op=_flash_varlen_attention_hub_backward_op, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out - return out + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index f88d404f8c5e..d4f5e99d6763 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -374,6 +374,8 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) @is_context_parallel @require_torch_multi_accelerator class ContextParallelAttentionBackendsTesterMixin: + unsupported_attn_backends: list[str] = [] + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) @pytest.mark.parametrize( "attention_backend", @@ -383,6 +385,10 @@ class ContextParallelAttentionBackendsTesterMixin: "flash_hub", marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), ), + pytest.param( + "flash_varlen_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), pytest.param( "_flash_3_hub", marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), @@ -398,9 +404,14 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen if getattr(self.model_class, "_cp_plan", None) is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + if attention_backend in self.unsupported_attn_backends: + pytest.skip(f"{attention_backend} is not supported for this model.") + if cp_type == "ring_degree": if attention_backend == AttentionBackendName.NATIVE: pytest.skip("Skipping test because ring isn't supported with native attention backend.") + elif attention_backend in ("flash_varlen_hub"): + pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.") if ulysses_anything and "ulysses" not in cp_type: pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") diff --git a/tests/models/testing_utils/utils.py b/tests/models/testing_utils/utils.py index 7bec37db2496..eda02a79c315 100644 --- a/tests/models/testing_utils/utils.py +++ b/tests/models/testing_utils/utils.py @@ -6,6 +6,7 @@ _BF16_REQUIRED_BACKENDS = { AttentionBackendName._NATIVE_CUDNN, AttentionBackendName.FLASH_HUB, + AttentionBackendName.FLASH_VARLEN_HUB, AttentionBackendName._FLASH_3_HUB, } diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 516850c4a281..061f61fa36ef 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -25,6 +25,7 @@ AttentionTesterMixin, BaseModelTesterConfig, BitsAndBytesTesterMixin, + ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin, LoraHotSwappingForModelTesterMixin, LoraTesterMixin, @@ -253,6 +254,15 @@ class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, """Context Parallel inference tests for QwenImage Transformer.""" +class TestQwenImageTransformerContextParallelAttnBackends( + QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin +): + """Context Parallel inference x attention backends tests for QwenImage Transformer""" + + # flash_hub and _flash_3_hub do not support attn_mask + unsupported_attn_backends = ["flash_hub", "_flash_3_hub"] + + class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): """LoRA adapter tests for QwenImage Transformer."""