Skip to content

Skip softmax calibration with list of thresholds#987

Open
rohansjoshi wants to merge 1 commit intomainfrom
rohjoshi/sa-calib
Open

Skip softmax calibration with list of thresholds#987
rohansjoshi wants to merge 1 commit intomainfrom
rohjoshi/sa-calib

Conversation

@rohansjoshi
Copy link
Contributor

@rohansjoshi rohansjoshi commented Mar 6, 2026

Modify skip softmax calibration to use a list of thresholds instead of a single threshold. Sparsity during inference is unchanged, but during calibration we can use the list to gather statistics about many thresholds in a single forward pass. Makes calibration 20x faster

Summary by CodeRabbit

  • New Features

    • Multi-threshold sparsity configuration allowing per-phase lists of thresholds.
  • Improvements

    • Calibration now gathers multi-threshold sparsity data in one pass for faster collection and fitting.
    • Sparsity/statistics report per-threshold block-level metrics.
  • Breaking Changes

    • Config key renamed from "threshold" → "thresholds" (lists per phase).
    • Stats and outputs changed from scalars to per-threshold lists.
  • Tests & Examples

    • Tests updated for list-based thresholds; example model loading uses automatic device placement.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 6, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

Adds multi-threshold support for attention sparsity: configs now accept per-phase lists of thresholds; calibrator collects per-sample sparsity for all thresholds in one forward pass; sparse method, stats aggregation, and tests updated to handle per-threshold arrays and renamed APIs/fields.

Changes

Cohort / File(s) Summary
Configuration Schema
modelopt/torch/sparsity/attention_sparsity/config.py
Replaced threshold (float) with thresholds (dict[str, list[float]]); updated defaults, validators, and global constants to enforce per-phase lists and equal-length requirements.
Calibration Logic
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
Switched to single-pass collection across all thresholds; renamed _set_threshold_set_thresholds to assign lists; aggregates per-sample per-threshold sparsity for fitting and adjusted messages/validation.
Sparse Method Implementation
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Renamed internal config/attributes to thresholds*; added _update_thresholds; added multi-threshold computation paths (prefill/decode), returns per-threshold sparsity and block info, and updated threshold/correction logic.
Statistics Aggregation
modelopt/torch/sparsity/attention_sparsity/stats_manager.py
sparse_blocks transitioned from scalar to per-threshold list: collect, get_summary, and reset updated to initialize, accumulate element-wise, and average list-form block stats.
Docs, Examples, Model Sparsify
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py, examples/llm_sparsity/attention_sparsity/hf_sa.py
Documentation and examples updated to show thresholds lists; example model loading now uses torch_dtype="auto", device_map="auto" and removes manual CUDA placement.
Tests — Config & Behavior
tests/_test_utils/.../sparse_attention_common.py, tests/unit/.../test_*.py, tests/gpu/.../test_*.py
All tests updated to use thresholds lists (e.g., {"prefill":[...],"decode":[...]}), call renamed APIs (e.g., _update_thresholds), and assert list-shaped per-threshold sparsity/stats outputs.

Sequence Diagram(s)

sequenceDiagram
    participant Calibrator
    participant Module
    participant SparseMethod as "Sparse Method"
    participant Aggregator

    rect rgba(200,100,150,0.5)
    Note over Calibrator,Aggregator: Old (per-threshold forwards)
    loop For each threshold t
        Calibrator->>Module: forward(threshold=t)
        Module->>SparseMethod: compute sparsity(threshold=t)
        SparseMethod-->>Module: sparsity_t
        Module-->>Calibrator: per-sample sparsity_t
        Calibrator->>Aggregator: aggregate(t, sparsity_t)
    end
    end

    rect rgba(100,160,200,0.5)
    Note over Calibrator,Aggregator: New (single-pass multi-threshold)
    Calibrator->>Module: forward(thresholds=[t1,...,tN])
    Module->>SparseMethod: compute sparsity for all thresholds
    SparseMethod-->>Module: sparsity_list [s1,...,sN]
    Module-->>Calibrator: per-sample sparsity_list
    loop For each threshold index i
        Calibrator->>Aggregator: unpack & aggregate sparsity_list[i]
    end
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly and concisely summarizes the primary change: replacing single thresholds with multiple thresholds in skip softmax calibration, enabling faster calibration.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Changes refactor sparse-attention calibration and configuration code without introducing insecure patterns or non-permissive dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch rohjoshi/sa-calib
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@rohansjoshi rohansjoshi marked this pull request as ready for review March 6, 2026 19:14
@rohansjoshi rohansjoshi requested a review from a team as a code owner March 6, 2026 19:14
@kevalmorabia97 kevalmorabia97 requested review from kaix-nv and removed request for kevalmorabia97 March 6, 2026 19:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

