Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1ececdc
[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD
sudhakarsingh27 Mar 10, 2026
e338049
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 10, 2026
fb27e0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
50839e1
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 12, 2026
5c10658
[PyTorch] Add non-CP pad_between_seqs test support for FlashAttention
sudhakarsingh27 Mar 12, 2026
66e3352
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 12, 2026
d8abce2
fixes from feedback
sudhakarsingh27 Mar 20, 2026
41e431a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2026
9efa48f
remove redundant condition
sudhakarsingh27 Mar 20, 2026
c8a84bf
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Mar 20, 2026
8652dba
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 20, 2026
232b78d
remove unnecessary zeroing logic, fixes from other feedback
sudhakarsingh27 Mar 21, 2026
73f989c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2026
0228d08
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 23, 2026
7b8bc13
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Mar 23, 2026
355252c
add the flag to skip flash attn3 for head_dim_qk>128
sudhakarsingh27 Mar 24, 2026
e02ab58
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 24, 2026
d596db0
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Mar 25, 2026
417b318
fix kv cache block size issue for FA2
sudhakarsingh27 Mar 25, 2026
0530153
add a skip when trying to run FA3 on SM100+
sudhakarsingh27 Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def generate_input_shapes(
config: ModelConfig,
world_size: int,
kernel_backend: str,
pad_between_seqs: str = "False",
):
if qkv_format == "bshd":
q_input_shape = (
Expand Down Expand Up @@ -99,9 +100,9 @@ def generate_input_shapes(
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)

# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
# When pad_between_seqs is True, or for FusedAttention, cu_seqlens_q reflects
# non-padded (actual) lengths. FA3 handles this via seqused_q/seqused_k.
if kernel_backend == "FusedAttention" or pad_between_seqs == "True":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()

# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
Expand Down Expand Up @@ -180,6 +181,7 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
pad_between_seqs="False",
log_level=logging.WARNING,
):
"""Test DotProductAttention module with context parallelism"""
Expand Down Expand Up @@ -275,7 +277,7 @@ def run_dpa_with_cp(
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, pad_between_seqs)
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
Expand Down Expand Up @@ -494,6 +496,7 @@ def run_dpa_with_cp(

# get outputs
tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_]
tensor_names = ["out", "dq", "dk", "dv", "dbias", "out_", "dq_", "dk_", "dv_", "dbias_"]
if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
Expand All @@ -502,11 +505,11 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
for tensor, name in zip(tensors, tensor_names):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
assert torch.all(~torch.isnan(tensor)), f"{name} has nan values"
assert torch.all(~torch.isinf(tensor)), f"{name} has inf values"
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
Expand Down Expand Up @@ -559,13 +562,24 @@ def run_dpa_with_cp(
if is_training:
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
out_ = out_.clone()
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
# FA3 forward doesn't zero padding positions in output;
# zero them in out_ (reference) so comparison is valid.
if pad_between_seqs == "True":
out_[cu_seqlens_q_padded[-1] :] = 0.0
for b in range(config.batch_size):
if num_pads_q[b] > 0:
out_[
(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[
b + 1
]
] = 0.0
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
Expand Down
28 changes: 19 additions & 9 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_dot_product_attention(
dtype,
model_configs,
Expand Down Expand Up @@ -157,6 +157,8 @@ def test_dot_product_attention(

config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if pad_between_seqs and qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
Expand Down Expand Up @@ -195,8 +197,10 @@ def test_dot_product_attention(
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
# FA3 natively supports pad_between_seqs via seqused_q/seqused_k.
# FA2 does not support pad_between_seqs, but _run_dot_product_attention
# manually pads and unpads the input and output of FlashAttention for testing purposes.
# Flash Attention is not supported on SM100+
if (
pad_between_seqs
and FlashAttentionUtils.is_installed
Expand All @@ -205,6 +209,7 @@ def test_dot_product_attention(
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
and get_device_compute_capability() < (10, 0)
):
flash_attn_supported = True

Expand Down Expand Up @@ -1197,12 +1202,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
block.softmax_offset.requires_grad = True

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
q = inp[0]
k = inp[1]
v = inp[2]
Expand All @@ -1218,14 +1223,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
cu_seqlens_q_padded=(
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
cu_seqlens_kv_padded=(
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
pad_between_seqs=pad_between_seqs,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
Expand All @@ -1239,12 +1249,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
if is_training:
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
Expand Down
12 changes: 11 additions & 1 deletion tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand Down Expand Up @@ -133,6 +142,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
pad_between_seqs=pad_between_seqs,
log_level=pytest_logging_level,
),
check=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,13 @@ def forward(
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
pad_between_seqs: Optional[bool] = False,
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
num_splits: Optional[int] = 1,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""flash-attn fprop"""

Expand Down Expand Up @@ -919,6 +922,7 @@ def forward(
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True

if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
Expand All @@ -935,8 +939,16 @@ def forward(
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q if qkv_format == "thd" else None,
cu_seqlens_kv if qkv_format == "thd" else None,
(
cu_seqlens_q_padded
if pad_between_seqs
else (cu_seqlens_q if qkv_format == "thd" else None)
),
(
cu_seqlens_kv_padded
if pad_between_seqs
else (cu_seqlens_kv if qkv_format == "thd" else None)
),
self.attention_dropout if self.training else 0.0,
cp_group,
cp_global_ranks,
Expand All @@ -948,7 +960,7 @@ def forward(
deterministic=self.deterministic,
window_size=window_size,
quantizers=quantizers,
pad_between_seqs=False,
pad_between_seqs=pad_between_seqs,
use_flash_attn_3=use_flash_attn_3,
fp8_output=fp8_output,
)
Expand Down Expand Up @@ -984,8 +996,12 @@ def forward(
else:
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
if not use_flash_attn_3 or inference_params is None:
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
if pad_between_seqs and use_flash_attn_3:
fa_optional_forward_args_thd.append(cu_seqlens_q_padded)
fa_optional_forward_args_thd.append(cu_seqlens_kv_padded)
else:
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if not use_flash_attn_3:
Expand Down Expand Up @@ -1019,6 +1035,13 @@ def forward(
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["num_splits"] = num_splits
if pad_between_seqs:
fa_3_optional_forward_kwargs["seqused_q"] = (
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
)
fa_3_optional_forward_kwargs["seqused_k"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
if inference_params is None:
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
else:
Expand Down
Loading
Loading