Skip to content

[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD#2596

Open
sudhakarsingh27 wants to merge 20 commits intoNVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs
Open

[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD#2596
sudhakarsingh27 wants to merge 20 commits intoNVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs

Conversation

@sudhakarsingh27
Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented Jan 14, 2026

Description

TLDR

Enable pad_between_seqs=True for FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previously pad_between_seqs was only supported with FusedAttention.

Problem

When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this via cu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) had pad_between_seqs hardcoded to False in the CP path, and FA2 was entirely disabled for pad_between_seqs + thd. FA3 can natively handle this via its seqused_q/seqused_k mechanism.

Solution

Use FA3's seqused_q/seqused_k tensors to communicate actual token counts per batch element. Pass cu_seqlens_q_padded for tensor memory layout while deriving seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] from the real cu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.

Fixes #2399

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

context_parallel.py

  • get_fa_args(): Add seqused_q/seqused_k parameters, pass through to FA3 forward and backward positional arg lists (replacing hardcoded Nones).
  • cp_p2p_fwd_flash_attn() / cp_p2p_bwd_flash_attn(): Accept pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded. When enabled, derive seqused tensors and override cu_seqlens to padded versions (with half-padding for lower-triangle/upper-triangle sections).
  • AttnFuncWithCPAndKVP2P: Thread pad_between_seqs and padded cu_seqlens through all forward/backward cp_p2p_fwd/bwd_flash_attn call sites. Save ctx.pad_between_seqs for backward.
  • AttnFuncWithCPAndQKVOA2A.forward(): Add pad_between_seqs parameter. When enabled with FA3+THD, derive seqused and swap cu_seqlens for padded versions before calling get_fa_args().
  • AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Use zeros_like (not empty_like) for gradient init when pad_between_seqs since FA3 skips padding positions. Add extra None in return tuple for the new pad_between_seqs gradient slot.
  • attn_forward_func_with_cp(): Pass pad_between_seqs in A2A args list.

backends.py

  • FlashAttention.forward(): Accept cu_seqlens_q_padded/cu_seqlens_kv_padded. Detect pad_between_seqs by comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and pass seqused_q/seqused_k.

dot_product_attention.py

  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded through to FlashAttention.

utils.py

  • Only disable FA2 (not FA3) when pad_between_seqs + thd. FA3 handles this natively via seqused.

test_attention_with_cp.py

  • Add @pytest.mark.parametrize("pad_between_seqs", [False, True]) to flash attention CP tests.
  • Skip pad_between_seqs=True for non-THD formats, when FA3 is not installed, and for a2a+p2p comm type (not yet supported).

run_attention_with_cp.py

  • Thread pad_between_seqs through generate_input_shapes() and run_dpa_with_cp().
  • When pad_between_seqs, set cu_seqlens_q to actual lengths (not just for FusedAttention).
  • Handle FA3 backward NaN at padding positions: nan_to_num(nan=0.0).
  • Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros at padding slots).
  • Add tensor names to NaN/Inf assertion messages for debuggability.

test_attention.py

  • Group FlashAttention with FusedAttention for padded input/output handling in _run_dot_product_attention() (previously FlashAttention used original unpadded inputs).
  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded and pad_between_seqs to DPA call for FlashAttention backend.
  • Add pad_between_seqs=True to parametrize with skip for non-THD formats.

New Tests

CP tests (test_attention_with_cp.py)

Added @pytest.mark.parametrize("pad_between_seqs", [False, True]) to test_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed, a2a+p2p comm type.

5 new tests that run (all pad_between_seqs=True, thd, bf16):

Test CP comm Model config
True-p2p-thd-cp_1_0-bf16 P2P causal, 1 head
True-p2p-thd-cp_2_1-bf16 P2P causal, 2 heads
True-a2a-thd-cp_1_0-bf16 A2A causal, 1 head
True-a2a-thd-cp_1_2-bf16 A2A causal, sliding window
True-a2a-thd-cp_2_1-bf16 A2A causal, 2 heads

Non-CP tests (test_attention.py)

Added True to @pytest.mark.parametrize("pad_between_seqs", [False, True]) on test_dot_product_attention, with skip for non-THD. Also changed _run_dot_product_attention so FlashAttention uses padded inputs/cu_seqlens and receives pad_between_seqs=True.