137-140: Guard against silent data loss when pairing thresholds and sparsities.

At Line 140, zip(self.threshold_trials, sparsity_list) silently truncates on length mismatch, which can hide calibration stat drift.

💡 Suggested fix
         for sample_stat in per_sample_stats:
             length = sample_stat["sample_length"]
             sparsity_list = sample_stat["sparsity"]
+            if len(sparsity_list) != len(self.threshold_trials):
+                raise ValueError(
+                    f"Expected {len(self.threshold_trials)} sparsity values, got {len(sparsity_list)}"
+                )
             for threshold, sparsity in zip(self.threshold_trials, sparsity_list):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 137 - 140, The loop silently truncates mismatched pairs by using
zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator
that processes per_sample_stats), validate that len(sparsity_list) ==
len(self.threshold_trials) and if not, raise a clear exception or log an error
and skip the sample to avoid silent data loss—use the
sample_stat/"sample_length" and sparsity_list context to include identifying
info in the message; do not rely on zip_longest to silently fill values,
explicitly enforce or handle length mismatches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the calibration sequence in a try/finally so
calibration mode is always cleaned up on error: after calling
self._set_thresholds(...) and self._enable_calibration_mode(...), run
forward_loop(model) and self._extract_calibration_stats(...) inside a try block
and call self._disable_calibration_mode(...) (and reset any trial thresholds if
applicable) in the finally block; reference the methods _set_thresholds,
_enable_calibration_mode, forward_loop, _extract_calibration_stats, and
_disable_calibration_mode so you locate and wrap that exact sequence to ensure
modules are disabled and thresholds cleared even when exceptions occur.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 179-183: The code computes total_blocks/total_valid_blocks using
num_causal_blocks when self.is_causal but later counts dense block positions
across all blocks (producing negative sparsity); update the dense-block counting
to mask out non-causal positions or else compute both numerator and denominator
from the same masked positions: use the same causal mask used to derive
num_causal_blocks when counting dense blocks (and when computing
total_blocks/total_valid_blocks) so numerator and denominator align (apply this
fix in the block that sets total_blocks/total_valid_blocks and also in the later
dense-counting section referenced around the second occurrence at Lines
~194-197); refer to self.is_causal, num_causal_blocks, total_valid_blocks,
total_blocks and the dense-block counting logic to locate and change the code.
- Around line 60-61: The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-140: The loop silently truncates mismatched pairs by using
zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator
that processes per_sample_stats), validate that len(sparsity_list) ==
len(self.threshold_trials) and if not, raise a clear exception or log an error
and skip the sample to avoid silent data loss—use the
sample_stat/"sample_length" and sparsity_list context to include identifying
info in the message; do not rely on zip_longest to silently fill values,
explicitly enforce or handle length mismatches.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: aced1124-5d44-4cab-b27a-89ab4c75bffa

📥 Commits

Reviewing files that changed from the base of the PR and between 1ccd945 and 8f455c1.

📒 Files selected for processing (12)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
  • tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py

Comment on lines +60 to 61
self.thresholds = self.thresholds_config.get("prefill", [1e-3])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid phase-order-dependent threshold fallback.

At Line 71, missing phase keys fall back to the previous runtime value (self.thresholds), so behavior depends on which phase ran first instead of config-only defaults.

