Skip to content

Added block wise RHT#1014

Open
kinjalpatel27 wants to merge 9 commits intomainfrom
kinjal/block_rht
Open

Added block wise RHT#1014
kinjalpatel27 wants to merge 9 commits intomainfrom
kinjal/block_rht

Conversation

@kinjalpatel27
Copy link
Contributor

@kinjalpatel27 kinjalpatel27 commented Mar 10, 2026

What does this PR do?

Added support for RHT with non-power of 2. Rotate quantization configuration can be used to specify block_size as well.

NVFP4_KV_ROTATE_BLOCK_32_CFG = {
    "quant_cfg": {
        "*q_bmm_quantizer": {
            "enable": False,
            "rotate": {"enable": True, "block_size": 32},
        },
        "*k_bmm_quantizer": {
            **_nvfp4_quantizer,
            "rotate": {"enable": True, "block_size": 32},
        },
        "*v_bmm_quantizer": _nvfp4_quantizer,
    },
    "algorithm": "max",
}

Testing

pytest  tests/gpu/torch/quantization/test_hadamard.py -k test_hadamard_transform_block

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

Summary by CodeRabbit

  • New Features

    • Block-granular randomized Hadamard transform (RHT) for non-power-of-2 dimensions.
    • Rotation configuration expanded to accept block-size and an option to perform rotation in FP32; rotation settings are now exposed via the public API.
  • Tests

    • Added tests validating block-granular RHT across varied dimensions and block sizes.
  • Documentation

    • Changelog updated to mention the new block-granular RHT capability.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 10, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds block-granular randomized Hadamard transform (RHT) for non-power-of-2 dimensions: new RotateConfig with block_size in config, normalized_hadamard_transform supports block-wise FHT and block_size inference/validation, TensorQuantizer exposes/forwards block_size, and tests cover block-mode behavior.

Changes

Cohort / File(s) Summary
Changelog
CHANGELOG.rst
Added note about support for block-granular RHT for non-power-of-2 dimensions.
Configuration
modelopt/torch/quantization/config.py
Added RotateConfig type (enable: bool, rotate_fp32: bool, `block_size: int
Core Transform
modelopt/torch/quantization/nn/functional.py
Added _largest_pow2_divisor helper and extended normalized_hadamard_transform(inputs, rotate_fp32=False, block_size=None) to infer/validate block_size, reshape into blocks, apply FHT per block, and rescale by sqrt(block_size); preserves rotate_fp32 casting behavior.
Quantizer Integration
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Imported RotateConfig; added rotate_block_size property; updated rotate enable/fp32 checks to accept RotateConfig; forwards block_size to normalized_hadamard_transform; includes block size in extra_repr.
Tests
tests/gpu/torch/quantization/test_hadamard.py
Added test_hadamard_transform_block parametrized over dims and block sizes to validate block-granular RHT preserves covariance within tolerance.

Sequence Diagram(s)

sequenceDiagram
    participant Config as RotateConfig
    participant TensorQ as TensorQuantizer
    participant FNS as normalized_hadamard_transform
    participant FHT as FastHadamardTransform

    Config->>TensorQ: provide rotate settings (enable, rotate_fp32, block_size)
    TensorQ->>FNS: forward(inputs, rotate_fp32, block_size)
    FNS->>FNS: validate or infer block_size (largest pow2 divisor)
    FNS->>FNS: reshape inputs into blocks (num_blocks, block_size)
    FNS->>FHT: apply FHT per block
    FHT-->>FNS: transformed blocks
    FNS-->>TensorQ: return transformed & rescaled tensor
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Added block wise RHT' directly summarizes the main feature addition of block-granular Randomized Hadamard Transform support, which is the primary change across the codebase modifications.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No instances of the six critical security anti-patterns specified in SECURITY.md guidelines detected in the pull request changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kinjal/block_rht
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 48.93617% with 24 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.07%. Comparing base (26cad67) to head (fb6d9ec).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/nn/functional.py 12.50% 14 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 61.90% 8 Missing ⚠️
modelopt/torch/quantization/config.py 80.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1014      +/-   ##
==========================================
- Coverage   70.09%   70.07%   -0.02%     
==========================================
  Files         221      221              
  Lines       25459    25499      +40     
==========================================
+ Hits        17845    17869      +24     
- Misses       7614     7630      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/block_rht branch 2 times, most recently from 35a0aa6 to a614f6a Compare March 10, 2026 18:45
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
@kinjalpatel27 kinjalpatel27 marked this pull request as ready for review March 11, 2026 21:16
@kinjalpatel27 kinjalpatel27 requested a review from a team as a code owner March 11, 2026 21:16
@kinjalpatel27 kinjalpatel27 requested a review from ajrasane March 11, 2026 21:16
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/config.py (1)

990-995: ⚠️ Potential issue | 🟡 Minor

The public rotate docs still describe only the old full-size Hadamard path.

When block_size is set, or auto-selected for a non-power-of-2 dimension, the implementation applies a block-diagonal transform and normalizes each block by sqrt(block_size). The formula here still describes hadamard(input.shape[-1]) / sqrt(input.shape[-1]), which is incorrect for the new mode.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/config.py` around lines 990 - 995, The docstring
for the rotate option currently states the input is transformed with
scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1]), which is
incorrect when block_size is set or auto-selected; update the documentation for
the rotate parameter to state that when block_size is specified (or
auto-selected for non-power-of-2 dims) a block-diagonal Hadamard transform is
applied (i.e., each block uses scipy.linalg.hadamard(block_size)) and each block
is normalized by 1/sqrt(block_size), whereas the full Hadamard
(hadamard(N)/sqrt(N)) only applies when using the full dimension without
blocking. Ensure you reference the rotate parameter, block_size behavior, and
scipy.linalg.hadamard in the docstring so readers know the per-block
normalization and auto-selection behavior.
🧹 Nitpick comments (1)
tests/gpu/torch/quantization/test_hadamard.py (1)

