Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/llm_sparsity/attention_sparsity/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Attention Sparsity for HuggingFace Models

In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. Two attention backends are supported:

- **pytorch** (default): Patches `F.softmax` to apply skip-softmax sparsity (requires `attn_implementation="eager"`)
- **triton**: Uses a fused Triton Flash Attention kernel with in-kernel sparsity (uses `attn_implementation="modelopt_triton"`)

## Getting Started

Expand Down
37 changes: 18 additions & 19 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ def main(args):

print(f"Loading model: {args.pyt_ckpt_path}")

# Load model and tokenizer
# Note: attn_implementation="eager" is required for calibration to work properly
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
# No need to specify attn_implementation here — mtsa.sparsify() sets it
# automatically ("eager" for pytorch backend, "modelopt_triton" for triton).
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
Expand All @@ -164,21 +163,21 @@ def main(args):
output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args)

# Apply sparse attention with optional calibration
print(f"\nApplying sparse attention: {args.sparse_attn}")
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]

# Override calibration options if provided via CLI
print(f"\nApplying sparse attention: {args.sparse_attn} (backend={args.backend})")
sparse_config = copy.deepcopy(SPARSE_ATTN_CFG_CHOICES[args.sparse_attn])

# Apply CLI overrides to sparse_cfg
sparse_cfg = sparse_config.get("sparse_cfg", {})
for layer_cfg in sparse_cfg.values():
if isinstance(layer_cfg, dict) and "method" in layer_cfg:
layer_cfg["backend"] = args.backend
if args.target_sparse_ratio is not None:
sparse_config = copy.deepcopy(sparse_config)
sparse_cfg = sparse_config.get("sparse_cfg", {})
if isinstance(sparse_cfg, dict) and "calibration" in sparse_cfg:
calibration_cfg = sparse_cfg["calibration"]
if isinstance(calibration_cfg, dict):
calibration_cfg["target_sparse_ratio"] = {
"prefill": args.target_sparse_ratio,
"decode": args.target_sparse_ratio,
}
print(f"Overriding target_sparse_ratio to {args.target_sparse_ratio}")
calib = sparse_cfg.setdefault("calibration", {})
assert isinstance(calib, dict)
calib["target_sparse_ratio"] = {
"prefill": args.target_sparse_ratio,
"decode": args.target_sparse_ratio,
}

model = mtsa.sparsify(model, config=sparse_config)
print("Sparse attention applied successfully!")
Expand Down Expand Up @@ -242,8 +241,8 @@ def main(args):
"--backend",
type=str,
default="pytorch",
choices=["pytorch"],
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
choices=["pytorch", "triton"],
help="Backend for sparse attention (default: pytorch). 'triton' uses the fused Triton kernel.",
)

# Sequence length arguments
Expand Down
47 changes: 47 additions & 0 deletions modelopt/torch/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared Triton kernels for modelopt (attention, quantization, etc.)."""

import torch

from modelopt.torch.utils import import_plugin

IS_AVAILABLE = False
attention = None
register_triton_attention = None

if torch.cuda.is_available():
with import_plugin(
"triton",
msg_if_missing=(
"Your device is potentially capable of using the triton attention "
"kernel. Try to install triton with `pip install triton`."
),
):
from .triton_fa import attention as _attention

attention = _attention
IS_AVAILABLE = True
with import_plugin("transformers"):
from .hf_triton_attention import register_triton_attention as _register_triton_attention

register_triton_attention = _register_triton_attention
Comment thread
coderabbitai[bot] marked this conversation as resolved.

__all__ = [
"IS_AVAILABLE",
"attention",
"register_triton_attention",
]
143 changes: 143 additions & 0 deletions modelopt/torch/kernels/hf_triton_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""HuggingFace attention backend using the Triton flash attention kernel.

Registers as attn_implementation="modelopt_triton" so HF models dispatch to the
Triton kernel natively. Handles format conversion between HF's [batch, heads, seq, dim]
and the kernel's flat packed [total_tokens, heads, dim] varlen format.
"""

from __future__ import annotations

import torch
import torch.nn as nn

from modelopt.torch.kernels.triton_fa import attention


def _seq_lens_from_mask(
attention_mask: torch.Tensor | None,
fallback: int,
device: torch.device,
) -> tuple[torch.Tensor | None, bool]:
"""Derive per-sequence lengths from attention mask.

Returns (b_seq_len, has_padding). If the mask is not a usable 2D format,
returns (None, False).
"""
if attention_mask is not None and attention_mask.dim() == 2:
mask = attention_mask.bool() if attention_mask.dtype != torch.bool else attention_mask
b_seq_len = mask.sum(dim=1).to(torch.int32).to(device)
has_padding = bool((b_seq_len != fallback).any())
return b_seq_len, has_padding
return None, False
Comment on lines +31 to +46
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject left-padded 2D masks or repack them.

Collapsing the mask to lengths only works for right padding. A row like [0, 0, 1, 1] will make the kernel read positions 0..1 as the valid prefix, and the post-mask then zeros the real tokens at 2..3. Please either pack from the mask positions or fail fast on non-right-padded masks.

🐛 Minimal guard
     if attention_mask is not None and attention_mask.dim() == 2:
         mask = attention_mask.bool() if attention_mask.dtype != torch.bool else attention_mask
+        if bool((~mask[:, :-1] & mask[:, 1:]).any()):
+            raise NotImplementedError(
+                "modelopt_triton currently supports only right-padded 2D attention masks"
+            )
         b_seq_len = mask.sum(dim=1).to(torch.int32).to(device)