💡 Suggested fix
-        # Initialize thresholds from dict config (prefill phase as default)
-        self.thresholds = self.thresholds_config.get("prefill", [1e-3])
+        # Deterministic fallback for configs that define only one phase
+        self._fallback_thresholds = (
+            self.thresholds_config.get("prefill")
+            or self.thresholds_config.get("decode")
+            or [1e-3]
+        )
+        self.thresholds = list(self._fallback_thresholds)

     def _update_thresholds(self, phase: str):
         """Update thresholds list based on phase."""
-        self.thresholds = self.thresholds_config.get(phase, self.thresholds)
+        self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))

Also applies to: 69-72

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 60 - 61, The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates FlashSkipSoftmax “skip softmax” calibration to support evaluating multiple sparsity thresholds in a single forward pass, improving calibration throughput while keeping inference sparsity behavior unchanged.

Changes:

  • Rename sparse attention config from threshold (scalar per phase) to thresholds (list per phase) and propagate through configs/tests.
  • Update FlashSkipSoftmax to compute per-threshold sparsity stats in one pass (and use the first threshold for the applied mask).
  • Extend stats aggregation to handle sparse_blocks as either a scalar or a list.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py Updates expected threshold info to thresholds dict-of-lists.
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py Updates sparse attention conversion tests to use thresholds.
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py Updates calibration tests/configs to use thresholds.
tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py Updates FlashSkipSoftmax unit tests for list-based sparsity outputs.
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py Updates GPU integration configs to thresholds.
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py Updates GPU calibration configs to thresholds.
tests/_test_utils/torch/sparsity/sparse_attention_common.py Updates shared test config fixtures to thresholds.
modelopt/torch/sparsity/attention_sparsity/stats_manager.py Adds support for aggregating list-valued sparse_blocks and list average sparsity.
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py Updates public-facing doc/example to thresholds.
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py Implements multi-threshold stats collection and threshold list handling.
modelopt/torch/sparsity/attention_sparsity/config.py Renames/validates thresholds as dict-of-float-lists (with length checks).
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py Switches calibration data collection to single-pass multi-threshold stats extraction.
Comments suppressed due to low confidence (1)

modelopt/torch/sparsity/attention_sparsity/config.py:132

  • validate_thresholds still raises an error that says "Threshold must be..." even though the field is now thresholds and expects lists. Updating this message will make validation failures much clearer to users.
    def validate_thresholds(cls, v):
        """Validate thresholds is a dict of lists with valid phases and values in range (0, 1)."""
        if not isinstance(v, dict):
            raise ValueError(
                f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
            )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@kaix-nv
Copy link
Contributor

kaix-nv commented Mar 10, 2026

The change does a great job of reducing calibration time overhead. LGTM overall, left a few comments.

@rohansjoshi rohansjoshi enabled auto-merge (squash) March 12, 2026 22:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

137-149: ⚠️ Potential issue | 🟡 Minor

Validate sparsity list length before zip.

At Line 140, zip(self.threshold_trials, sparsity_list) silently truncates if lengths differ. If a module returns fewer sparsity values than expected, data will be silently lost.

💡 Suggested fix
         for sample_stat in per_sample_stats:
             length = sample_stat["sample_length"]
             sparsity_list = sample_stat["sparsity"]
+            if len(sparsity_list) != len(self.threshold_trials):
+                raise ValueError(
+                    f"Expected {len(self.threshold_trials)} sparsity values, got {len(sparsity_list)}"
+                )
             for threshold, sparsity in zip(self.threshold_trials, sparsity_list):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 137 - 149, The loop over per_sample_stats uses zip(self.threshold_trials,
sparsity_list) which silently truncates mismatched lengths; before zipping,
validate that len(sparsity_list) == len(self.threshold_trials) (or handle the
mismatch explicitly) inside the same function/loop in calibrator.py (the
variables: per_sample_stats, sparsity_list, self.threshold_trials); if they
differ, either raise a clear ValueError or log an error and skip/pad entries so
no data is silently dropped, then proceed to iterate using the validated/padded
lists.
🧹 Nitpick comments (1)
tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py (1)

143-146: Relax sparsity bounds reflect known causal masking issue.

The sparsity check at Line 145 uses all(-1 <= s <= 1 for s in stats["sparsity"]) rather than the expected [0, 1] range. This appears to accommodate the causal block counting issue flagged in previous reviews where numerator/denominator mismatch can produce invalid sparsity values.