56-66: Add one config-driven coverage case for rotate.block_size.

This only exercises normalized_hadamard_transform() directly, so it will not catch regressions in QuantizerAttributeConfig.rotate parsing or the new TensorQuantizer.forward(..., block_size=self.rotate_block_size) plumbing. A small set_quantizer_by_cfg_context(..., {"rotate": {"enable": True, "block_size": 32}}) case would close that gap.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/quantization/test_hadamard.py` around lines 56 - 66, Add a
config-driven coverage case that exercises the rotate.block_size plumbing: in
test_hadamard_transform_block add (or a new short test) a context using
set_quantizer_by_cfg_context({"rotate": {"enable": True, "block_size": 32}})
then call normalized_hadamard_transform on the same input tensor without passing
block_size (so the call path uses QuantizerAttributeConfig.rotate /
TensorQuantizer.forward via rotate_block_size), and assert the xxt preservation
as before; this ensures rotate parsing and forward(...,
block_size=self.rotate_block_size) plumbing are exercised in addition to direct
normalized_hadamard_transform calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/config.py`:
- Around line 978-988: Add a pydantic field validator for the rotate
ModeloptField to validate that if "block_size" is present it must be an int or
None and must be >0 (reject booleans and non-positive values), and update the
rotate field docstring to describe both full-dimension Hadamard and
block-granular RHT behaviors; specifically, add a `@field_validator`("rotate")
that inspects the dict (or coerced bool) and raises a ValueError when block_size
is not None and not an int > 0 (or when it is a bool), ensure the type
annotation allows None for block_size, and mention in the docstring how
rotate_block_size and normalized_hadamard_transform handle None (auto-select
largest power-of-2 divisor) vs an explicit positive block_size for
block-granular transforms.

---

Outside diff comments:
In `@modelopt/torch/quantization/config.py`:
- Around line 990-995: The docstring for the rotate option currently states the
input is transformed with scipy.linalg.hadamard(input.shape[-1]) /
sqrt(input.shape[-1]), which is incorrect when block_size is set or
auto-selected; update the documentation for the rotate parameter to state that
when block_size is specified (or auto-selected for non-power-of-2 dims) a
block-diagonal Hadamard transform is applied (i.e., each block uses
scipy.linalg.hadamard(block_size)) and each block is normalized by
1/sqrt(block_size), whereas the full Hadamard (hadamard(N)/sqrt(N)) only applies
when using the full dimension without blocking. Ensure you reference the rotate
parameter, block_size behavior, and scipy.linalg.hadamard in the docstring so
readers know the per-block normalization and auto-selection behavior.

---

Nitpick comments:
In `@tests/gpu/torch/quantization/test_hadamard.py`:
- Around line 56-66: Add a config-driven coverage case that exercises the
rotate.block_size plumbing: in test_hadamard_transform_block add (or a new short
test) a context using set_quantizer_by_cfg_context({"rotate": {"enable": True,
"block_size": 32}}) then call normalized_hadamard_transform on the same input
tensor without passing block_size (so the call path uses
QuantizerAttributeConfig.rotate / TensorQuantizer.forward via
rotate_block_size), and assert the xxt preservation as before; this ensures
rotate parsing and forward(..., block_size=self.rotate_block_size) plumbing are
exercised in addition to direct normalized_hadamard_transform calls.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ee2ffa98-eca1-431e-8c8e-9afc20b99a49

