diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..ec836fc3af 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -31,6 +31,7 @@ def generate_input_shapes( config: ModelConfig, world_size: int, kernel_backend: str, + pad_between_seqs: str = "False", ): if qkv_format == "bshd": q_input_shape = ( @@ -99,9 +100,9 @@ def generate_input_shapes( ).cuda() cu_seqlens_q = torch.clone(cu_seqlens_q_padded) - # Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does, - # cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only. - if kernel_backend == "FusedAttention": + # When pad_between_seqs is True, or for FusedAttention, cu_seqlens_q reflects + # non-padded (actual) lengths. FA3 handles this via seqused_q/seqused_k. + if kernel_backend == "FusedAttention" or pad_between_seqs == "True": cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() # NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded` @@ -180,6 +181,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + pad_between_seqs="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" @@ -275,7 +277,7 @@ def run_dpa_with_cp( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, pad_between_seqs) q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() @@ -494,6 +496,7 @@ def run_dpa_with_cp( # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + tensor_names = ["out", "dq", "dk", "dv", "dbias", "out_", "dq_", "dk_", "dv_", "dbias_"] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): @@ -502,11 +505,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + for tensor, name in zip(tensors, tensor_names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + assert torch.all(~torch.isnan(tensor)), f"{name} has nan values" + assert torch.all(~torch.isinf(tensor)), f"{name} has inf values" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -559,13 +562,24 @@ def run_dpa_with_cp( if is_training: dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + out_ = out_.clone() cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + # FA3 forward doesn't zero padding positions in output; + # zero them in out_ (reference) so comparison is valid. + if pad_between_seqs == "True": + out_[cu_seqlens_q_padded[-1] :] = 0.0 + for b in range(config.batch_size): + if num_pads_q[b] > 0: + out_[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] + ] = 0.0 for x in [dq, out, dq_, out_]: assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 for b in range(config.batch_size): diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2eb307aa48..929f2b5326 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -124,7 +124,7 @@ def reset_global_fp8_state(): @pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) -@pytest.mark.parametrize("pad_between_seqs", [False]) +@pytest.mark.parametrize("pad_between_seqs", [False, True]) def test_dot_product_attention( dtype, model_configs, @@ -157,6 +157,8 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if pad_between_seqs and qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") if qkv_format == "thd" and "padding" not in config.attn_mask_type: config.attn_mask_type = ( "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" @@ -195,8 +197,10 @@ def test_dot_product_attention( ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention - # mannually pads and unpads the input and output of FlashAttention for testing purposes + # FA3 natively supports pad_between_seqs via seqused_q/seqused_k. + # FA2 does not support pad_between_seqs, but _run_dot_product_attention + # manually pads and unpads the input and output of FlashAttention for testing purposes. + # Flash Attention is not supported on SM100+ if ( pad_between_seqs and FlashAttentionUtils.is_installed @@ -205,6 +209,7 @@ def test_dot_product_attention( and config.attn_mask_type in ["causal", "padding_causal"] ) and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) + and get_device_compute_capability() < (10, 0) ): flash_attn_supported = True @@ -1197,12 +1202,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: block.softmax_offset.requires_grad = True # Run a forward and backward pass - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: q = inp_orig[0] k = inp_orig[1] v = inp_orig[2] d_out = out_grad_orig - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: q = inp[0] k = inp[1] v = inp[2] @@ -1218,14 +1223,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=config.max_seqlen_kv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, - cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, + cu_seqlens_q_padded=( + cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), + cu_seqlens_kv_padded=( + cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, fast_zero_fill=True, + pad_between_seqs=pad_between_seqs, # Only pass num_splits when exercising the FlashAttention path num_splits=config.num_splits if backend == "FlashAttention" else 1, ) @@ -1239,12 +1249,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: if is_training: return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: return out, max_logit, (None, None, None, d_softmax_offset) - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) if is_training: diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index ecd0090a3b..98f8e56490 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -85,11 +85,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("pad_between_seqs", [False, True]) +def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + if pad_between_seqs: + if qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") + if not FlashAttentionUtils.v3_is_installed: + pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") + if cp_comm_type == "a2a+p2p": + pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type @@ -133,6 +142,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + pad_between_seqs=pad_between_seqs, log_level=pytest_logging_level, ), check=True, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 442366035a..8a610ff187 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -738,10 +738,13 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + pad_between_seqs: Optional[bool] = False, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, num_splits: Optional[int] = 1, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -919,6 +922,7 @@ def forward( use_flash_attn_3 = False if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True + if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -935,8 +939,16 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q if qkv_format == "thd" else None, - cu_seqlens_kv if qkv_format == "thd" else None, + ( + cu_seqlens_q_padded + if pad_between_seqs + else (cu_seqlens_q if qkv_format == "thd" else None) + ), + ( + cu_seqlens_kv_padded + if pad_between_seqs + else (cu_seqlens_kv if qkv_format == "thd" else None) + ), self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -948,7 +960,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, - pad_between_seqs=False, + pad_between_seqs=pad_between_seqs, use_flash_attn_3=use_flash_attn_3, fp8_output=fp8_output, ) @@ -984,8 +996,12 @@ def forward( else: func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment if not use_flash_attn_3 or inference_params is None: - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) + if pad_between_seqs and use_flash_attn_3: + fa_optional_forward_args_thd.append(cu_seqlens_q_padded) + fa_optional_forward_args_thd.append(cu_seqlens_kv_padded) + else: + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if not use_flash_attn_3: @@ -1019,6 +1035,13 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["num_splits"] = num_splits + if pad_between_seqs: + fa_3_optional_forward_kwargs["seqused_q"] = ( + cu_seqlens_q[1:] - cu_seqlens_q[:-1] + ) + fa_3_optional_forward_kwargs["seqused_k"] = ( + cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + ) if inference_params is None: fa_3_optional_forward_kwargs["deterministic"] = self.deterministic else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..dc17d1dfc5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -572,6 +572,8 @@ def get_fa_args( dq=None, dk=None, dv=None, + seqused_q=None, + seqused_k=None, ): """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: @@ -581,7 +583,9 @@ def get_fa_args( *[None] * 4, # k_new, v_new, qv, out cu_seqlens_q, cu_seqlens_kv, - *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + None, # cu_seqlens_k_new + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, *[None] @@ -599,8 +603,8 @@ def get_fa_args( return [ cu_seqlens_q, cu_seqlens_kv, - None, # sequed_q - None, # sequed_k + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, dq, @@ -610,8 +614,8 @@ def get_fa_args( return [ None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sequed_q - None, # sequed_k + None, # seqused_q + None, # seqused_k max_seqlen_q, max_seqlen_kv, dq, @@ -917,6 +921,9 @@ def cp_p2p_fwd_flash_attn( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -943,6 +950,20 @@ def cp_p2p_fwd_flash_attn( fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + seqused_q = None + seqused_k = None + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + # Derive actual token counts per batch element from cu_seqlens + seqused_q = cu_seqlens_q_per_step[1:] - cu_seqlens_q_per_step[:-1] + seqused_k = cu_seqlens_kv_per_step[1:] - cu_seqlens_kv_per_step[:-1] + # Override cu_seqlens to padded layout for tensor memory layout + cu_seqlens_q_ = cu_seqlens_q_padded + cu_seqlens_kv_ = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_ = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_ = cu_seqlens_q_padded // 2 + fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -951,6 +972,8 @@ def cp_p2p_fwd_flash_attn( cu_seqlens_kv=cu_seqlens_kv_, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -1180,6 +1203,9 @@ def cp_p2p_bwd_flash_attn( rng_states, softmax_lse, softmax_lse_, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1188,7 +1214,10 @@ def cp_p2p_bwd_flash_attn( section, ): """Per-tile backward call of CP P2P with FlashAttention backend""" - dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: @@ -1213,17 +1242,33 @@ def cp_p2p_bwd_flash_attn( max_seqlen_q_ = max_seqlen_q // 2 softmax_lse__ = softmax_lse_ + seqused_q = None + seqused_k = None + cu_seqlens_q_bwd = cu_seqlens_q_per_step[cp_size - step - 1] + cu_seqlens_kv_bwd = cu_seqlens_kv_per_step[cp_size - step - 1] + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q_bwd[1:] - cu_seqlens_q_bwd[:-1] + seqused_k = cu_seqlens_kv_bwd[1:] - cu_seqlens_kv_bwd[:-1] + cu_seqlens_q_bwd = cu_seqlens_q_padded + cu_seqlens_kv_bwd = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_bwd = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_bwd = cu_seqlens_q_padded // 2 + fa_backward_args_thd = get_fa_args( False, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + cu_seqlens_q=cu_seqlens_q_bwd, + cu_seqlens_kv=cu_seqlens_kv_bwd, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if use_flash_attn_3: fa_backward_kwargs["is_causal"] = causal_ @@ -1643,6 +1688,9 @@ def forward( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # cp_size = 4: @@ -1685,7 +1733,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) elif i <= rank: @@ -1712,7 +1762,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1739,7 +1791,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1764,7 +1818,11 @@ def forward( ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( - cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, + *prepare_outputs, + section, + ) ) # softmax_lse correction @@ -2015,6 +2073,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8_meta = fp8_meta @@ -2414,6 +2473,9 @@ def backward(ctx, dout, *_args): rng_states, softmax_lse, softmax_lse_, + ctx.pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, @@ -2429,7 +2491,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) elif i >= (cp_size - rank - 1): section = "lower-triangle" @@ -2440,7 +2504,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "upper-triangle" @@ -2451,7 +2517,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "all" @@ -2462,7 +2530,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) # dq, dk, dv are reduced across steps in higher precision @@ -3335,6 +3405,7 @@ def forward( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, @@ -3527,14 +3598,25 @@ def forward( ): out_part = O_quantizer(out_) else: + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -3639,6 +3721,7 @@ def forward( ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_type = softmax_type ctx.dQKV_quantizer = dQKV_quantizer @@ -3817,18 +3900,32 @@ def backward(ctx, dout, *_args): dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors - dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + if ctx.pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q, k, v]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if ctx.pad_between_seqs and ctx.use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state @@ -3923,6 +4020,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, d_softmax_offset, None, ) @@ -4141,6 +4239,7 @@ def attn_forward_func_with_cp( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..b9bb1a3e28 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1469,10 +1469,13 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, inference_params=inference_params, flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, num_splits=num_splits, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if use_fused_attention: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..914dae4a3c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -549,7 +549,7 @@ def get_attention_backend( # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 - # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | % 256 == 0 # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 @@ -588,9 +588,9 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if inference_params.is_paged: - if use_flash_attention_2 and inference_params.page_size < 256: + if use_flash_attention_2 and inference_params.page_size % 256 != 0: if FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 for page size < 256") + logger.debug("Disabling FlashAttention 2 for page size not divisible by 256") use_flash_attention_2 = False if use_flash_attention_2: if not FlashAttentionUtils.is_installed: @@ -600,6 +600,16 @@ def get_attention_backend( "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" ) use_flash_attention_2 = False + else: + # Non-paged KV cache still passes a block_table to FA2 for thd_2bshd support, + # and FA2 enforces page_size % 256 == 0 on the effective page size (max_seqlen_kv). + if use_flash_attention_2 and max_seqlen_kv % 256 != 0: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for non-paged KV cache" + " with max_seqlen_kv not divisible by 256" + ) + use_flash_attention_2 = False # Filter: Head dimension if head_dim_qk != head_dim_v: @@ -686,14 +696,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # Filter: QKV layout if qkv_format == "thd": if pad_between_seqs: - if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( - use_flash_attention_3 and FlashAttentionUtils.v3_is_installed - ): + if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " + "Disabling FlashAttention 2 for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) - use_flash_attention = False + use_flash_attention_2 = False + # FA3 supports pad_between_seqs via seqused_q/seqused_k if device_compute_capability == (12, 0): if cudnn_version < (9, 18, 1): if use_fused_attention: @@ -1070,6 +1079,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "please install flash-attn >= 2.4.1." ) use_flash_attention_2 = False + if use_flash_attention_3 and deterministic: + if head_dim_qk > 128: + logger.warning( + "Disabling FlashAttention 3 for deterministic execution with head_dim_qk > 128" + ) + use_flash_attention_3 = False if use_fused_attention and deterministic: if softmax_type != "vanilla": logger.debug(