Consider adding a comment explaining this relaxed bound, or fixing the underlying causal masking issue in calc_correction_factor_and_p.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py`
around lines 143 - 146, The test currently asserts sparsity values allow [-1,1]
which hides a known causal masking numerator/denominator mismatch; update the
test at the assertion for stats["sparsity"] to either (a) tighten to the
expected 0..1 range and adjust/fix the underlying calculation in
calc_correction_factor_and_p (or related causal masking logic) so sparsity
cannot go negative, or (b) if you keep the relaxed bound, add a concise comment
above the assertion referencing the causal masking bug and pointing to
calc_correction_factor_and_p so future maintainers know why [-1,1] is allowed
and where to fix it. Ensure the reference is added near the assertion and that
any code changes fix the denominator/numerator handling in
calc_correction_factor_and_p to produce values in [0,1] before reverting the
test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 150-156: Replace the hardcoded trust_remote_code=True in the
AutoModelForCausalLM.from_pretrained(...) call with a caller-configurable CLI
flag: add a new argument to the script's argument parser (e.g.,
--trust-remote-code as a store_true flag or a boolean option defaulting to
False) and pass that parsed value (e.g., args.trust_remote_code) into the
from_pretrained call (alongside existing args.pyt_ckpt_path,
attn_implementation, torch_dtype, device_map). Ensure the new flag defaults to
False and is referenced where trust_remote_code is currently used.

---

Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-149: The loop over per_sample_stats uses
zip(self.threshold_trials, sparsity_list) which silently truncates mismatched
lengths; before zipping, validate that len(sparsity_list) ==
len(self.threshold_trials) (or handle the mismatch explicitly) inside the same
function/loop in calibrator.py (the variables: per_sample_stats, sparsity_list,
self.threshold_trials); if they differ, either raise a clear ValueError or log
an error and skip/pad entries so no data is silently dropped, then proceed to
iterate using the validated/padded lists.

---

Nitpick comments:
In `@tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py`:
- Around line 143-146: The test currently asserts sparsity values allow [-1,1]
which hides a known causal masking numerator/denominator mismatch; update the
test at the assertion for stats["sparsity"] to either (a) tighten to the
expected 0..1 range and adjust/fix the underlying calculation in
calc_correction_factor_and_p (or related causal masking logic) so sparsity
cannot go negative, or (b) if you keep the relaxed bound, add a concise comment
above the assertion referencing the causal masking bug and pointing to
calc_correction_factor_and_p so future maintainers know why [-1,1] is allowed
and where to fix it. Ensure the reference is added near the assertion and that
any code changes fix the denominator/numerator handling in
calc_correction_factor_and_p to produce values in [0,1] before reverting the
test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 507ea991-20b5-4a73-a642-d9e9d329486c

📥 Commits

Reviewing files that changed from the base of the PR and between 8f455c1 and fd82fe3.

📒 Files selected for processing (14)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
  • tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)

59-71: ⚠️ Potential issue | 🟠 Major

Phase-order-dependent threshold fallback persists.

At line 71, _update_thresholds falls back to self.thresholds when the phase key is missing from thresholds_config. This means behavior depends on which phase ran first, rather than deterministic config-only defaults.

💡 Suggested fix
+        # Deterministic fallback for configs that define only one phase
+        self._fallback_thresholds = (
+            self.thresholds_config.get("prefill")
+            or self.thresholds_config.get("decode")
+            or [1e-3]
+        )
         # Initialize thresholds from dict config (prefill phase as default)
-        self.thresholds = self.thresholds_config.get("prefill", [1e-3])
+        self.thresholds = list(self._fallback_thresholds)

     ...

     def _update_thresholds(self, phase: str):
         """Update thresholds list based on phase."""
-        self.thresholds = self.thresholds_config.get(phase, self.thresholds)
+        self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 59 - 71, _update_thresholds currently falls back to the mutable
instance field self.thresholds when a phase key is missing, causing behavior to
depend on which phase ran first; change _update_thresholds to use a
deterministic config-only fallback by reading from self.thresholds_config (e.g.,
use self.thresholds_config.get(phase, self.thresholds_config.get("prefill",
[1e-3]))) so missing phases always resolve from the config (not prior runtime
state) and keep set_calibration_mode/_calibration_mode logic unchanged.
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

130-136: ⚠️ Potential issue | 🟠 Major

Calibration mode may not be disabled on exception.

If forward_loop(model) or _extract_calibration_stats(...) raises an exception, _disable_calibration_mode is never called, leaving modules in calibration mode with trial thresholds still set.

💡 Suggested fix
         self._set_thresholds(attention_modules, self.threshold_trials)
         self._enable_calibration_mode(attention_modules)
-        with torch.no_grad():
-            forward_loop(model)
-        per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
-        self._disable_calibration_mode(attention_modules)
+        try:
+            with torch.no_grad():
+                forward_loop(model)
+            per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
+        finally:
+            self._disable_calibration_mode(attention_modules)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 130 - 136, Wrap the execution of forward_loop(model) and
_extract_calibration_stats(...) in a try/finally so that
_disable_calibration_mode(attention_modules) is always called even if
forward_loop or _extract_calibration_stats raises; keep the torch.no_grad()
context around the try/finally block and ensure _set_thresholds(...) and
_enable_calibration_mode(...) remain before the try so trial thresholds are set
for the attempt and always cleared in the finally via
_disable_calibration_mode(attention_modules).
🧹 Nitpick comments (4)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1)

72-72: Consider adding a test with multiple thresholds.

All threshold configurations in this file use single-element lists. Since the PR's main feature is gathering statistics for multiple thresholds in a single forward pass (~20× faster calibration), consider adding at least one test case that uses multiple threshold values (e.g., {"prefill": [1e-4, 1e-3, 1e-2], "decode": [1e-4, 1e-3]}). This would validate that the multi-threshold functionality works correctly at the conversion/configuration level.

If multi-threshold behavior is tested elsewhere (e.g., in calibration tests), this can be disregarded.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py`
at line 72, Add a new unit test case in test_sparse_attention_conversion.py that
uses a thresholds config with multiple values (e.g., {"prefill": [1e-4, 1e-3,
1e-2], "decode": [1e-4, 1e-3]}) instead of single-element lists; exercise the
same conversion/config path used by the existing tests (the code that reads the
"thresholds" dict during sparse attention conversion) and assert that the
resulting conversion/config contains entries for each provided threshold and
that any gathered statistics or outputs are produced per-threshold (verify
keys/counts match the input lists). Ensure the test targets the same functions
used in the file for conversion/configuration so it validates multi-threshold
handling end-to-end.
modelopt/torch/sparsity/attention_sparsity/config.py (1)