📥 Commits

Reviewing files that changed from the base of the PR and between 52f8783 and 719b0d7.

📒 Files selected for processing (5)
  • CHANGELOG.rst
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/nn/functional.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • tests/gpu/torch/quantization/test_hadamard.py

…fig-driven test

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/config.py`:
- Around line 978-1012: Replace the raw dict for the rotate field with a nested
Pydantic model to enforce allowed keys and per-field types: define a
RotateConfig model (fields: enable: bool = False, rotate_fp32: bool = False,
block_size: Optional[int] = None) with Config extra=forbid to reject unknown
keys, then change the field type from dict[str, ...] to bool | RotateConfig in
the rotate ModeloptField declaration; update the validate_rotate classmethod to
accept bool (return as-is or wrap into RotateConfig) and to convert/validate
incoming dicts by instantiating RotateConfig (e.g., RotateConfig.model_validate
or RotateConfig(**v)), raising any validation errors so callers like
normalized_hadamard_transform get a guaranteed typed config.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c48b57a9-36f5-4714-b058-29dd0a740b9f

📥 Commits

Reviewing files that changed from the base of the PR and between 8782400 and 08e1a6a.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/config.py

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/gpu/torch/quantization/test_hadamard.py (1)

56-68: Consider adding a small atol fallback and rotate_fp32 coverage.

Two observations on test robustness:

  1. Using atol=0.0 relies entirely on relative tolerance. If any element in xxt or xxt_h is very close to zero, even tiny absolute differences could cause spurious failures. A small absolute tolerance (e.g., atol=1e-6) provides a safety margin.

  2. The existing test_hadamard_transform validates both default and rotate_fp32=True modes. Consider extending this test (or adding a parameter) to cover rotate_fp32=True with block_size to ensure numerical stability in float32 rotation mode.

Suggested improvement
 `@pytest.mark.parametrize`(
     ("dim", "block_size"),
     [(1920, 128), (1536, 128), (1920, None), (64, 32)],
 )
-def test_hadamard_transform_block(dim, block_size):
+@pytest.mark.parametrize("rotate_fp32", [False, True])
+def test_hadamard_transform_block(dim, block_size, rotate_fp32):
     """Block-granular RHT for non-power-of-2 dimensions (e.g. MoE expert channels)."""
     x = torch.rand(4, dim, device="cuda")
     xxt = x @ x.T
-    x_h = normalized_hadamard_transform(x, block_size=block_size)
+    x_h = normalized_hadamard_transform(x, rotate_fp32=rotate_fp32, block_size=block_size)
     xxt_h = x_h @ x_h.T
     # Use rtol instead of atol: float32 accumulated error scales with value magnitude,
     # which grows with dim. 1e-3 relative tolerance is appropriate for float32 block RHT.
-    assert torch.allclose(xxt_h, xxt, rtol=1e-3, atol=0.0)
+    assert torch.allclose(xxt_h, xxt, rtol=1e-3, atol=1e-6)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/quantization/test_hadamard.py` around lines 56 - 68, Update
test_hadamard_transform_block to add a small absolute tolerance and cover
rotate_fp32: change the assertion to use atol=1e-6 instead of 0.0 and
parametrize the test (or add a loop) to call normalized_hadamard_transform(...,
rotate_fp32=bool) so both rotate_fp32=True and False are exercised alongside
block_size; reference the test function test_hadamard_transform_block and the
normalized_hadamard_transform call when making these changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/gpu/torch/quantization/test_hadamard.py`:
- Around line 56-68: Update test_hadamard_transform_block to add a small
absolute tolerance and cover rotate_fp32: change the assertion to use atol=1e-6
instead of 0.0 and parametrize the test (or add a loop) to call
normalized_hadamard_transform(..., rotate_fp32=bool) so both rotate_fp32=True
and False are exercised alongside block_size; reference the test function
test_hadamard_transform_block and the normalized_hadamard_transform call when
making these changes.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0568022d-4b9e-4de8-aec0-3dcdba267050

📥 Commits

Reviewing files that changed from the base of the PR and between 359cd58 and 217e2ff.

📒 Files selected for processing (1)
  • tests/gpu/torch/quantization/test_hadamard.py

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants