From 1ececdc47947050985ec3fa790b66ea529303720 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 10 Mar 2026 13:41:52 -0700 Subject: [PATCH 01/11] [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD ## TLDR Enable `pad_between_seqs=True` for A2A and P2P context parallelism comm types with FlashAttention 3 and THD format. Previously `pad_between_seqs` was only supported with FusedAttention. ## Problem When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With `pad_between_seqs=True`, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this via `cu_seqlens_q_padded`, but FlashAttention (both FA2 and FA3) had `pad_between_seqs` hardcoded to `False` in the CP path, and FA2 was entirely disabled for `pad_between_seqs + thd`. FA3 can natively handle this via its `seqused_q`/`seqused_k` mechanism. ## Solution Use FA3's `seqused_q`/`seqused_k` tensors to communicate actual token counts per batch element. Pass `cu_seqlens_q_padded` for tensor memory layout while deriving `seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]` from the real `cu_seqlens`. ## Changes ### context_parallel.py - `get_fa_args()`: Add `seqused_q`/`seqused_k` parameters, pass through to FA3 forward and backward positional arg lists (replacing hardcoded `None`s). - `cp_p2p_fwd_flash_attn()` / `cp_p2p_bwd_flash_attn()`: Accept `pad_between_seqs`, `cu_seqlens_q_padded`, `cu_seqlens_kv_padded`. When enabled, derive `seqused` tensors and override `cu_seqlens` to padded versions (with half-padding for lower-triangle/upper-triangle sections). - `AttnFuncWithCPAndKVP2P`: Thread `pad_between_seqs` and padded cu_seqlens through all forward/backward `cp_p2p_fwd/bwd_flash_attn` call sites. Save `ctx.pad_between_seqs` for backward. - `AttnFuncWithCPAndQKVOA2A.forward()`: Add `pad_between_seqs` parameter. When enabled with FA3+THD, derive `seqused` and swap `cu_seqlens` for padded versions before calling `get_fa_args()`. - `AttnFuncWithCPAndQKVOA2A.backward()`: Same seqused/cu_seqlens override. Use `zeros_like` (not `empty_like`) for gradient init when `pad_between_seqs` since FA3 skips padding positions. Add extra `None` in return tuple for the new `pad_between_seqs` gradient slot. - `attn_forward_func_with_cp()`: Pass `pad_between_seqs` in A2A args list. ### backends.py - `FlashAttention.forward()`: Accept `cu_seqlens_q_padded`/`cu_seqlens_kv_padded`. Detect `pad_between_seqs` by comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and pass `seqused_q`/`seqused_k`. ### dot_product_attention.py - Pass `cu_seqlens_q_padded`/`cu_seqlens_kv_padded` through to `FlashAttention`. ### utils.py - Only disable FA2 (not FA3) when `pad_between_seqs + thd`. FA3 handles this natively via `seqused`. ### test_attention_with_cp.py - Add `@pytest.mark.parametrize("pad_between_seqs", [False, True])` to flash attention CP tests. - Skip `pad_between_seqs=True` for non-THD formats, when FA3 is not installed, and for `a2a+p2p` comm type (not yet supported). ### run_attention_with_cp.py - Thread `pad_between_seqs` through `generate_input_shapes()` and `run_dpa_with_cp()`. - When `pad_between_seqs`, set `cu_seqlens_q` to actual lengths (not just for FusedAttention). - Handle FA3 backward NaN at padding positions: `nan_to_num(nan=0.0)`. - Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros at padding slots). - Add tensor names to NaN/Inf assertion messages for debuggability. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 83 +++++---- .../attention/test_attention_with_cp.py | 12 +- .../dot_product_attention/backends.py | 32 +++- .../dot_product_attention/context_parallel.py | 164 +++++++++++++++--- .../dot_product_attention.py | 2 + .../attention/dot_product_attention/utils.py | 9 +- 6 files changed, 237 insertions(+), 65 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0f36a8816d..f28c7d01ca 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -31,6 +31,7 @@ def generate_input_shapes( config: ModelConfig, world_size: int, kernel_backend: str, + pad_between_seqs: str = "False", ): if qkv_format == "bshd": q_input_shape = ( @@ -99,9 +100,9 @@ def generate_input_shapes( ).cuda() cu_seqlens_q = torch.clone(cu_seqlens_q_padded) - # Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does, - # cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only. - if kernel_backend == "FusedAttention": + # When pad_between_seqs is True, or for FusedAttention, cu_seqlens_q reflects + # non-padded (actual) lengths. FA3 handles this via seqused_q/seqused_k. + if kernel_backend == "FusedAttention" or pad_between_seqs == "True": cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() # NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded` @@ -180,6 +181,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + pad_between_seqs="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" @@ -275,7 +277,7 @@ def run_dpa_with_cp( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, pad_between_seqs) q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() @@ -494,6 +496,7 @@ def run_dpa_with_cp( # get outputs tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + tensor_names = ["out", "dq", "dk", "dv", "dbias", "out_", "dq_", "dk_", "dv_", "dbias_"] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): @@ -502,11 +505,14 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for tensor in tensors: + if pad_between_seqs == "True": + # FA3 backward may produce NaN at padding positions; replace with 0 + tensors = [t.detach().nan_to_num(nan=0.0) if t is not None else t for t in tensors] + for tensor, name in zip(tensors, tensor_names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + assert torch.all(~torch.isnan(tensor)), f"NaN in {name}" + assert torch.all(~torch.isinf(tensor)), f"Inf in {name}" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -567,19 +573,26 @@ def run_dpa_with_cp( cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 + if pad_between_seqs == "True": + # FA3 doesn't guarantee zeros at padding positions; zero them explicitly + x[cu_seqlens_q_padded[-1] :] = 0.0 + for b in range(config.batch_size): + if num_pads_q[b] > 0: + x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]] = 0.0 + else: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] ] - ] - ).item() - == 0 - ) + ).item() + == 0 + ) cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True @@ -587,19 +600,25 @@ def run_dpa_with_cp( cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 - ) + if pad_between_seqs == "True": + x[cu_seqlens_kv_padded[-1] :] = 0.0 + for b in range(config.batch_size): + if num_pads_kv[b] > 0: + x[(cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[b + 1]] = 0.0 + else: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 + ) else: # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index ecd0090a3b..98f8e56490 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -85,11 +85,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("pad_between_seqs", [False, True]) +def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + if pad_between_seqs: + if qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") + if not FlashAttentionUtils.v3_is_installed: + pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") + if cp_comm_type == "a2a+p2p": + pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type @@ -133,6 +142,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + pad_between_seqs=pad_between_seqs, log_level=pytest_logging_level, ), check=True, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..0b5739fee3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -742,6 +742,8 @@ def forward( flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, num_splits: Optional[int] = 1, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -919,6 +921,11 @@ def forward( use_flash_attn_3 = False if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True + + pad_between_seqs = False + if qkv_format == "thd" and cu_seqlens_q_padded is not None: + pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) + if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -935,8 +942,12 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q if qkv_format == "thd" else None, - cu_seqlens_kv if qkv_format == "thd" else None, + cu_seqlens_q_padded if pad_between_seqs else ( + cu_seqlens_q if qkv_format == "thd" else None + ), + cu_seqlens_kv_padded if pad_between_seqs else ( + cu_seqlens_kv if qkv_format == "thd" else None + ), self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -948,7 +959,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, - pad_between_seqs=False, + pad_between_seqs=pad_between_seqs, use_flash_attn_3=use_flash_attn_3, fp8_output=fp8_output, ) @@ -984,8 +995,12 @@ def forward( else: func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment if not use_flash_attn_3 or inference_params is None: - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) + if pad_between_seqs and use_flash_attn_3: + fa_optional_forward_args_thd.append(cu_seqlens_q_padded) + fa_optional_forward_args_thd.append(cu_seqlens_kv_padded) + else: + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if not use_flash_attn_3: @@ -1019,6 +1034,13 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["num_splits"] = num_splits + if pad_between_seqs: + fa_3_optional_forward_kwargs["seqused_q"] = ( + cu_seqlens_q[1:] - cu_seqlens_q[:-1] + ) + fa_3_optional_forward_kwargs["seqused_k"] = ( + cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + ) if inference_params is None: fa_3_optional_forward_kwargs["deterministic"] = self.deterministic else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 10ba99595b..8e21799f3c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -572,6 +572,8 @@ def get_fa_args( dq=None, dk=None, dv=None, + seqused_q=None, + seqused_k=None, ): """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: @@ -581,7 +583,9 @@ def get_fa_args( *[None] * 4, # k_new, v_new, qv, out cu_seqlens_q, cu_seqlens_kv, - *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + None, # cu_seqlens_k_new + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, *[None] @@ -599,8 +603,8 @@ def get_fa_args( return [ cu_seqlens_q, cu_seqlens_kv, - None, # sequed_q - None, # sequed_k + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, dq, @@ -610,8 +614,8 @@ def get_fa_args( return [ None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sequed_q - None, # sequed_k + None, # seqused_q + None, # seqused_k max_seqlen_q, max_seqlen_kv, dq, @@ -923,6 +927,9 @@ def cp_p2p_fwd_flash_attn( cu_seqlens_q_per_step, cu_seqlens_kv_per_step, section, + pad_between_seqs=False, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, ): """Per-tile forward call of CP P2P with FlashAttention backend""" cu_seqlens_q_ = cu_seqlens_q_per_step @@ -943,6 +950,20 @@ def cp_p2p_fwd_flash_attn( fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + seqused_q = None + seqused_k = None + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + # Derive actual token counts per batch element from cu_seqlens + seqused_q = cu_seqlens_q_per_step[1:] - cu_seqlens_q_per_step[:-1] + seqused_k = cu_seqlens_kv_per_step[1:] - cu_seqlens_kv_per_step[:-1] + # Override cu_seqlens to padded layout for tensor memory layout + cu_seqlens_q_ = cu_seqlens_q_padded + cu_seqlens_kv_ = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_ = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_ = cu_seqlens_q_padded // 2 + fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -951,6 +972,8 @@ def cp_p2p_fwd_flash_attn( cu_seqlens_kv=cu_seqlens_kv_, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -1186,10 +1209,19 @@ def cp_p2p_bwd_flash_attn( out_part, dout_part, section, + pad_between_seqs=False, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, ): """Per-tile backward call of CP P2P with FlashAttention backend""" - dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] - if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: + if pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if use_flash_attn_3: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + elif fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 @@ -1213,17 +1245,33 @@ def cp_p2p_bwd_flash_attn( max_seqlen_q_ = max_seqlen_q // 2 softmax_lse__ = softmax_lse_ + seqused_q = None + seqused_k = None + cu_seqlens_q_bwd = cu_seqlens_q_per_step[cp_size - step - 1] + cu_seqlens_kv_bwd = cu_seqlens_kv_per_step[cp_size - step - 1] + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q_bwd[1:] - cu_seqlens_q_bwd[:-1] + seqused_k = cu_seqlens_kv_bwd[1:] - cu_seqlens_kv_bwd[:-1] + cu_seqlens_q_bwd = cu_seqlens_q_padded + cu_seqlens_kv_bwd = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_bwd = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_bwd = cu_seqlens_q_padded // 2 + fa_backward_args_thd = get_fa_args( False, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + cu_seqlens_q=cu_seqlens_q_bwd, + cu_seqlens_kv=cu_seqlens_kv_bwd, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if use_flash_attn_3: fa_backward_kwargs["is_causal"] = causal_ @@ -1681,7 +1729,12 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) elif i <= rank: @@ -1708,7 +1761,12 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) else: @@ -1735,7 +1793,12 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) else: @@ -1760,7 +1823,14 @@ def forward( ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( - cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + ) ) # softmax_lse correction @@ -2011,6 +2081,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8_meta = fp8_meta @@ -2425,7 +2496,12 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=ctx.pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) elif i >= (cp_size - rank - 1): section = "lower-triangle" @@ -2436,7 +2512,12 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=ctx.pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) else: section = "upper-triangle" @@ -2447,7 +2528,12 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=ctx.pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) else: section = "all" @@ -2458,7 +2544,12 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, + pad_between_seqs=ctx.pad_between_seqs, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) # dq, dk, dv are reduced across steps in higher precision @@ -3331,6 +3422,7 @@ def forward( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, @@ -3523,14 +3615,25 @@ def forward( ): out_part = O_quantizer(out_) else: + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -3635,6 +3738,7 @@ def forward( ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_type = softmax_type ctx.dQKV_quantizer = dQKV_quantizer @@ -3813,18 +3917,32 @@ def backward(ctx, dout, *_args): dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors - dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + if ctx.pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q, k, v]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if ctx.pad_between_seqs and ctx.use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state @@ -3919,6 +4037,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, d_softmax_offset, None, ) @@ -4137,6 +4256,7 @@ def attn_forward_func_with_cp( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..e7a14b0664 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1473,6 +1473,8 @@ def forward( flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, num_splits=num_splits, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if use_fused_attention: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..3fe49e8744 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -682,14 +682,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # Filter: QKV layout if qkv_format == "thd": if pad_between_seqs: - if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( - use_flash_attention_3 and FlashAttentionUtils.v3_is_installed - ): + if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " + "Disabling FlashAttention 2 for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) - use_flash_attention = False + use_flash_attention_2 = False + # FA3 supports pad_between_seqs via seqused_q/seqused_k if device_compute_capability == (12, 0): if use_fused_attention: logger.debug( From fb27e0ca1d2527f9bef4a0618315cc7316e8b453 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:38:41 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/run_attention_with_cp.py | 18 +++++++++++++----- .../dot_product_attention/backends.py | 12 ++++++++---- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index f28c7d01ca..05ff8e2982 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -578,7 +578,11 @@ def run_dpa_with_cp( x[cu_seqlens_q_padded[-1] :] = 0.0 for b in range(config.batch_size): if num_pads_q[b] > 0: - x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]] = 0.0 + x[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] + ] = 0.0 else: assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 for b in range(config.batch_size): @@ -586,9 +590,9 @@ def run_dpa_with_cp( num_pads_q[b] == 0 or torch.count_nonzero( x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] + ( + cu_seqlens_q_padded[b + 1] - num_pads_q[b] + ) : cu_seqlens_q_padded[b + 1] ] ).item() == 0 @@ -604,7 +608,11 @@ def run_dpa_with_cp( x[cu_seqlens_kv_padded[-1] :] = 0.0 for b in range(config.batch_size): if num_pads_kv[b] > 0: - x[(cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[b + 1]] = 0.0 + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] = 0.0 else: assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 for b in range(config.batch_size): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0b5739fee3..766d86caf9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -942,11 +942,15 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q_padded if pad_between_seqs else ( - cu_seqlens_q if qkv_format == "thd" else None + ( + cu_seqlens_q_padded + if pad_between_seqs + else (cu_seqlens_q if qkv_format == "thd" else None) ), - cu_seqlens_kv_padded if pad_between_seqs else ( - cu_seqlens_kv if qkv_format == "thd" else None + ( + cu_seqlens_kv_padded + if pad_between_seqs + else (cu_seqlens_kv if qkv_format == "thd" else None) ), self.attention_dropout if self.training else 0.0, cp_group, From 5c106584f5e7c556427180b9e86a4393bff7a07f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 12 Mar 2026 11:28:43 -0700 Subject: [PATCH 03/11] [PyTorch] Add non-CP pad_between_seqs test support for FlashAttention Enable FlashAttention backend in test_attention.py to use padded cu_seqlens and pad_between_seqs parameter, matching FusedAttention's test path. FA3 natively supports pad_between_seqs via seqused_q/seqused_k. - Group FlashAttention with FusedAttention for padded input/output handling - Pass cu_seqlens_q_padded/cu_seqlens_kv_padded for FlashAttention backend - Pass pad_between_seqs to DPA call - Add pad_between_seqs=True to parametrize with thd-only skip Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention.py | 26 +++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 60ade522e3..936de47921 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -124,7 +124,7 @@ def reset_global_fp8_state(): @pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) -@pytest.mark.parametrize("pad_between_seqs", [False]) +@pytest.mark.parametrize("pad_between_seqs", [False, True]) def test_dot_product_attention( dtype, model_configs, @@ -157,6 +157,8 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if pad_between_seqs and qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") if qkv_format == "thd" and "padding" not in config.attn_mask_type: config.attn_mask_type = ( "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" @@ -195,8 +197,9 @@ def test_dot_product_attention( ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention - # mannually pads and unpads the input and output of FlashAttention for testing purposes + # FA3 natively supports pad_between_seqs via seqused_q/seqused_k. + # FA2 does not support pad_between_seqs, but _run_dot_product_attention + # manually pads and unpads the input and output of FlashAttention for testing purposes. if ( pad_between_seqs and FlashAttentionUtils.is_installed @@ -1197,12 +1200,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: block.softmax_offset.requires_grad = True # Run a forward and backward pass - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: q = inp_orig[0] k = inp_orig[1] v = inp_orig[2] d_out = out_grad_orig - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: q = inp[0] k = inp[1] v = inp[2] @@ -1218,14 +1221,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=config.max_seqlen_kv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, - cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, + cu_seqlens_q_padded=( + cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), + cu_seqlens_kv_padded=( + cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, fast_zero_fill=True, + pad_between_seqs=pad_between_seqs, # Only pass num_splits when exercising the FlashAttention path num_splits=config.num_splits if backend == "FlashAttention" else 1, ) @@ -1239,12 +1247,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: if is_training: return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: return out, max_logit, (None, None, None, d_softmax_offset) - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) if is_training: From d8abce2c6edf2227c7401e1a6b1bb7b14644ae69 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 19 Mar 2026 20:12:49 -0700 Subject: [PATCH 04/11] fixes from feedback Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 9 ++-- .../dot_product_attention/backends.py | 4 +- .../dot_product_attention/context_parallel.py | 42 ++++++------------- .../dot_product_attention.py | 1 + 4 files changed, 17 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 05ff8e2982..552fa6897b 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -505,14 +505,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - if pad_between_seqs == "True": - # FA3 backward may produce NaN at padding positions; replace with 0 - tensors = [t.detach().nan_to_num(nan=0.0) if t is not None else t for t in tensors] for tensor, name in zip(tensors, tensor_names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"NaN in {name}" - assert torch.all(~torch.isinf(tensor)), f"Inf in {name}" + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -565,7 +562,7 @@ def run_dpa_with_cp( if is_training: dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + dq_, dk_, dv_, out_ = [x.clone() for x in [dq_, dk_, dv_, out_]] cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 766d86caf9..48137241ea 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -738,6 +738,7 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + pad_between_seqs: Optional[bool] = False, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, @@ -922,9 +923,6 @@ def forward( if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True - pad_between_seqs = False - if qkv_format == "thd" and cu_seqlens_q_padded is not None: - pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 8e21799f3c..a83cc6a06f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -921,15 +921,15 @@ def cp_p2p_fwd_flash_attn( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, cu_seqlens_q_per_step, cu_seqlens_kv_per_step, section, - pad_between_seqs=False, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, ): """Per-tile forward call of CP P2P with FlashAttention backend""" cu_seqlens_q_ = cu_seqlens_q_per_step @@ -1203,15 +1203,15 @@ def cp_p2p_bwd_flash_attn( rng_states, softmax_lse, softmax_lse_, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, out_part, dout_part, section, - pad_between_seqs=False, - cu_seqlens_q_padded=None, - cu_seqlens_kv_padded=None, ): """Per-tile backward call of CP P2P with FlashAttention backend""" if pad_between_seqs: @@ -1687,6 +1687,9 @@ def forward( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # cp_size = 4: @@ -1732,9 +1735,6 @@ def forward( *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) elif i <= rank: @@ -1764,9 +1764,6 @@ def forward( *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) else: @@ -1796,9 +1793,6 @@ def forward( *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) else: @@ -1827,9 +1821,6 @@ def forward( *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) ) @@ -2481,6 +2472,9 @@ def backward(ctx, dout, *_args): rng_states, softmax_lse, softmax_lse_, + ctx.pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, @@ -2499,9 +2493,6 @@ def backward(ctx, dout, *_args): *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=ctx.pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) elif i >= (cp_size - rank - 1): section = "lower-triangle" @@ -2515,9 +2506,6 @@ def backward(ctx, dout, *_args): *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=ctx.pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) else: section = "upper-triangle" @@ -2531,9 +2519,6 @@ def backward(ctx, dout, *_args): *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=ctx.pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) else: section = "all" @@ -2547,9 +2532,6 @@ def backward(ctx, dout, *_args): *flash_attn_inputs, *prepare_outputs, section, - pad_between_seqs=ctx.pad_between_seqs, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) # dq, dk, dv are reduced across steps in higher precision diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index e7a14b0664..b9bb1a3e28 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1469,6 +1469,7 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, inference_params=inference_params, flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, From 41e431afa6addda04513323567d475f4b9da9a89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 03:13:37 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/backends.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 48137241ea..4650d76613 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -923,7 +923,6 @@ def forward( if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True - if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): From 9efa48ffe9566d48ce2f355928abd198dc43a030 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 19 Mar 2026 20:15:38 -0700 Subject: [PATCH 06/11] remove redundant condition Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/context_parallel.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a83cc6a06f..bcdd81a459 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1218,10 +1218,7 @@ def cp_p2p_bwd_flash_attn( dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] else: dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] - if use_flash_attn_3: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - elif fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size_left"] = -1 From 232b78d6c4d7e5ae517b0094dc03e88b36cb775f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 20 Mar 2026 18:34:39 -0700 Subject: [PATCH 07/11] remove unnecessary zeroing logic, fixes from other feedback Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 70 ++++++++----------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 552fa6897b..fcdcf75426 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -508,8 +508,8 @@ def run_dpa_with_cp( for tensor, name in zip(tensors, tensor_names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + assert torch.all(~torch.isnan(tensor)), f"{name} has nan values" + assert torch.all(~torch.isinf(tensor)), f"{name} has inf values" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -562,38 +562,38 @@ def run_dpa_with_cp( if is_training: dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [x.clone() for x in [dq_, dk_, dv_, out_]] + out_ = out_.clone() cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + # FA3 forward doesn't zero padding positions in output; + # zero them in out_ (reference) so comparison is valid. + if pad_between_seqs == "True": + out_[cu_seqlens_q_padded[-1] :] = 0.0 + for b in range(config.batch_size): + if num_pads_q[b] > 0: + out_[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] + ] = 0.0 for x in [dq, out, dq_, out_]: - if pad_between_seqs == "True": - # FA3 doesn't guarantee zeros at padding positions; zero them explicitly - x[cu_seqlens_q_padded[-1] :] = 0.0 - for b in range(config.batch_size): - if num_pads_q[b] > 0: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] = 0.0 - else: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_q_padded[b + 1] - num_pads_q[b] - ) : cu_seqlens_q_padded[b + 1] - ] - ).item() - == 0 - ) + ( + cu_seqlens_q_padded[b + 1] - num_pads_q[b] + ) : cu_seqlens_q_padded[b + 1] + ] + ).item() + == 0 + ) cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True @@ -601,20 +601,10 @@ def run_dpa_with_cp( cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] for x in [dk, dv, dk_, dv_]: - if pad_between_seqs == "True": - x[cu_seqlens_kv_padded[-1] :] = 0.0 - for b in range(config.batch_size): - if num_pads_kv[b] > 0: - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] = 0.0 - else: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 or torch.count_nonzero( x[ ( From 73f989cb5e86d56a524cbf628c3533e9c044d831 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 01:35:42 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index fcdcf75426..ec836fc3af 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -587,9 +587,9 @@ def run_dpa_with_cp( num_pads_q[b] == 0 or torch.count_nonzero( x[ - ( - cu_seqlens_q_padded[b + 1] - num_pads_q[b] - ) : cu_seqlens_q_padded[b + 1] + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] ] ).item() == 0 @@ -605,15 +605,15 @@ def run_dpa_with_cp( for b in range(config.batch_size): assert ( num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 - ) + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 + ) else: # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() From 355252c3fcbef52a0853587a2eef4ed1617b7466 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 24 Mar 2026 12:48:44 -0700 Subject: [PATCH 09/11] add the flag to skip flash attn3 for head_dim_qk>128 Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index dbdf11343e..acce65115a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1069,6 +1069,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "please install flash-attn >= 2.4.1." ) use_flash_attention_2 = False + if use_flash_attention_3 and deterministic: + if head_dim_qk > 128: + logger.warning( + "Disabling FlashAttention 3 for deterministic execution with head_dim_qk > 128" + ) + use_flash_attention_3 = False if use_fused_attention and deterministic: if softmax_type != "vanilla": logger.debug( From 417b318a1eb51d80b154e306c3e3a76fac17430f Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 25 Mar 2026 13:23:18 -0700 Subject: [PATCH 10/11] fix kv cache block size issue for FA2 Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index acce65115a..914dae4a3c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -549,7 +549,7 @@ def get_attention_backend( # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 - # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | % 256 == 0 # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 @@ -588,9 +588,9 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if inference_params.is_paged: - if use_flash_attention_2 and inference_params.page_size < 256: + if use_flash_attention_2 and inference_params.page_size % 256 != 0: if FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 for page size < 256") + logger.debug("Disabling FlashAttention 2 for page size not divisible by 256") use_flash_attention_2 = False if use_flash_attention_2: if not FlashAttentionUtils.is_installed: @@ -600,6 +600,16 @@ def get_attention_backend( "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" ) use_flash_attention_2 = False + else: + # Non-paged KV cache still passes a block_table to FA2 for thd_2bshd support, + # and FA2 enforces page_size % 256 == 0 on the effective page size (max_seqlen_kv). + if use_flash_attention_2 and max_seqlen_kv % 256 != 0: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for non-paged KV cache" + " with max_seqlen_kv not divisible by 256" + ) + use_flash_attention_2 = False # Filter: Head dimension if head_dim_qk != head_dim_v: From 05301535893c19cd859eb3587cacdd83af8b1ecf Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 25 Mar 2026 15:24:37 -0700 Subject: [PATCH 11/11] add a skip when trying to run FA3 on SM100+ Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index d1aeb012f7..929f2b5326 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -200,6 +200,7 @@ def test_dot_product_attention( # FA3 natively supports pad_between_seqs via seqused_q/seqused_k. # FA2 does not support pad_between_seqs, but _run_dot_product_attention # manually pads and unpads the input and output of FlashAttention for testing purposes. + # Flash Attention is not supported on SM100+ if ( pad_between_seqs and FlashAttentionUtils.is_installed @@ -208,6 +209,7 @@ def test_dot_product_attention( and config.attn_mask_type in ["causal", "padding_causal"] ) and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) + and get_device_compute_capability() < (10, 0) ): flash_attn_supported = True