127-132: Minor: Error message uses singular "Threshold" but field is "thresholds".

The error message at line 131 says "Threshold must be a dict..." but the field is now named thresholds (plural).

💡 Suggested fix
         if not isinstance(v, dict):
             raise ValueError(
-                f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
+                f"thresholds must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 127 - 132,
In validate_thresholds (the classmethod in config.py) update the ValueError
message to use the plural field name "thresholds" (e.g., "Thresholds must be a
dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}") so the error
refers to the correct field; locate the validate_thresholds function and replace
the singular "Threshold" text with "Thresholds" in the processLogger.error/raise
ValueError message.
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)

144-146: Redundant assertion after conditional check.

The assert on line 146 is defensive but redundant since use_calibration_params already guarantees that both calibration_params and target_sparse_ratio are not None.

💡 Suggested cleanup
         if use_calibration_params:
             # Calibrated dynamic threshold: bypass thresholds list entirely
-            assert calibration_params is not None and target_sparse_ratio is not None
             a = calibration_params[phase]["a"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 144 - 146, Remove the redundant defensive assert inside the branch
guarded by use_calibration_params: since the if use_calibration_params: check
already guarantees that calibration_params and target_sparse_ratio are provided,
delete the assert calibration_params is not None and target_sparse_ratio is not
None to avoid unnecessary duplication in flash_skip_softmax.py (references:
use_calibration_params, calibration_params, target_sparse_ratio).
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

137-149: Consider adding length validation before zipping thresholds with sparsity.

At line 140, zip(self.threshold_trials, sparsity_list) will silently truncate if the lengths don't match. Since sparsity_list comes from per-sample stats that should match threshold_trials, a mismatch would indicate a bug.

💡 Suggested defensive check
         for sample_stat in per_sample_stats:
             length = sample_stat["sample_length"]
             sparsity_list = sample_stat["sparsity"]
+            assert len(sparsity_list) == len(self.threshold_trials), (
+                f"Sparsity list length {len(sparsity_list)} != threshold_trials {len(self.threshold_trials)}"
+            )
             for threshold, sparsity in zip(self.threshold_trials, sparsity_list):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 137 - 149, In the loop over per_sample_stats inside the calibrator (where
you iterate sample_stat and do zip(self.threshold_trials, sparsity_list)), add a
defensive length check that verifies len(sparsity_list) ==
len(self.threshold_trials) before zipping; if they differ, raise a clear
ValueError or log an error with identifying info from sample_stat (e.g.,
"sample_length" or an ID) and either skip that sample or fail fast so the silent
truncation cannot happen — update the code paths in the same function/method
where "sparsity_list" and "self.threshold_trials" are used to perform this
validation and handle the mismatch.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py`:
- Around line 69-74: The code reads stats["sparse_blocks"] directly which can
raise KeyError if absent; change to use stats.get("sparse_blocks", []) (assign
to incoming) and treat empty list as no-op when updating
self.aggregated_stats["sparse_blocks"] so the loop which sums elements only runs
when incoming is non-empty; ensure you still initialize
self.aggregated_stats["sparse_blocks"] = list(incoming) when incoming is
present, mirroring the pattern used for total_blocks.

---

Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the execution of forward_loop(model) and
_extract_calibration_stats(...) in a try/finally so that
_disable_calibration_mode(attention_modules) is always called even if
forward_loop or _extract_calibration_stats raises; keep the torch.no_grad()
context around the try/finally block and ensure _set_thresholds(...) and
_enable_calibration_mode(...) remain before the try so trial thresholds are set
for the attempt and always cleared in the finally via
_disable_calibration_mode(attention_modules).

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 59-71: _update_thresholds currently falls back to the mutable
instance field self.thresholds when a phase key is missing, causing behavior to
depend on which phase ran first; change _update_thresholds to use a
deterministic config-only fallback by reading from self.thresholds_config (e.g.,
use self.thresholds_config.get(phase, self.thresholds_config.get("prefill",
[1e-3]))) so missing phases always resolve from the config (not prior runtime
state) and keep set_calibration_mode/_calibration_mode logic unchanged.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-149: In the loop over per_sample_stats inside the calibrator
(where you iterate sample_stat and do zip(self.threshold_trials,
sparsity_list)), add a defensive length check that verifies len(sparsity_list)
== len(self.threshold_trials) before zipping; if they differ, raise a clear
ValueError or log an error with identifying info from sample_stat (e.g.,
"sample_length" or an ID) and either skip that sample or fail fast so the silent
truncation cannot happen — update the code paths in the same function/method
where "sparsity_list" and "self.threshold_trials" are used to perform this
validation and handle the mismatch.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 127-132: In validate_thresholds (the classmethod in config.py)
update the ValueError message to use the plural field name "thresholds" (e.g.,
"Thresholds must be a dict with 'prefill' and/or 'decode' keys, got
{type(v).__name__}") so the error refers to the correct field; locate the
validate_thresholds function and replace the singular "Threshold" text with
"Thresholds" in the processLogger.error/raise ValueError message.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 144-146: Remove the redundant defensive assert inside the branch
guarded by use_calibration_params: since the if use_calibration_params: check
already guarantees that calibration_params and target_sparse_ratio are provided,
delete the assert calibration_params is not None and target_sparse_ratio is not
None to avoid unnecessary duplication in flash_skip_softmax.py (references:
use_calibration_params, calibration_params, target_sparse_ratio).

In
`@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py`:
- Line 72: Add a new unit test case in test_sparse_attention_conversion.py that
uses a thresholds config with multiple values (e.g., {"prefill": [1e-4, 1e-3,
1e-2], "decode": [1e-4, 1e-3]}) instead of single-element lists; exercise the
same conversion/config path used by the existing tests (the code that reads the
"thresholds" dict during sparse attention conversion) and assert that the
resulting conversion/config contains entries for each provided threshold and
that any gathered statistics or outputs are produced per-threshold (verify
keys/counts match the input lists). Ensure the test targets the same functions
used in the file for conversion/configuration so it validates multi-threshold
handling end-to-end.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9fef1c11-30e1-4d26-80f0-4440f3ed1d89

📥 Commits

Reviewing files that changed from the base of the PR and between fd82fe3 and 586209f.

📒 Files selected for processing (14)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
  • tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
🚧 Files skipped from review as they are similar to previous changes (7)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py

Comment on lines +69 to +74
incoming = stats["sparse_blocks"]
if "sparse_blocks" not in self.aggregated_stats:
self.aggregated_stats["sparse_blocks"] = list(incoming)
else:
for i, val in enumerate(incoming):
self.aggregated_stats["sparse_blocks"][i] += val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential KeyError if sparse_blocks is missing from stats.

Line 69 accesses stats["sparse_blocks"] directly without using .get(), unlike other fields (e.g., total_blocks on line 67). If a caller omits sparse_blocks, this will raise a KeyError.

Consider using .get() with a default or documenting that sparse_blocks is required:

💡 Suggested fix
-        incoming = stats["sparse_blocks"]
+        incoming = stats.get("sparse_blocks")
+        if incoming is None:
+            return
+
         if "sparse_blocks" not in self.aggregated_stats:
             self.aggregated_stats["sparse_blocks"] = list(incoming)
         else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py` around lines 69
- 74, The code reads stats["sparse_blocks"] directly which can raise KeyError if
absent; change to use stats.get("sparse_blocks", []) (assign to incoming) and
treat empty list as no-op when updating self.aggregated_stats["sparse_blocks"]
so the loop which sums elements only runs when incoming is non-empty; ensure you
still initialize self.aggregated_stats["sparse_blocks"] = list(incoming) when
incoming is present, mirroring the pattern used for total_blocks.

…single pass

Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)