Also applies to: 112-117

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/hf_triton_attention.py` around lines 31 - 46, The
current _seq_lens_from_mask collapses any 2D mask to lengths which only works
for right-padded masks; detect rows that are not right-padded (i.e., rows where
a zero appears before a later one) and either repack those sequences into a
contiguous prefix ordering or raise a clear error; implement the check by
validating for each row of attention_mask that once a 0 appears no subsequent 1
exists, and if the check fails either 1) build packed indices from the mask and
return packed b_seq_len/indicator for the Triton kernel or 2) raise ValueError
with a message referencing _seq_lens_from_mask so callers know they must provide
right-padded masks (also apply the same guard where similar logic occurs around
the other block referenced at lines ~112-117).



def triton_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
Comment on lines +49 to +57
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "hf_triton_attention.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 113


🏁 Script executed:

wc -l ./modelopt/torch/kernels/hf_triton_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 117


🏁 Script executed:

cat -n ./modelopt/torch/kernels/hf_triton_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 6610


Read the backend contract flags from kwargs instead of hardcoding causality via seq_len heuristic.

The current implementation determines causality by checking seq_len <= 1 (line 99: "is_causal": not is_decode) rather than respecting flags passed by Hugging Face. This violates the attention backend contract and will cause non-causal attention models (e.g., encoder-only transformers) to be incorrectly masked as causal during prefill. Additionally, the dropout parameter (line 56) is silently ignored despite being part of the interface.

The docstring correctly documents these limitations (lines 69–74), but the implementation should either:

  1. Extract and honor is_causal and dropout_p from kwargs to match HF's attention backend contract, or
  2. Validate unsupported configurations explicitly rather than silently applying incorrect behavior.

This affects lines 64–74 (docstring), 99 (causality logic), and the overall function contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/hf_triton_attention.py` around lines 49 - 57, The
triton_attention_forward implementation currently infers causality via a seq_len
heuristic and ignores the dropout parameter; update it to read HF backend flags
from kwargs instead: extract and honor kwargs.get("is_causal") (or
kwargs["is_causal"] with a clear validation) and kwargs.get("dropout_p") (or
fallback to the function's dropout parameter only if absent), remove the
seq_len<=1 causality heuristic, and ensure the attention backend payload passed
to the Triton kernel uses the explicit is_causal and dropout values;
alternatively, if those flags are unsupported, raise a clear error if they are
provided rather than silently applying wrong behavior.

) -> tuple[torch.Tensor, None]:
"""Attention forward compatible with HF AttentionInterface.

Converts HF tensors to varlen format, calls the Triton kernel, converts back.
Handles both prefill (seq_len > 1) and decode (seq_len == 1).

Args:
module: The attention module (LlamaAttention etc.).
query: [batch, num_heads, seq_len, head_dim].
key: [batch, num_kv_heads, seq_k, head_dim].
value: [batch, num_kv_heads, seq_k, head_dim].
attention_mask: Optional; kernel handles causal masking internally.
2D [batch, seq_len] masks are used to derive per-sequence lengths.
Other formats (e.g. 4D causal masks) are ignored.
scaling: Softmax scale (e.g. 1/sqrt(head_dim)).
dropout: Ignored (kernel has no dropout); use 0 for eval.
**kwargs: Reserved for future extensions.

Returns:
(attn_output, None) with attn_output [batch, seq_len, num_heads, head_dim].
"""
batch, num_heads, seq_len, head_dim = query.shape
seq_k = key.shape[2]
num_kv_heads = key.shape[1]
device = query.device
is_decode = seq_len <= 1

# Reshape from HF [batch, heads, seq, dim] -> flat [batch*seq, heads, dim]
q = query.permute(0, 2, 1, 3).reshape(batch * seq_len, num_heads, head_dim).contiguous()
k = key.permute(0, 2, 1, 3).reshape(batch * seq_k, num_kv_heads, head_dim).contiguous()
v = value.permute(0, 2, 1, 3).reshape(batch * seq_k, num_kv_heads, head_dim).contiguous()

# Build varlen metadata
b_seq_len_q, has_padding = _seq_lens_from_mask(attention_mask, seq_len, device)
if b_seq_len_q is None:
b_seq_len_q = torch.full((batch,), seq_len, device=device, dtype=torch.int32)

kw = {
"b_start_loc": torch.arange(batch, device=device, dtype=torch.int32) * seq_len,
"b_seq_len": b_seq_len_q,
"max_input_len": seq_len,
"is_causal": not is_decode,
"softmax_scale": scaling,
}
# Decode: Q has 1 token, K/V have seq_k tokens (KV cache, no padding)
if is_decode:
kw["b_start_loc_k"] = torch.arange(batch, device=device, dtype=torch.int32) * seq_k
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["max_input_len_k"] = seq_k

o = attention(q, k, v, **kw)

attn_output = o.view(batch, seq_len, num_heads, head_dim)

# Zero out padding positions (kernel produces NaN for all-padding rows due to 0/0).
# Assumes right-padding (valid tokens at positions 0..n-1), which is the HF
# convention during prefill. Left-padded inputs are not supported.
if has_padding:
pad_mask = torch.arange(seq_len, device=device)[None, :] >= b_seq_len_q[:, None]
attn_output = attn_output.masked_fill(pad_mask[:, :, None, None], 0.0)

return (attn_output, None)


def register_triton_attention() -> bool:
"""Register the Triton backend with HF AttentionInterface.

Called by _set_attn_implementation() during sparsification. Must run before
the model's first forward pass so HF dispatches to the Triton kernel.

Returns:
True if registration succeeded, False if transformers API not available.
"""
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
except (ImportError, AttributeError):
return False

ALL_ATTENTION_FUNCTIONS.register("modelopt_triton", triton_attention_forward)
return True


__all__ = [
"register_triton_attention",
"triton_attention_forward",
]
Loading
Loading