48 new test IDs collected, but all are skipped because the main parametrize uses qkv_layout=None (defaults to sbhd, not thd). The non-CP pad_between_seqs + FA3 code path is exercised indirectly when other test functions call test_dot_product_attention with qkv_layout="thd_thd_thd" (e.g., test_dpa_softmax_thd).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27 sudhakarsingh27 self-assigned this Jan 14, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

This PR enables pad_between_seqs=True for FlashAttention 3 with THD format across both CP paths (A2A and P2P) and the non-CP path, by leveraging FA3's native seqused_q/seqused_k mechanism. Previously this combination was only supported with FusedAttention.

The core approach is sound: FA3's seqused_q/seqused_k tensors carry actual per-batch token counts (derived from real cu_seqlens), while the padded cu_seqlens_q_padded tensors describe the physical memory layout. FA2 remains disabled for pad_between_seqs + thd (unchanged). Key implementation details are correctly handled:

  • get_fa_args() now propagates seqused_q/seqused_k instead of hardcoded None
  • P2P forward/backward (cp_p2p_fwd_flash_attn, cp_p2p_bwd_flash_attn) compute seqused from cu_seqlens_q_per_step and use cu_seqlens_q_padded (halved for lower/upper-triangle sections) for the memory layout
  • A2A forward/backward (AttnFuncWithCPAndQKVOA2A) apply the same seqused + padded layout pattern
  • Backward gradient buffers correctly use zeros_like (instead of empty_like) so FA3's unwritten padding positions remain zero
  • The A2A backward return tuple is correctly extended with one additional None for the new pad_between_seqs input
  • Test reference outputs (out_) are cloned and explicitly zeroed at padding positions before comparison, since FA3 does not guarantee zeroing unwritten slots in the forward output

Also includes two unrelated but useful fixes in utils.py: FA2 disabled for non-paged KV cache with max_seqlen_kv % 256 != 0, and FA3 disabled for deterministic execution with head_dim_qk > 128.

Issues found:

  • backends.py: When inference_params is not None combined with pad_between_seqs=True and FA3, non-padded cu_seqlens_q is passed as the FA3 memory-layout kwarg — which would be inconsistent with padded input tensors. This combination is out of scope for this PR but is worth a comment or assertion.
  • run_attention_with_cp.py: Reference gradient tensors (dq_, dk_, dv_) are not explicitly zeroed at padding positions. Currently safe because zeros_like initializes them and FA3 skips padding slots in backward, but adding explicit zeroing would make the test more defensive and self-documenting.

Confidence Score: 4/5

  • PR is safe to merge; the two identified concerns are non-blocking style suggestions and an out-of-scope edge case.
  • The core implementation is architecturally sound and correctly wired across all code paths (P2P forward/backward, A2A forward/backward, non-CP FA3, utils backend selection). The return-tuple alignment for A2A backward is correct, zeros_like is used appropriately for gradient safety, and tests exercise the new paths with meaningful skip guards. The inference + pad_between_seqs inconsistency is a theoretical edge case not covered by the PR's stated scope. The unzeroed reference gradient buffers in the CP test are safe today given the zeros_like initializer, but lack defensive depth. No correctness bugs were found in the primary training + CP + FA3 + THD path this PR targets.
  • transformer_engine/pytorch/attention/dot_product_attention/backends.py (inference path with pad_between_seqs); tests/pytorch/attention/run_attention_with_cp.py (reference gradient zeroing)

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core CP logic updated across P2P and A2A paths to support pad_between_seqs via seqused_q/seqused_k for FA3. Backward return tuple correctly extended with extra None for the new pad_between_seqs param. The cu_seqlens_padded // 2 halving for lower/upper-triangle sections implicitly requires padded lengths to be even (guaranteed by CP padding invariants). Uses zeros_like for gradient init to safely handle unwritten padding positions.
transformer_engine/pytorch/attention/dot_product_attention/backends.py FlashAttention.forward() now accepts pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded and threads them correctly to both CP and non-CP FA3 paths. For the non-CP FA3 path, seqused_q/seqused_k are derived from actual cu_seqlens while padded cu_seqlens provide the memory layout. The inference path (inference_params is not None) combined with pad_between_seqs=True could result in non-padded cu_seqlens being passed as the memory layout kwarg, which may be inconsistent — but this combination is out of scope for the PR.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Correctly narrows FA2 disablement for pad_between_seqs+thd to FA2 only, leaving FA3 enabled. Also includes two unrelated but useful fixes: FA2 page size divisibility check for non-paged KV cache, and FA3 deterministic execution disabled for head_dim_qk > 128.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Minimal, correct pass-through of pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded to FlashAttention.forward().
tests/pytorch/attention/test_attention_with_cp.py Properly adds pad_between_seqs parametrize with correct skip guards (non-THD, FA3 not installed, a2a+p2p not yet supported). Correctly passes the parameter through to the subprocess runner.
tests/pytorch/attention/run_attention_with_cp.py Correctly extends generate_input_shapes() and run_dpa_with_cp() with pad_between_seqs. Reference output out_ is cloned and explicitly zeroed at padding positions before comparison (since FA3 does not guarantee zeros there). NaN/Inf assertions now include tensor names for debuggability. One minor note: pad_between_seqs is passed as a string ("True"/"False") consistent with other CLI-style bool args in this file.
tests/pytorch/attention/test_attention.py FlashAttention is now grouped with FusedAttention for padded input/output handling. For pad_between_seqs=False, inp == inp_orig so the grouping change is a no-op. For pad_between_seqs=True, padded inputs are correctly passed and output is correctly unpadded before comparison. The new get_device_compute_capability() < (10, 0) guard correctly restricts the flash_attn_supported re-enablement to pre-SM100 devices.

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention
    participant FA as FlashAttention
    participant utils as get_attention_backend
    participant CP as attn_forward_func_with_cp
    participant P2P as AttnFuncWithCPAndKVP2P
    participant A2A as AttnFuncWithCPAndQKVOA2A
    participant helper as cp_p2p_fwd_flash_attn
    participant FA3 as flash_attn_v3

    User->>DPA: forward(q,k,v, cu_seqlens_q, cu_seqlens_q_padded, pad_between_seqs=True)
    DPA->>utils: get_attention_backend(pad_between_seqs=True)
    utils-->>DPA: FA2 disabled, FA3 enabled
    DPA->>FA: forward(pad_between_seqs=True, cu_seqlens_q_padded)

    alt Context Parallel path
        FA->>CP: forward(pad_between_seqs, cu_seqlens_q_padded)
        alt P2P comm type
            CP->>P2P: apply(pad_between_seqs, cu_seqlens_q_padded)
            loop each CP step
                P2P->>helper: "cp_p2p_fwd_flash_attn(pad_between_seqs, cu_seqlens_q_padded)"
                Note over helper: "seqused_q = cu_seqlens_q_per_step[1:] - [:-1]"
                Note over helper: "cu_seqlens_q = cu_seqlens_q_padded (layout)"
                helper->>FA3: "varlen_func(q,k,v, cu_seqlens_padded, seqused_q, seqused_k)"
                FA3-->>helper: out with padding positions unwritten
            end
            P2P-->>CP: out (LSE accumulation zeros padding)
        else A2A comm type
            CP->>A2A: apply(pad_between_seqs, cu_seqlens_q_padded)
            Note over A2A: "seqused_q = cu_seqlens_q[1:] - [:-1]"
            Note over A2A: "fa_cu_seqlens_q = cu_seqlens_q_padded"
            A2A->>FA3: "varlen_func(q,k,v, cu_seqlens_padded, seqused_q, seqused_k)"
            FA3-->>A2A: out
            A2A-->>CP: out
        end
        CP-->>FA: out
    else Non-CP path
        Note over FA: "seqused_q = cu_seqlens_q[1:] - [:-1]"
        Note over FA: "cu_seqlens = cu_seqlens_q_padded (layout)"
        FA->>FA3: "varlen_func(q,k,v, cu_seqlens_padded, seqused_q)"
        FA3-->>FA: out
    end
    FA-->>DPA: out
    DPA-->>User: out
Loading

Reviews (9): Last reviewed commit: "add a skip when trying to run FA3 on SM1..." | Re-trigger Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +974 to +983
# if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k`
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
# padding positions.
if pad_between_seqs:
fa_3_optional_forward_kwargs["seqused_q"] = (
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
)
fa_3_optional_forward_kwargs["seqused_k"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?

## TLDR
Enable `pad_between_seqs=True` for A2A and P2P context parallelism comm types
with FlashAttention 3 and THD format. Previously `pad_between_seqs` was only
supported with FusedAttention.

## Problem
When using THD format with variable-length sequences, sequences are padded for
divisibility across CP ranks. With `pad_between_seqs=True`, the attention kernel
needs to know actual (unpadded) token counts so it doesn't compute attention over
padding tokens. FusedAttention already handled this via `cu_seqlens_q_padded`, but
FlashAttention (both FA2 and FA3) had `pad_between_seqs` hardcoded to `False` in
the CP path, and FA2 was entirely disabled for `pad_between_seqs + thd`. FA3 can
natively handle this via its `seqused_q`/`seqused_k` mechanism.

## Solution
Use FA3's `seqused_q`/`seqused_k` tensors to communicate actual token counts per
batch element. Pass `cu_seqlens_q_padded` for tensor memory layout while deriving
`seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]` from the real `cu_seqlens`.

## Changes

### context_parallel.py
- `get_fa_args()`: Add `seqused_q`/`seqused_k` parameters, pass through to FA3
  forward and backward positional arg lists (replacing hardcoded `None`s).
- `cp_p2p_fwd_flash_attn()` / `cp_p2p_bwd_flash_attn()`: Accept `pad_between_seqs`,
  `cu_seqlens_q_padded`, `cu_seqlens_kv_padded`. When enabled, derive `seqused`
  tensors and override `cu_seqlens` to padded versions (with half-padding for
  lower-triangle/upper-triangle sections).
- `AttnFuncWithCPAndKVP2P`: Thread `pad_between_seqs` and padded cu_seqlens
  through all forward/backward `cp_p2p_fwd/bwd_flash_attn` call sites. Save
  `ctx.pad_between_seqs` for backward.
- `AttnFuncWithCPAndQKVOA2A.forward()`: Add `pad_between_seqs` parameter. When
  enabled with FA3+THD, derive `seqused` and swap `cu_seqlens` for padded versions
  before calling `get_fa_args()`.
- `AttnFuncWithCPAndQKVOA2A.backward()`: Same seqused/cu_seqlens override.
  Use `zeros_like` (not `empty_like`) for gradient init when `pad_between_seqs`
  since FA3 skips padding positions. Add extra `None` in return tuple for the
  new `pad_between_seqs` gradient slot.
- `attn_forward_func_with_cp()`: Pass `pad_between_seqs` in A2A args list.

### backends.py
- `FlashAttention.forward()`: Accept `cu_seqlens_q_padded`/`cu_seqlens_kv_padded`.
  Detect `pad_between_seqs` by comparing padded vs actual cu_seqlens. Pass padded
  cu_seqlens to CP path. For non-CP FA3 path, derive and pass `seqused_q`/`seqused_k`.

### dot_product_attention.py
- Pass `cu_seqlens_q_padded`/`cu_seqlens_kv_padded` through to `FlashAttention`.

### utils.py
- Only disable FA2 (not FA3) when `pad_between_seqs + thd`. FA3 handles this
  natively via `seqused`.

### test_attention_with_cp.py
- Add `@pytest.mark.parametrize("pad_between_seqs", [False, True])` to flash
  attention CP tests.
- Skip `pad_between_seqs=True` for non-THD formats, when FA3 is not installed,
  and for `a2a+p2p` comm type (not yet supported).

### run_attention_with_cp.py
- Thread `pad_between_seqs` through `generate_input_shapes()` and `run_dpa_with_cp()`.
- When `pad_between_seqs`, set `cu_seqlens_q` to actual lengths (not just for
  FusedAttention).
- Handle FA3 backward NaN at padding positions: `nan_to_num(nan=0.0)`.
- Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros
  at padding slots).
- Add tensor names to NaN/Inf assertion messages for debuggability.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from ea51821 to e338049 Compare March 10, 2026 23:37
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L2

@sudhakarsingh27 sudhakarsingh27 changed the title Flash attn pad bw seqs [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD Mar 11, 2026
Enable FlashAttention backend in test_attention.py to use padded
cu_seqlens and pad_between_seqs parameter, matching FusedAttention's
test path. FA3 natively supports pad_between_seqs via seqused_q/seqused_k.

- Group FlashAttention with FusedAttention for padded input/output handling
- Pass cu_seqlens_q_padded/cu_seqlens_kv_padded for FlashAttention backend
- Pass pad_between_seqs to DPA call
- Add pad_between_seqs=True to parametrize with thd-only skip

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

pad_between_seqs = False
if qkv_format == "thd" and cu_seqlens_q_padded is not None:
pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can pad_between_seqs be decided ahead of time, passed by the user or something? This wouldn't be CUDA Graph-compatible right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern exists in dpa.py as well. But yes, it's definitely redundant here

sudhakarsingh27 and others added 5 commits March 19, 2026 20:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ransformerEngine into flash_attn_pad_bw_seqs
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support FlashAttention with pad_between_seqs=True

2 participants