59-71: ⚠️ Potential issue | 🟠 Major

Avoid phase-order-dependent threshold fallback.

At line 71, missing phase keys fall back to the previous runtime value (self.thresholds), so behavior depends on which phase ran first instead of config-only defaults.

💡 Suggested fix
-        # Initialize thresholds from dict config (prefill phase as default)
-        self.thresholds = self.thresholds_config.get("prefill", [1e-3])
+        # Deterministic fallback for configs that define only one phase
+        self._fallback_thresholds = (
+            self.thresholds_config.get("prefill")
+            or self.thresholds_config.get("decode")
+            or [1e-3]
+        )
+        self.thresholds = list(self._fallback_thresholds)
 
     def _update_thresholds(self, phase: str):
         """Update thresholds list based on phase."""
-        self.thresholds = self.thresholds_config.get(phase, self.thresholds)
+        self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 59 - 71, The _update_thresholds method currently falls back to the
current runtime self.thresholds when a phase key is missing, causing behavior to
depend on phase order; change _update_thresholds to load from the configuration
only by setting self.thresholds = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3])) (or another config-level default
key) so missing phase entries always fall back to a config-defined default
instead of the previous runtime value; update references to
thresholds/thresholds_config and preserve the calibration flag logic in
set_calibration_mode/_update_thresholds.
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

130-136: ⚠️ Potential issue | 🟠 Major

Ensure calibration mode is always disabled on failure.

An exception in forward_loop(model) or _extract_calibration_stats(...) leaves modules in calibration mode with trial thresholds still set.

💡 Suggested fix
         self._set_thresholds(attention_modules, self.threshold_trials)
         self._enable_calibration_mode(attention_modules)
-        with torch.no_grad():
-            forward_loop(model)
-        per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
-        self._disable_calibration_mode(attention_modules)
+        try:
+            with torch.no_grad():
+                forward_loop(model)
+            per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
+        finally:
+            self._disable_calibration_mode(attention_modules)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 130 - 136, Wrap the forward/measurement block in a try/finally to
guarantee cleanup: before calling self._set_thresholds(...) save the original
thresholds, then call self._set_thresholds(attention_modules,
self.threshold_trials) and self._enable_calibration_mode(attention_modules),
perform forward_loop(model) and self._extract_calibration_stats(... ) inside the
try, and in the finally always call
self._disable_calibration_mode(attention_modules) and restore/clear thresholds
by calling self._set_thresholds(attention_modules, original_thresholds) (or
self._set_thresholds(attention_modules, None) if no original) so modules never
remain in calibration mode or with trial thresholds after an exception.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)

137-149: Consider validating sparsity_list length matches threshold_trials.

The zip at line 140 silently truncates if sparsity_list has fewer entries than self.threshold_trials. This could mask bugs where modules don't report all thresholds.

💡 Suggested validation
         for sample_stat in per_sample_stats:
             length = sample_stat["sample_length"]
             sparsity_list = sample_stat["sparsity"]
+            if len(sparsity_list) != len(self.threshold_trials):
+                raise ValueError(
+                    f"Sparsity list length {len(sparsity_list)} doesn't match "
+                    f"threshold_trials length {len(self.threshold_trials)}"
+                )
             for threshold, sparsity in zip(self.threshold_trials, sparsity_list):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around
lines 137 - 149, The code currently zips self.threshold_trials with
sparsity_list (from sample_stat["sparsity"]) which silently truncates when
sparsity_list is shorter; update the loop in the per_sample_stats processing to
validate that len(sparsity_list) == len(self.threshold_trials) (or at least >=)
before iterating, and if the lengths mismatch raise an exception or log an
explicit error including the offending sample (sample_stat) and the lengths;
only proceed to append to all_data_points when the validation passes to avoid
masked bugs in threshold reporting.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the forward/measurement block in a try/finally to
guarantee cleanup: before calling self._set_thresholds(...) save the original
thresholds, then call self._set_thresholds(attention_modules,
self.threshold_trials) and self._enable_calibration_mode(attention_modules),
perform forward_loop(model) and self._extract_calibration_stats(... ) inside the
try, and in the finally always call
self._disable_calibration_mode(attention_modules) and restore/clear thresholds
by calling self._set_thresholds(attention_modules, original_thresholds) (or
self._set_thresholds(attention_modules, None) if no original) so modules never
remain in calibration mode or with trial thresholds after an exception.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 59-71: The _update_thresholds method currently falls back to the
current runtime self.thresholds when a phase key is missing, causing behavior to
depend on phase order; change _update_thresholds to load from the configuration
only by setting self.thresholds = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3])) (or another config-level default
key) so missing phase entries always fall back to a config-defined default
instead of the previous runtime value; update references to
thresholds/thresholds_config and preserve the calibration flag logic in
set_calibration_mode/_update_thresholds.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-149: The code currently zips self.threshold_trials with
sparsity_list (from sample_stat["sparsity"]) which silently truncates when
sparsity_list is shorter; update the loop in the per_sample_stats processing to
validate that len(sparsity_list) == len(self.threshold_trials) (or at least >=)
before iterating, and if the lengths mismatch raise an exception or log an
explicit error including the offending sample (sample_stat) and the lengths;
only proceed to append to all_data_points when the validation passes to avoid
masked bugs in threshold reporting.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ade8fd6f-a93b-4d32-9878-61b65c4b9f23

📥 Commits

Reviewing files that changed from the base of the PR and between 586209f and 8430617.

📒 Files selected for processing (15)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
  • tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py
  • tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
  • tests/_test_utils/torch/sparsity/sparse_attention_common.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py

@codecov
Copy link

codecov bot commented Mar 12, 2026

Codecov Report

❌ Patch coverage is 82.60870% with 16 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.09%. Comparing base (69c0d47) to head (8430617).

Files with missing lines Patch % Lines
...rsity/attention_sparsity/calibration/calibrator.py 45.45% 12 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 92.15% 4 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #987   +/-   ##
=======================================
  Coverage   70.09%   70.09%           
=======================================
  Files         221      221           
  Lines       25459    25491   +32     
=======================================
+ Hits        17845    17868   +23     
- Misses       7614     7623    +9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants