Skip to content

Activation hooks redesign (reuse hooks component across both minitron and puzzletron)#1022

Open
danielkorzekwa wants to merge 6 commits intomainfrom
dkorzewa/activation_hooks_redesign_minitron_puzzletron
Open

Activation hooks redesign (reuse hooks component across both minitron and puzzletron)#1022
danielkorzekwa wants to merge 6 commits intomainfrom
dkorzewa/activation_hooks_redesign_minitron_puzzletron

Conversation

@danielkorzekwa
Copy link

@danielkorzekwa danielkorzekwa commented Mar 11, 2026

What does this PR do?

Type of change: Redesign of existing feature

This PR introduces a shared activation hooks infrastructure for minitron and puzzletron. The activation hooks framework provides a reusable component for collecting and analyzing activations during forward passes, which is used by both minitron pruning and puzzletron algorithms.

Note! Minitron megatron.py/mcore_minitron.py:ImportanceEstimatorRegistry code does not use this component yet - will be refactored in a separate MR.

Key changes:

  • Added modelopt/torch/nas/plugins/megatron_hooks/ module with base hooks framework:

    • base_hooks.py: Core hook infrastructure for registering and managing forward hooks
    • base_hooks_analysis.py: Analysis utilities for processing collected activations
    • megatron_hooks.py: Megatron-specific hook implementations
    • compare_module_outputs.py: Utilities for comparing module outputs
  • Added unit tests in tests/gpu/torch/nas/plugins/megatron_hooks/:

    • test_base_hooks.py: Tests for base hooks functionality
    • test_base_hooks_analysis.py: Tests for activation analysis utilities
  • Updated test_mcore_gpt_minitron_pruning.py to validate activation collection

  • Updated test utilities for distributed testing support

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ Yes - This is a new module that doesn't affect existing functionality
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ No
  • Did you write any new necessary tests?: ✅ Yes - Added comprehensive tests for the activation hooks infrastructure
  • Did you update Changelog?: ✅ N/A - This is infrastructure code that will be used by subsequent PRs

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Megatron hooks plugin for neural architecture search with activation-based importance estimation for model pruning.
    • Introduced layer output comparison tool to analyze and report statistics across model variants.
    • Added robust JSON serialization utilities for complex Python objects and configurations.
    • Added output flushing print utility for distributed training.
  • Tests

    • Introduced comprehensive unit tests for hook-based importance scoring and pruning evaluation.

- Add base hooks framework in modelopt/torch/nas/plugins/megatron_hooks/
  - base_hooks.py: Core hook infrastructure
  - base_hooks_analysis.py: Analysis utilities for hooks
  - megatron_hooks.py: Megatron-specific hook implementations
  - compare_module_outputs.py: Module comparison utilities
- Add tests for activation hooks
- Update test utilities for distributed testing
- Update minitron pruning tests to use new activation hooks
The activation hooks infrastructure depends on aprint from puzzletron.tools.logger.
Adding minimal logger module to satisfy this dependency.

Note: Some docstring linting warnings are suppressed as this is copied code.
… moved later outside of puzzletron module)

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

📝 Walkthrough

Walkthrough

This PR introduces a new Megatron NAS plugin system for activation-based importance estimation in neural networks. It includes abstract hook classes for capturing and analyzing layer activations, tensor-parallel support, output comparison tools, and comprehensive unit tests. Additionally, utility functions for logging and robust JSON serialization are added.

Changes

Cohort / File(s) Summary
Megatron Hooks Plugin Core
modelopt/torch/nas/plugins/megatron_hooks/__init__.py, modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py, modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py
New plugin infrastructure with abstract ForwardHook base class, five concrete hook implementations (L2NormHook, IndependentChannelContributionHook, IterativeChannelContributionHook, IndependentKvHeadContributionHook, LayerNormContributionHook), tensor-parallel support, activation accumulation, and state management.
Analysis & Comparison Tools
modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py, modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py
New analysis module with evaluate_importance_scores function for pruning simulation and metric computation. Output comparison module with OutputSaveHook for layer output capture and multi-layer comparison utilities (RMSE, cosine similarity) with JSON reporting.
Hook System Tests
tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py, tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py
Comprehensive unit tests validating hook behavior (IterativeChannelContributionHook, L2NormHook) and analysis functions across multiple hook types with deterministic inputs and expected output assertions.
Test Infrastructure
tests/_test_utils/torch/distributed/utils.py, tests/conftest.py, tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
Enhanced distributed test setup with environment variables (RANK, LOCAL_RANK, WORLD_SIZE, WANDB_DISABLED), new pytest fixture for project root path, and extended pruning validation with deterministic initialization and activation statistics assertions.
Utility Functions
modelopt/torch/utils/logging.py, modelopt/torch/utils/robust_json.py
New aprint function for synchronized output, and comprehensive RobustJSONEncoder with serialization utilities (json_dumps, json_dump, json_load) supporting dataclasses, Paths, Enums, OmegaConf, and callables.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error PR introduces critical RCE vulnerabilities by using torch.load() on user-supplied paths without weights_only=True parameter, violating SECURITY.md guidelines. Add weights_only=True to torch.load() calls in compare_module_outputs.py and base_hooks.py, or add security comments justifying weights_only=False usage.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately summarizes the main objective: introducing a shared activation hooks infrastructure to be reused across minitron and puzzletron projects, which is exactly what the changeset delivers.
Docstring Coverage ✅ Passed Docstring coverage is 93.41% which is sufficient. The required threshold is 80.00%.
✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch dkorzewa/activation_hooks_redesign_minitron_puzzletron

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (6)
tests/_test_utils/torch/distributed/utils.py (1)

26-30: Socket not explicitly closed after getting free port.

The socket is bound to get a free port but never closed. While Python's garbage collector will eventually close it, explicitly closing the socket ensures the port is released promptly and avoids potential resource leaks.

♻️ Suggested fix
 def get_free_port():
     sock = socket.socket()
     sock.bind(("", 0))
     port = sock.getsockname()[1]
+    sock.close()
     return port
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/_test_utils/torch/distributed/utils.py` around lines 26 - 30, The
get_free_port function creates a socket but never closes it; update
get_free_port to explicitly close the socket after binding (e.g., use a context
manager with socket.socket(...) as sock or call sock.close()) so the port is
released promptly; modify the function containing get_free_port to ensure the
socket is closed before returning the port while keeping the same behavior of
binding to ("", 0) and returning getsockname()[1].
modelopt/torch/puzzletron/tools/logger.py (3)

96-108: Fragile stack frame navigation.

The triple f_back navigation assumes a fixed call depth. If the call chain changes (e.g., adding a wrapper or calling dist_log directly), incorrect source locations will be reported. Consider using inspect.stack() with a more robust approach.

♻️ Suggested improvement
     `@staticmethod`
-    def get_caller_location() -> str:
+    def get_caller_location(depth: int = 3) -> str:
         """Get the caller location from the stack frame."""
-        # Get the caller's stack frame
-        frame = inspect.currentframe()
-
-        # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source
-        caller_frame = frame.f_back.f_back.f_back
-
-        # Get the filename and line number from the caller's stack frame
-        filename = os.path.basename(caller_frame.f_code.co_filename)
-        lineno = caller_frame.f_lineno
-        return f"{filename}:{lineno}"
+        stack = inspect.stack()
+        if len(stack) > depth:
+            frame_info = stack[depth]
+            return f"{os.path.basename(frame_info.filename)}:{frame_info.lineno}"
+        return "unknown:0"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 96 - 108, The
get_caller_location method currently walks a fixed f_back chain which is
brittle; replace that logic in get_caller_location with a robust
inspect.stack()-based search: call inspect.stack() and iterate frames to find
the first frame whose module/filename is not this logger module (or where the
function name is not get_caller_location/dist_log), then use that frame's
filename and lineno to return "filename:lineno"; ensure you fall back to the
last non-None frame if none match to avoid exceptions.

23-23: Unused import of torch.distributed.launch.

This import is marked with # noqa: F401 (unused import) but there's no apparent reason for its presence. If it's needed for a side effect, please add a comment explaining why.

♻️ Suggested fix - remove if not needed
-import torch.distributed.launch  # noqa: F401
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` at line 23, The unused import
torch.distributed.launch (currently annotated with # noqa: F401) should be
removed to eliminate dead code; if the import is intentionally required for a
side effect, replace the bare # noqa with a short explanatory comment (e.g.,
"import to register torch distributed launch entrypoint for CLI/side-effects")
next to the import so future readers understand its purpose—update the import
statement at the top of logger.py accordingly.

111-114: Global logger class modification may have unintended side effects.

logging.setLoggerClass(DistributedLogger) changes the default logger class for the entire process. Any subsequent logging.getLogger() calls in other modules will create DistributedLogger instances, which may not be intended.

Consider using a factory function or creating the logger directly without modifying the global logger class.

♻️ Alternative approach
-# Initialize logger
-logging.setLoggerClass(DistributedLogger)
-logger = logging.getLogger(__name__)
+# Initialize logger without modifying global logger class
+logger = DistributedLogger(__name__)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 111 - 114, Remove the
global side-effect call to logging.setLoggerClass(DistributedLogger) and instead
instantiate or provide a factory that returns a DistributedLogger explicitly;
replace the current sequence with something like creating the module logger by
calling DistributedLogger(__name__) (or implement a get_distributed_logger(name)
helper that returns DistributedLogger(name)), assign it to the logger variable
and keep logger.propagate = False. Ensure no other global logger class is
changed so other modules calling logging.getLogger() are unaffected.
modelopt/torch/puzzletron/tools/robust_json.py (2)

31-31: Hard import of optional dependency at module level.

Per coding guidelines, optional dependencies should be gated. omegaconf may not be installed in all environments.

♻️ Suggested fix - gate the import
-from omegaconf import DictConfig, ListConfig, OmegaConf
+try:
+    from omegaconf import DictConfig, ListConfig, OmegaConf
+    _HAS_OMEGACONF = True
+except ImportError:
+    DictConfig = ListConfig = None
+    _HAS_OMEGACONF = False

Then in RobustJSONEncoder.default:

-        if isinstance(o, (DictConfig, ListConfig)):
+        if _HAS_OMEGACONF and isinstance(o, (DictConfig, ListConfig)):
             return OmegaConf.to_container(o, resolve=True)

As per coding guidelines: "Avoid hard imports of optional dependencies at module level; gate features by install extras."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/robust_json.py` at line 31, The module
currently hard-imports omegaconf at top-level; gate this optional dependency by
moving the import into the code path that needs it and handling ImportError:
remove the top-level "from omegaconf ..." import and instead import omegaconf
(or specific symbols) inside RobustJSONEncoder.default (and/or any other
functions that use DictConfig/ListConfig/OmegaConf), catch ImportError and fall
back to treating those objects as regular mappings or raise a clear runtime
error instructing to install the optional extra; update references to
DictConfig/ListConfig/OmegaConf in RobustJSONEncoder.default to use the local
import or the fallback behavior.

74-78: Return type hint is too restrictive.

json.loads can return any JSON-compatible type (dict, list, str, int, etc.), not just dict. Either update the type hint or add validation.

♻️ Suggested fix
-def json_load(path: Path | str) -> dict:
+def json_load(path: Path | str) -> Any:
     """Load JSON from file and return as dictionary."""
     path = Path(path)
     text = path.read_text()
     return json.loads(text)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/robust_json.py` around lines 74 - 78, The
function json_load currently types its return as dict but json.loads can return
any JSON-compatible type; update json_load's signature and implementation:
either change the return type from dict to a broad JSON type (e.g., Any or a
custom JSONType alias for dict|list|str|int|float|bool|None) or keep dict but
validate the loaded value and raise a clear exception if it's not a dict; refer
to the function name json_load and the code that calls
Path.read_text()/json.loads to implement the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py`:
- Around line 24-103: In evaluate_importance_scores, avoid building autograd
graphs and guard against empty input: wrap the per-batch evaluation loop (the
section using linear_layer, pruned_activations, pruned_output, rmse/cosine
computation) in a torch.no_grad() context to prevent gradient graph allocation,
and add an early validation at the top that checks activations_batches is
non-empty (raise a ValueError with a clear message or return zeroed metrics)
before computing num_to_prune; reference the function name
evaluate_importance_scores and the local symbols activations_batches,
linear_layer, pruned_activations, original_output, pruned_output, rmse_values,
and cosine_values when making the edits.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Around line 154-156: The code currently mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before json_dump, which removes the
model reference for later uses; instead, create a copy of
activation_hooks_kwargs (e.g., tmp = dict(args.activation_hooks_kwargs) or use
copy.deepcopy) and remove "model" from that copy if present, then call
json_dump(OmegaConf.to_container(args_copy_or_modified_args, resolve=True),
activations_log_dir / "args.json") so the original args.activation_hooks_kwargs
remains unchanged; reference the symbols args.activation_hooks_kwargs,
json_dump, OmegaConf.to_container, and activations_log_dir / "args.json".
- Around line 162-180: save_hook_states currently calls hook.state_dict() for
every hook which fails when certain hooks (e.g.,
IndependentKvHeadContributionHook, LayerNormContributionHook) raise
NotImplementedError; modify save_hook_states to safely handle hooks that don't
support checkpointing by wrapping the state collection in a try/except (catch
NotImplementedError and skip that hook) or by checking for a supported
method/attribute before calling state_dict, and optionally record/log skipped
module names; ensure you reference save_hook_states and state_dict in your
change so only hooks that successfully return state are included in the saved
hook_states dict.
- Around line 231-259: The L2NormHook.load_state_dict currently assigns
checkpointed "_activations" directly, which can leave tensors on the wrong
device and later cause device-mismatch in L2NormHook.accumulate (and in the "+="
path). Update L2NormHook.load_state_dict to move the loaded activations to the
current module/device before assigning (follow the pattern used in
IndependentChannelContributionHook.load_state_dict): detect the target device
(e.g., from a provided module or torch.device of existing tensors), call
.to(device) on state_dict["activations"] (preserving dtype), then assign to
self._activations so subsequent accumulate() and add_ operations are
device-safe.
- Around line 456-480: The recomputed forward output (output_curr) omits the
layer bias while output_tensor includes it, skewing scaling_factor_per_token;
modify the recomputation in the block using curr_activations and
self.weight_matrix so it also adds the layer bias when present (e.g., compute
output_curr = F.linear(input=curr_activations, weight=self.weight_matrix,
bias=self.bias) or add self.bias expanded to output_curr's shape), ensuring the
bias shape matches output_curr before computing output_norms and
scaling_factor_per_token; reference symbols: output_tensor, output_curr,
curr_activations, self.weight_matrix, self.bias, scaling_factor_per_token.

In `@modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py`:
- Around line 240-244: Replace the current prints-and-return branch that handles
layer mismatches with a hard failure: instead of printing "ERROR: Layer
mismatch!" and returning, raise a RuntimeError (or other appropriate exception)
that includes the mismatched ref_layers and comp_layers in the message so the
CLI/CI exits non-zero; modify the code in compare_module_outputs.py where
ref_layers and comp_layers are compared (the current if set(ref_layers) !=
set(comp_layers): block) to raise the exception with clear context.

In `@modelopt/torch/puzzletron/tools/logger.py`:
- Around line 69-73: The NotImplementedError message in the ranks validation
(inside the function/method that broadcasts messages, referencing variables msg
and ranks) is missing the 'last' choice; update the error text to list all valid
options consistently with the check — e.g., include 'last' alongside 'all',
'main', and 'local_main' — so the raised message accurately reflects the allowed
ranks values.
- Around line 79-84: The "last" branch incorrectly compares self.local_rank to
self.world_size - 1; replace that check to target the last local rank on node 0
by testing node and local sizes: change the condition (ranks == "last" and
self.local_rank != self.world_size - 1) to require that the process is not the
last local rank on node 0 — e.g., use self.node_rank and self.local_world_size
and only allow printing when (self.node_rank == 0 and self.local_rank ==
self.local_world_size - 1); update the condition accordingly in the function
that contains the ranks logic so it references self.node_rank, self.local_rank,
and self.local_world_size instead of self.world_size.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`:
- Around line 217-257: The test uses fragile, hardware-specific hard-coded
activation values (e.g., pruning_scores["layer_scores"], activations entries
like "decoder.layers.0.mlp" / "decoder.layers.0.self_attention") checked with
_assert_approx and abs=1e-3; replace these strict exact-value checks with more
robust assertions: validate tensor shapes and value ranges or use relative
tolerance (rtol) instead of strict absolute checks, or widen/make the tolerance
configurable via a test constant; for layer-level correctness prefer asserting
expected monotonic/relative relationships (e.g., pruned vs unpruned scores)
rather than exact floats; optionally split the large conditional block into
separate focused tests for MHA vs GQA using the same symbols (pruning_scores,
activations, _assert_approx) to improve readability and maintainability.

In `@tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py`:
- Around line 113-154: The forward hook handle created in _run_hook_and_evaluate
is only removed on the happy path; wrap the work that runs the forward passes
and calls hook.accumulate (the loop that appends to all_activations and the call
to importance_scores = hook.accumulate()) in a try/finally and call
handle.remove() in the finally block to guarantee the hook is detached even on
exceptions; re-raise the exception after cleanup if one occurred so failures are
propagated, and keep the subsequent call to evaluate_importance_scores
unchanged.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/tools/logger.py`:
- Around line 96-108: The get_caller_location method currently walks a fixed
f_back chain which is brittle; replace that logic in get_caller_location with a
robust inspect.stack()-based search: call inspect.stack() and iterate frames to
find the first frame whose module/filename is not this logger module (or where
the function name is not get_caller_location/dist_log), then use that frame's
filename and lineno to return "filename:lineno"; ensure you fall back to the
last non-None frame if none match to avoid exceptions.
- Line 23: The unused import torch.distributed.launch (currently annotated with
# noqa: F401) should be removed to eliminate dead code; if the import is
intentionally required for a side effect, replace the bare # noqa with a short
explanatory comment (e.g., "import to register torch distributed launch
entrypoint for CLI/side-effects") next to the import so future readers
understand its purpose—update the import statement at the top of logger.py
accordingly.
- Around line 111-114: Remove the global side-effect call to
logging.setLoggerClass(DistributedLogger) and instead instantiate or provide a
factory that returns a DistributedLogger explicitly; replace the current
sequence with something like creating the module logger by calling
DistributedLogger(__name__) (or implement a get_distributed_logger(name) helper
that returns DistributedLogger(name)), assign it to the logger variable and keep
logger.propagate = False. Ensure no other global logger class is changed so
other modules calling logging.getLogger() are unaffected.

In `@modelopt/torch/puzzletron/tools/robust_json.py`:
- Line 31: The module currently hard-imports omegaconf at top-level; gate this
optional dependency by moving the import into the code path that needs it and
handling ImportError: remove the top-level "from omegaconf ..." import and
instead import omegaconf (or specific symbols) inside RobustJSONEncoder.default
(and/or any other functions that use DictConfig/ListConfig/OmegaConf), catch
ImportError and fall back to treating those objects as regular mappings or raise
a clear runtime error instructing to install the optional extra; update
references to DictConfig/ListConfig/OmegaConf in RobustJSONEncoder.default to
use the local import or the fallback behavior.
- Around line 74-78: The function json_load currently types its return as dict
but json.loads can return any JSON-compatible type; update json_load's signature
and implementation: either change the return type from dict to a broad JSON type
(e.g., Any or a custom JSONType alias for dict|list|str|int|float|bool|None) or
keep dict but validate the loaded value and raise a clear exception if it's not
a dict; refer to the function name json_load and the code that calls
Path.read_text()/json.loads to implement the chosen approach.

In `@tests/_test_utils/torch/distributed/utils.py`:
- Around line 26-30: The get_free_port function creates a socket but never
closes it; update get_free_port to explicitly close the socket after binding
(e.g., use a context manager with socket.socket(...) as sock or call
sock.close()) so the port is released promptly; modify the function containing
get_free_port to ensure the socket is closed before returning the port while
keeping the same behavior of binding to ("", 0) and returning getsockname()[1].

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: afede6d7-95c6-4028-bab9-e7b8241fd750

📥 Commits

Reviewing files that changed from the base of the PR and between fe83270 and 6ce8345.

📒 Files selected for processing (14)
  • modelopt/torch/nas/plugins/megatron_hooks/__init__.py
  • modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py
  • modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py
  • modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py
  • modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py
  • modelopt/torch/puzzletron/__init__.py
  • modelopt/torch/puzzletron/tools/__init__.py
  • modelopt/torch/puzzletron/tools/logger.py
  • modelopt/torch/puzzletron/tools/robust_json.py
  • tests/_test_utils/torch/distributed/utils.py
  • tests/conftest.py
  • tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py
  • tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Comment on lines +24 to +103
def evaluate_importance_scores(
linear_layer: nn.Linear,
activations_batches: list[torch.Tensor],
importance_scores: torch.Tensor,
prune_ratio: float = 0.2,
) -> dict[str, float]:
"""Compute reconstruction error after pruning input channels of a linear layer.

This function simulates channel pruning by zeroing out input channels identified as
least important, then measures how much the layer's output changes.

Args:
linear_layer: The linear layer to analyze with shape (out_features, in_features).
For example: nn.Linear(in_features=1024, out_features=4096)
activations_batches: List of input activation tensors.
Each tensor has shape [seq_len, batch_size, in_features].
The last dimension must match linear_layer.in_features.
Example: List of [16, 8, 1024] tensors
importance_scores: Importance score for each input channel (feature).
Shape: [in_features]. Lower scores = less important.
Example: [1024] tensor with one score per input feature
prune_ratio: Fraction of input channels to prune (default: 0.2 means prune 20%).

Returns:
Dictionary containing averaged metrics across all activation batches:
- rmse: Root mean squared error between original and pruned output
- cosine_similarity: Cosine similarity between original and pruned output
- num_pruned: Number of input channels pruned

Example:
>>> layer = nn.Linear(in_features=1024, out_features=4096)
>>> # Collect multiple batches for robust evaluation
>>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)]
>>> scores = torch.randn(1024) # one score per input feature
>>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2)
>>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels")

Note:
- This simulates pruning (zeros out inputs) without modifying layer weights
- "Channels" refers to INPUT features, not output features

"""
num_channels = importance_scores.shape[0]
num_to_prune = int(num_channels * prune_ratio)

# Identify channels to prune (lowest scoring = least important)
_, channels_to_prune = torch.topk(importance_scores, num_to_prune, largest=False)

# Compute metrics for each batch and average
rmse_values = []
cosine_values = []

for activations in activations_batches:
# Get original output
original_output = linear_layer(activations)

# Prune by zeroing out identified channels
pruned_activations = activations.clone()
pruned_activations[..., channels_to_prune] = 0

# Get pruned output
pruned_output = linear_layer(pruned_activations)

# Compute metrics for this batch
rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
rmse_values.append(rmse)

# Cosine similarity (flatten to vectors)
original_flat = original_output.reshape(-1)
pruned_flat = pruned_output.reshape(-1)
cosine = F.cosine_similarity(
original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
).item()
cosine_values.append(cosine)

# Return averaged metrics
return {
"rmse": sum(rmse_values) / len(rmse_values),
"cosine_similarity": sum(cosine_values) / len(cosine_values),
"num_pruned": num_to_prune,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

fd -t f "base_hooks_analysis.py"

Repository: NVIDIA/Model-Optimizer

Length of output: 201


🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py | head -110

Repository: NVIDIA/Model-Optimizer

Length of output: 5245


🏁 Script executed:

cat -n tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py

Repository: NVIDIA/Model-Optimizer

Length of output: 7539


🏁 Script executed:

# Also check if there are imports or dependencies we should verify
grep -n "torch.no_grad\|activations_batches\|evaluate_importance_scores" modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py

Repository: NVIDIA/Model-Optimizer

Length of output: 395


Wrap the evaluation loop in torch.no_grad() and validate non-empty input.

This analysis-only function builds two autograd graphs per batch unnecessarily, creating memory pressure during calibration. Additionally, passing an empty activations_batches list silently fails at the return statement with a ZeroDivisionError.

🛠️ Suggested fix
 def evaluate_importance_scores(
     linear_layer: nn.Linear,
     activations_batches: list[torch.Tensor],
     importance_scores: torch.Tensor,
     prune_ratio: float = 0.2,
 ) -> dict[str, float]:
@@
     """
     num_channels = importance_scores.shape[0]
+    if not activations_batches:
+        raise ValueError("activations_batches must be non-empty")
     num_to_prune = int(num_channels * prune_ratio)
@@
-    for activations in activations_batches:
-        # Get original output
-        original_output = linear_layer(activations)
-
-        # Prune by zeroing out identified channels
-        pruned_activations = activations.clone()
-        pruned_activations[..., channels_to_prune] = 0
-
-        # Get pruned output
-        pruned_output = linear_layer(pruned_activations)
-
-        # Compute metrics for this batch
-        rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
-        rmse_values.append(rmse)
-
-        # Cosine similarity (flatten to vectors)
-        original_flat = original_output.reshape(-1)
-        pruned_flat = pruned_output.reshape(-1)
-        cosine = F.cosine_similarity(
-            original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
-        ).item()
-        cosine_values.append(cosine)
+    with torch.no_grad():
+        for activations in activations_batches:
+            original_output = linear_layer(activations)
+
+            pruned_activations = activations.clone()
+            pruned_activations[..., channels_to_prune] = 0
+            pruned_output = linear_layer(pruned_activations)
+
+            rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
+            rmse_values.append(rmse)
+
+            original_flat = original_output.reshape(-1)
+            pruned_flat = pruned_output.reshape(-1)
+            cosine = F.cosine_similarity(
+                original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
+            ).item()
+            cosine_values.append(cosine)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py` around
lines 24 - 103, In evaluate_importance_scores, avoid building autograd graphs
and guard against empty input: wrap the per-batch evaluation loop (the section
using linear_layer, pruned_activations, pruned_output, rmse/cosine computation)
in a torch.no_grad() context to prevent gradient graph allocation, and add an
early validation at the top that checks activations_batches is non-empty (raise
a ValueError with a clear message or return zeroed metrics) before computing
num_to_prune; reference the function name evaluate_importance_scores and the
local symbols activations_batches, linear_layer, pruned_activations,
original_output, pruned_output, rmse_values, and cosine_values when making the
edits.

Comment on lines +154 to +156
if rank == 0:
args.activation_hooks_kwargs.pop("model")
json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't mutate args.activation_hooks_kwargs just to dump JSON.

pop("model") edits the live config object. Reusing the same args later loses the model reference, and a second dump on rank 0 can fail on the missing key.

🛠️ Suggested fix
         if rank == 0:
-            args.activation_hooks_kwargs.pop("model")
-            json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
+            args_to_dump = OmegaConf.to_container(args, resolve=True)
+            args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
+            json_dump(args_to_dump, activations_log_dir / "args.json")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if rank == 0:
args.activation_hooks_kwargs.pop("model")
json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
if rank == 0:
args_to_dump = OmegaConf.to_container(args, resolve=True)
args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
json_dump(args_to_dump, activations_log_dir / "args.json")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 154 -
156, The code currently mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before json_dump, which removes the
model reference for later uses; instead, create a copy of
activation_hooks_kwargs (e.g., tmp = dict(args.activation_hooks_kwargs) or use
copy.deepcopy) and remove "model" from that copy if present, then call
json_dump(OmegaConf.to_container(args_copy_or_modified_args, resolve=True),
activations_log_dir / "args.json") so the original args.activation_hooks_kwargs
remains unchanged; reference the symbols args.activation_hooks_kwargs,
json_dump, OmegaConf.to_container, and activations_log_dir / "args.json".

Comment on lines +162 to +180
def save_hook_states(
cls: type["ForwardHook"],
activation_hooks: dict[str, "ForwardHook"],
activations_log_dir: Path | str,
) -> None:
"""Save hook states for checkpointing (separate from final results).

This can be called periodically during scoring.
Note: Synchronization should be handled at a higher level to avoid deadlocks.
"""
activations_log_dir = Path(activations_log_dir)
activations_log_dir.mkdir(exist_ok=True, parents=True)
rank = dist.rank()

hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth"
hook_states = {
module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
}
torch.save(hook_states, hook_states_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

save_hook_states() needs to handle hooks without checkpoint support.

This helper blindly calls state_dict() on every hook, but IndependentKvHeadContributionHook and LayerNormContributionHook later raise NotImplementedError. One such hook makes periodic checkpointing fail for the whole run.

🛠️ Suggested fix
-        hook_states = {
-            module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
-        }
+        hook_states = {}
+        unsupported = []
+        for module_name, hook in activation_hooks.items():
+            try:
+                hook_states[module_name] = hook.state_dict()
+            except NotImplementedError:
+                unsupported.append(module_name)
+
+        if unsupported:
+            aprint(
+                "Skipping hook checkpoint save for hooks without state support: "
+                f"{unsupported}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 162 -
180, save_hook_states currently calls hook.state_dict() for every hook which
fails when certain hooks (e.g., IndependentKvHeadContributionHook,
LayerNormContributionHook) raise NotImplementedError; modify save_hook_states to
safely handle hooks that don't support checkpointing by wrapping the state
collection in a try/except (catch NotImplementedError and skip that hook) or by
checking for a supported method/attribute before calling state_dict, and
optionally record/log skipped module names; ensure you reference
save_hook_states and state_dict in your change so only hooks that successfully
return state are included in the saved hook_states dict.

Comment on lines +231 to +259
if self._activations is None:
self._activations = activations
else:
self._activations += activations

def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.

Returns:
Tensor of accumulated scores, one per channel

Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)

def to_dict(self) -> dict[str, torch.Tensor]:
"""Convert to dict format for saving."""
return {"score": self.accumulate().cpu()}

def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
return {"activations": self._activations}

def load_state_dict(self, state_dict: dict) -> None:
"""Load activations from checkpoint."""
self._activations = state_dict["activations"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 125


🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | sed -n '200,280p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3929


🏁 Script executed:

# Find the IndependentChannelContributionHook class
rg "class IndependentChannelContributionHook" -A 50 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2237


🏁 Script executed:

# Get the full L2NormHook class definition
rg "class L2NormHook" -B 5 -A 100 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4345


🏁 Script executed:

# Find IndependentChannelContributionHook's load_state_dict method
rg "class IndependentChannelContributionHook" -A 200 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | grep -A 20 "load_state_dict"

Repository: NVIDIA/Model-Optimizer

Length of output: 1178


🏁 Script executed:

# Let's get more context on IndependentChannelContributionHook's state management
sed -n '262,500p' modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 10681


Make L2NormHook state reload device-agnostic.

After restoring from checkpoint, _activations remains on the checkpoint device. When the resumed layer runs on a different device, accumulation fails on line 234 due to device mismatch in the += operation. IndependentChannelContributionHook.load_state_dict() already demonstrates the correct pattern for device-agnostic state loading.

🛠️ Suggested fix
         if self._activations is None:
             self._activations = activations
         else:
+            if self._activations.device != activations.device:
+                self._activations = self._activations.to(activations.device)
             self._activations += activations
@@
     def state_dict(self) -> dict:
         """Return the state dictionary containing activations."""
-        return {"activations": self._activations}
+        return {
+            "activations": None
+            if self._activations is None
+            else self._activations.cpu().clone()
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self._activations is None:
self._activations = activations
else:
self._activations += activations
def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.
Returns:
Tensor of accumulated scores, one per channel
Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)
def to_dict(self) -> dict[str, torch.Tensor]:
"""Convert to dict format for saving."""
return {"score": self.accumulate().cpu()}
def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
return {"activations": self._activations}
def load_state_dict(self, state_dict: dict) -> None:
"""Load activations from checkpoint."""
self._activations = state_dict["activations"]
if self._activations is None:
self._activations = activations
else:
if self._activations.device != activations.device:
self._activations = self._activations.to(activations.device)
self._activations += activations
def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.
Returns:
Tensor of accumulated scores, one per channel
Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)
def to_dict(self) -> dict[str, torch.Tensor]:
"""Convert to dict format for saving."""
return {"score": self.accumulate().cpu()}
def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
return {
"activations": None
if self._activations is None
else self._activations.cpu().clone()
}
def load_state_dict(self, state_dict: dict) -> None:
"""Load activations from checkpoint."""
self._activations = state_dict["activations"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 231 -
259, The L2NormHook.load_state_dict currently assigns checkpointed
"_activations" directly, which can leave tensors on the wrong device and later
cause device-mismatch in L2NormHook.accumulate (and in the "+=" path). Update
L2NormHook.load_state_dict to move the loaded activations to the current
module/device before assigning (follow the pattern used in
IndependentChannelContributionHook.load_state_dict): detect the target device
(e.g., from a provided module or torch.device of existing tensors), call
.to(device) on state_dict["activations"] (preserving dtype), then assign to
self._activations so subsequent accumulate() and add_ operations are
device-safe.

Comment on lines +456 to +480
# Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear)
# TODO: Consider better design to handle RowParallelLinear and nn.Linear
if isinstance(output, tuple):
output_tensor = output[0]
else:
output_tensor = output

activations = args[0]

# Don't aggregate activations from non-max subnets (e.g. from profiling)
if self.max_size is not None and activations.shape[-1] != self.max_size:
return

n_channels_to_prune = self.pruning_schedule[self.curr_iter]

curr_activations = activations.clone() # Shape B,T,I
curr_activations[..., self.pruned_channels] = 0
output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E

if self.calibration_method is None:
scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T
elif self.calibration_method == "scale_by_magnitude":
output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T
output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T
scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | sed -n '440,500p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3427


🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3769


🏁 Script executed:

ast-grep --pattern 'def $method($args) {
  $$$
}'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg -A 50 "Handle case where output is a tuple" modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2550


🏁 Script executed:

wc -l modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 125


Include bias in the pruned forward recomputation.

For biased nn.Linear layers, output_tensor includes the bias term but output_curr does not. This mismatch causes the scaling factor computation (line 480) and subsequent channel importance ranking (line 487) to be incorrect. The bias difference must be accounted for in the recomputed forward pass.

Fix
         curr_activations = activations.clone()  # Shape B,T,I
         curr_activations[..., self.pruned_channels] = 0
+        bias = None if isinstance(output, tuple) else getattr(module, "bias", None)
-        output_curr = F.linear(input=curr_activations, weight=self.weight_matrix)  # Shape B,T,E
+        output_curr = F.linear(
+            input=curr_activations,
+            weight=self.weight_matrix,
+            bias=bias,
+        )  # Shape B,T,E
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 456 -
480, The recomputed forward output (output_curr) omits the layer bias while
output_tensor includes it, skewing scaling_factor_per_token; modify the
recomputation in the block using curr_activations and self.weight_matrix so it
also adds the layer bias when present (e.g., compute output_curr =
F.linear(input=curr_activations, weight=self.weight_matrix, bias=self.bias) or
add self.bias expanded to output_curr's shape), ensuring the bias shape matches
output_curr before computing output_norms and scaling_factor_per_token;
reference symbols: output_tensor, output_curr, curr_activations,
self.weight_matrix, self.bias, scaling_factor_per_token.

Comment on lines +240 to +244
if set(ref_layers) != set(comp_layers):
print("\nERROR: Layer mismatch!")
print(f"Reference layers: {ref_layers}")
print(f"Compare layers: {comp_layers}")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Raise on layer mismatch instead of returning successfully.

A bad comparison is a hard failure. Returning here makes the CLI exit with status 0, so scripts and CI can silently accept invalid results.

🛠️ Suggested fix
     if set(ref_layers) != set(comp_layers):
-        print("\nERROR: Layer mismatch!")
-        print(f"Reference layers: {ref_layers}")
-        print(f"Compare layers: {comp_layers}")
-        return
+        raise ValueError(
+            "Layer mismatch: "
+            f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if set(ref_layers) != set(comp_layers):
print("\nERROR: Layer mismatch!")
print(f"Reference layers: {ref_layers}")
print(f"Compare layers: {comp_layers}")
return
if set(ref_layers) != set(comp_layers):
raise ValueError(
"Layer mismatch: "
f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py` around
lines 240 - 244, Replace the current prints-and-return branch that handles layer
mismatches with a hard failure: instead of printing "ERROR: Layer mismatch!" and
returning, raise a RuntimeError (or other appropriate exception) that includes
the mismatched ref_layers and comp_layers in the message so the CLI/CI exits
non-zero; modify the code in compare_module_outputs.py where ref_layers and
comp_layers are compared (the current if set(ref_layers) != set(comp_layers):
block) to raise the exception with clear context.

Comment on lines +69 to +73
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Error message is inconsistent with valid choices.

The validation checks for ["all", "main", "local_main", "last"] but the error message only lists ['all', 'main', 'local_main'], missing 'last'.

🐛 Proposed fix
         if ranks not in ["all", "main", "local_main", "last"]:
             raise NotImplementedError(
                 f"Could not broadcast msg {msg} - "
-                f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
+                f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
)
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 69 - 73, The
NotImplementedError message in the ranks validation (inside the function/method
that broadcasts messages, referencing variables msg and ranks) is missing the
'last' choice; update the error text to list all valid options consistently with
the check — e.g., include 'last' alongside 'all', 'main', and 'local_main' — so
the raised message accurately reflects the allowed ranks values.

Comment on lines +79 to +84
elif (
(ranks == "main" and self.global_rank != 0)
or (ranks == "last" and self.local_rank != self.world_size - 1)
or (ranks == "local_main" and self.local_rank != 0)
):
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Logic bug in 'last' rank check for multi-node environments.

The condition self.local_rank != self.world_size - 1 compares local_rank (0 to local_world_size-1) against world_size (total processes across all nodes). In a multi-node setup with 8 total GPUs across 2 nodes, local_rank ranges 0-3 but world_size - 1 is 7, so the condition local_rank != 7 is always true and no rank ever prints.

Based on the docstring at line 167 ("rank -1 in node 0"), the intent seems to be the last rank on node 0.

🐛 Proposed fix
         elif (
             (ranks == "main" and self.global_rank != 0)
-            or (ranks == "last" and self.local_rank != self.world_size - 1)
+            or (ranks == "last" and self.global_rank != self.world_size - 1)
             or (ranks == "local_main" and self.local_rank != 0)
         ):
             return

Or if the intent is truly "last local rank on each node":

-            or (ranks == "last" and self.local_rank != self.world_size - 1)
+            or (ranks == "last" and self.local_rank != int(os.environ.get("LOCAL_WORLD_SIZE", 1)) - 1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 79 - 84, The "last"
branch incorrectly compares self.local_rank to self.world_size - 1; replace that
check to target the last local rank on node 0 by testing node and local sizes:
change the condition (ranks == "last" and self.local_rank != self.world_size -
1) to require that the process is not the last local rank on node 0 — e.g., use
self.node_rank and self.local_world_size and only allow printing when
(self.node_rank == 0 and self.local_rank == self.local_world_size - 1); update
the condition accordingly in the function that contains the ranks logic so it
references self.node_rank, self.local_rank, and self.local_world_size instead of
self.world_size.

Comment on lines +217 to +257
# TODO: Simplify it: this unit test is too long,
# hard to read (the same set of assertions across different test cases with if-else).

assert len(pruning_scores["activations_per_rank"]) == size
activations = pruning_scores["activations_per_rank"][rank]

# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
if size == 1 and pruned_ffn_div == 4:
# Layer scores
_assert_approx(pruning_scores["layer_scores"], {1: 0.028923, 2: 0.046508})

# Validate decoder.layers.0.mlp activations
mlp_0_acts = activations["decoder.layers.0.mlp"]
_assert_approx(mlp_0_acts.min().item(), 0.000026)
_assert_approx(mlp_0_acts.max().item(), 0.000729)
_assert_approx(mlp_0_acts.mean().item(), 0.000201)

# Validate decoder.layers.1.mlp activations
mlp_1_acts = activations["decoder.layers.1.mlp"]
_assert_approx(mlp_1_acts.min().item(), 0.000022)
_assert_approx(mlp_1_acts.max().item(), 0.000762)
_assert_approx(mlp_1_acts.mean().item(), 0.000162)

# Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
elif size == 1 and pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1:
# Layer scores
_assert_approx(pruning_scores["layer_scores"], {1: 0.028056, 2: 0.038353})

# Validate decoder.layers.0.self_attention activations
attn_0_acts = activations["decoder.layers.0.self_attention"]
assert attn_0_acts.shape == torch.Size([hidden_size])
_assert_approx(attn_0_acts.min().item(), 0.010091)
_assert_approx(attn_0_acts.max().item(), 0.023826)
_assert_approx(attn_0_acts.mean().item(), 0.014548)

# Validate decoder.layers.1.self_attention activations
attn_1_acts = activations["decoder.layers.1.self_attention"]
assert attn_1_acts.shape == torch.Size([hidden_size])
_assert_approx(attn_1_acts.min().item(), 0.009982)
_assert_approx(attn_1_acts.max().item(), 0.035644)
_assert_approx(attn_1_acts.mean().item(), 0.020140)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there are any hook cleanup or state reset mechanisms in the test utilities
rg -n "cleanup|reset|teardown" tests/_test_utils/torch/ -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 4291


🏁 Script executed:

#!/bin/bash
# Check the test file structure and imports
head -50 tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1818


🏁 Script executed:

#!/bin/bash
# Look for fixture definitions and pytest markers in the test file
rg -n "@pytest|@fixture|def setup|def teardown" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 155


🏁 Script executed:

#!/bin/bash
# Check if the test uses DistributedProcess or similar distributed utilities
rg -n "DistributedProcess|distributed|teardown" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 136


🏁 Script executed:

#!/bin/bash
# Check for conftest.py and pytest configuration
find tests/gpu_megatron -name "conftest.py" -exec cat {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 2622


🏁 Script executed:

#!/bin/bash
# Check pytest configuration
cat pyproject.toml | grep -A 20 "\[tool.pytest"

Repository: NVIDIA/Model-Optimizer

Length of output: 756


🏁 Script executed:

#!/bin/bash
# Look at the full test function containing lines 217-257
sed -n '180,260p' tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3785


🏁 Script executed:

#!/bin/bash
# Get the full function signature and parametrize decorators
sed -n '110,180p' tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2465


🏁 Script executed:

#!/bin/bash
# Search for where this test function is defined to see if it uses dist_workers fixture
grep -n "def test_mcore_gpt_minitron_pruning_and_export" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Check how set_seed is implemented
rg -n "def set_seed" tests/_test_utils/torch/ -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 808


🏁 Script executed:

#!/bin/bash
# Look for test function wrapper that calls _test_mcore_gpt_pruning
rg -n "test_mcore_gpt_pruning_and_export|test_mcore_gpt_pruning[^_]" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py -B 5 -A 15 | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 1842


Hard-coded activation values with tight tolerance may cause test fragility across different hardware and CUDA versions.

The test uses very specific hard-coded values (e.g., 0.028923, 0.046508) with abs=1e-3 tolerance. While proper teardown and seeding infrastructure is in place (via megatron_worker_teardown and set_seed), GPU-specific activation values can vary slightly across different hardware and CUDA versions, causing these assertions to fail even though the pruning logic is correct.

Consider:

  1. Using relative comparisons or shape/range checks instead of exact activation value assertions
  2. Documenting which GPU/CUDA version these values were captured on and re-validating periodically
  3. If exact values are needed, widen tolerance or make it configurable per environment

The TODO comment at line 217 also notes this test's complexity—splitting into focused unit tests would improve maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`
around lines 217 - 257, The test uses fragile, hardware-specific hard-coded
activation values (e.g., pruning_scores["layer_scores"], activations entries
like "decoder.layers.0.mlp" / "decoder.layers.0.self_attention") checked with
_assert_approx and abs=1e-3; replace these strict exact-value checks with more
robust assertions: validate tensor shapes and value ranges or use relative
tolerance (rtol) instead of strict absolute checks, or widen/make the tolerance
configurable via a test constant; for layer-level correctness prefer asserting
expected monotonic/relative relationships (e.g., pruned vs unpruned scores)
rather than exact floats; optionally split the large conditional block into
separate focused tests for MHA vs GQA using the same symbols (pruning_scores,
activations, _assert_approx) to improve readability and maintainability.

Comment on lines +113 to +154
def _run_hook_and_evaluate(
layer: nn.Linear,
hook,
num_iterations: int,
prune_ratio: float,
) -> dict:
"""Shared helper to run hook, collect scores, and evaluate.

Args:
layer: Linear layer to test
hook: Hook instance (already created)
num_iterations: Number of forward passes
prune_ratio: Fraction of channels to prune

Returns:
Dictionary with evaluation metrics
"""
handle = layer.register_forward_hook(hook) # Store the handle

# Run forward passes
all_activations = []
for _ in range(num_iterations):
activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50
all_activations.append(activations)
_ = layer(activations)

# Get importance scores from hook
importance_scores = hook.accumulate()

# Remove the hook before evaluation to avoid triggering it again
handle.remove()

# Evaluate the importance scores by simulating pruning on all collected activations
# Pass the list of activations to compute averaged metrics across batches
metrics = evaluate_importance_scores(
layer,
all_activations, # List of activation batches
importance_scores,
prune_ratio=prune_ratio,
)

return metrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Always remove the forward hook in a finally block.

handle.remove() only runs on the happy path. If the forward loop or hook.accumulate() fails, the hook stays attached in the reused worker and can bleed into the next scenario.

🛠️ Suggested fix
 def _run_hook_and_evaluate(
     layer: nn.Linear,
     hook,
     num_iterations: int,
@@
-    handle = layer.register_forward_hook(hook)  # Store the handle
-
-    # Run forward passes
-    all_activations = []
-    for _ in range(num_iterations):
-        activations = torch.randn(16, 8, layer.in_features)  # seq=16, batch=8, in_features=50
-        all_activations.append(activations)
-        _ = layer(activations)
-
-    # Get importance scores from hook
-    importance_scores = hook.accumulate()
-
-    # Remove the hook before evaluation to avoid triggering it again
-    handle.remove()
+    handle = layer.register_forward_hook(hook)
+    try:
+        all_activations = []
+        for _ in range(num_iterations):
+            activations = torch.randn(16, 8, layer.in_features)
+            all_activations.append(activations)
+            _ = layer(activations)
+
+        importance_scores = hook.accumulate()
+    finally:
+        handle.remove()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py`
around lines 113 - 154, The forward hook handle created in
_run_hook_and_evaluate is only removed on the happy path; wrap the work that
runs the forward passes and calls hook.accumulate (the loop that appends to
all_activations and the call to importance_scores = hook.accumulate()) in a
try/finally and call handle.remove() in the finally block to guarantee the hook
is detached even on exceptions; re-raise the exception after cleanup if one
occurred so failures are propagated, and keep the subsequent call to
evaluate_importance_scores unchanged.

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@danielkorzekwa danielkorzekwa requested a review from a team as a code owner March 11, 2026 14:40
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)

248-259: Shape assertions are robust; consider extending this pattern.

The assert attn_0_acts.shape == torch.Size([hidden_size]) checks (lines 249, 256) are hardware-independent and reliable. Consider adopting similar structural assertions for the MLP activations in the first test case (lines 231-240) to improve robustness while maintaining the statistical checks as secondary validation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`
around lines 248 - 259, Add hardware-independent shape assertions for the MLP
activation entries similar to the attention checks: for the MLP activations
retrieved from the activations dict (e.g., activations["decoder.layers.0.mlp"]
and activations["decoder.layers.1.mlp"]) assert their .shape equals
torch.Size([hidden_size]) before or alongside the existing min/max/mean
_assert_approx checks; update the test variables corresponding to the first test
case’s MLP activations (the variables that hold those activation tensors) to
include these assert ... .shape == torch.Size([hidden_size]) statements.
modelopt/torch/utils/robust_json.py (2)

47-48: Fragile dtype detection via string comparison.

Matching type(o).__name__ == "dtype" will catch any class named "dtype", not just numpy.dtype or torch.dtype. This could lead to unexpected behavior with other types. Consider adding a comment documenting this intentional duck-typing approach, or checking o.__class__.__module__ as well (e.g., starts with numpy or torch).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/robust_json.py` around lines 47 - 48, The current
fragile dtype detection uses type(o).__name__ == "dtype"; update it to only
match actual numpy/torch dtypes by also checking the class module (e.g., check
o.__class__.__module__ and ensure it startswith "numpy" or "torch") or
explicitly check for instances via duck-typing attributes, and add a brief
comment explaining the intent; replace the condition referencing
type(o).__name__ == "dtype" with a combined check using o.__class__.__module__
(or equivalent) so only numpy/torch dtype classes are matched in robust_json.py.

15-15: Consider removing the blanket mypy suppression.

Disabling mypy for the entire file with # mypy: ignore-errors bypasses type checking. The file already has type annotations, so targeted # type: ignore comments on specific problematic lines (if any) would be preferable to maintain type safety for the rest of the module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/robust_json.py` at line 15, Remove the top-level blanket
mypy suppression in robust_json.py (the `# mypy: ignore-errors` comment) and
instead address specific typing issues: delete that header, run mypy to find
failing lines, and add targeted `# type: ignore[...]` comments only on the
problematic expressions (or refine annotations on functions/classes such as any
helper functions in robust_json.py) so the rest of the module retains type
checking.
modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py (2)

675-681: Consider documenting why checkpointing is not supported.

The state_dict and load_state_dict methods raise NotImplementedError. While this is a valid design choice, adding a brief explanation in the docstring or error message (e.g., "KV head pruning is designed to complete in a single run") would help users understand the limitation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 675 -
681, Update the docstrings and/or raised NotImplementedError messages for the
state_dict and load_state_dict methods in base_hooks.py to briefly explain why
checkpointing isn't supported (e.g., "Checkpointing not supported because KV
head pruning is a one-time operation and does not maintain persistent mutable
state across runs"), so users understand the design choice; modify the
docstrings of state_dict and load_state_dict and the text passed to
NotImplementedError in those methods to include that brief rationale.

810-815: Consider using json_dump from robust_json for consistency.

The module imports json_dump from modelopt.torch.utils.robust_json (line 30) but uses the standard json.dump here. Using json_dump would ensure consistent JSON serialization across the codebase and automatically create parent directories.

Suggested change
-        output_path = activations_log_dir / "channel_importance_results.json"
-        aprint(f"Saving channel importance data to {output_path}")
-        with open(output_path, "w") as f:
-            json.dump(output_data, f, indent=2)
+        output_path = activations_log_dir / "channel_importance_results.json"
+        aprint(f"Saving channel importance data to {output_path}")
+        json_dump(output_data, output_path)

Note: This would lose the indent=2 formatting. If pretty-printing is important, consider adding an indent parameter to json_dump or keep the current approach.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 810 -
815, Replace the use of json.dump when writing output_data to output_path with
the project utility json_dump (imported from modelopt.torch.utils.robust_json)
so serialization is consistent and parent directories are created automatically;
update the write block that currently uses output_path and json.dump to call
json_dump(output_data, output_path) (or add an indent param to json_dump if
pretty-printing is required) and remove the manual open(...) context since
json_dump handles file creation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/utils/logging.py`:
- Around line 206-208: The aprint wrapper currently always passes flush=True
while also forwarding kwargs, causing a duplicate-key TypeError when callers
pass flush in kwargs; update aprint to handle/merge the flush kwarg by
extracting it from kwargs (e.g., pop 'flush' if present) and then call print
with the resolved flush value (use True as the default if not provided) so
callers can override flush without causing a duplicate keyword error; reference
the aprint function in logging.py to locate the change.

In `@modelopt/torch/utils/robust_json.py`:
- Around line 74-78: The return type of json_load is too narrow (annotated as ->
dict) while json.loads can return dict, list, str, int, float, bool, or None;
update json_load's return annotation to reflect that (e.g., use typing.Any or
define a JSONType = Union[dict, list, str, int, float, bool, None] and use it)
and adjust the docstring to say "Return parsed JSON value" instead of "return as
dictionary"; reference the json_load function and the json.loads call when
making this change.

---

Nitpick comments:
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Around line 675-681: Update the docstrings and/or raised NotImplementedError
messages for the state_dict and load_state_dict methods in base_hooks.py to
briefly explain why checkpointing isn't supported (e.g., "Checkpointing not
supported because KV head pruning is a one-time operation and does not maintain
persistent mutable state across runs"), so users understand the design choice;
modify the docstrings of state_dict and load_state_dict and the text passed to
NotImplementedError in those methods to include that brief rationale.
- Around line 810-815: Replace the use of json.dump when writing output_data to
output_path with the project utility json_dump (imported from
modelopt.torch.utils.robust_json) so serialization is consistent and parent
directories are created automatically; update the write block that currently
uses output_path and json.dump to call json_dump(output_data, output_path) (or
add an indent param to json_dump if pretty-printing is required) and remove the
manual open(...) context since json_dump handles file creation.

In `@modelopt/torch/utils/robust_json.py`:
- Around line 47-48: The current fragile dtype detection uses type(o).__name__
== "dtype"; update it to only match actual numpy/torch dtypes by also checking
the class module (e.g., check o.__class__.__module__ and ensure it startswith
"numpy" or "torch") or explicitly check for instances via duck-typing
attributes, and add a brief comment explaining the intent; replace the condition
referencing type(o).__name__ == "dtype" with a combined check using
o.__class__.__module__ (or equivalent) so only numpy/torch dtype classes are
matched in robust_json.py.
- Line 15: Remove the top-level blanket mypy suppression in robust_json.py (the
`# mypy: ignore-errors` comment) and instead address specific typing issues:
delete that header, run mypy to find failing lines, and add targeted `# type:
ignore[...]` comments only on the problematic expressions (or refine annotations
on functions/classes such as any helper functions in robust_json.py) so the rest
of the module retains type checking.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`:
- Around line 248-259: Add hardware-independent shape assertions for the MLP
activation entries similar to the attention checks: for the MLP activations
retrieved from the activations dict (e.g., activations["decoder.layers.0.mlp"]
and activations["decoder.layers.1.mlp"]) assert their .shape equals
torch.Size([hidden_size]) before or alongside the existing min/max/mean
_assert_approx checks; update the test variables corresponding to the first test
case’s MLP activations (the variables that hold those activation tensors) to
include these assert ... .shape == torch.Size([hidden_size]) statements.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ec61ff4f-cb7a-4262-91ac-c9d486e2dd0f

📥 Commits

Reviewing files that changed from the base of the PR and between 6ce8345 and c3870a5.

📒 Files selected for processing (4)
  • modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py
  • modelopt/torch/utils/logging.py
  • modelopt/torch/utils/robust_json.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Comment on lines +206 to +208
def aprint(*args, **kwargs):
"""All ranks from all nodes print."""
print(*args, **kwargs, flush=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid duplicating the flush keyword.

aprint(..., flush=False) currently raises TypeError because flush is passed both via **kwargs and as an explicit keyword. Since this is a new public wrapper around print, it should merge or override the kwarg instead of duplicating it.

Proposed fix
 def aprint(*args, **kwargs):
     """All ranks from all nodes print."""
-    print(*args, **kwargs, flush=True)
+    kwargs = {**kwargs, "flush": True}
+    print(*args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/logging.py` around lines 206 - 208, The aprint wrapper
currently always passes flush=True while also forwarding kwargs, causing a
duplicate-key TypeError when callers pass flush in kwargs; update aprint to
handle/merge the flush kwarg by extracting it from kwargs (e.g., pop 'flush' if
present) and then call print with the resolved flush value (use True as the
default if not provided) so callers can override flush without causing a
duplicate keyword error; reference the aprint function in logging.py to locate
the change.

Comment on lines +74 to +78
def json_load(path: Path | str) -> dict:
"""Load JSON from file and return as dictionary."""
path = Path(path)
text = path.read_text()
return json.loads(text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation is too restrictive.

json.loads can return various types (dict, list, str, int, float, bool, or None) depending on the JSON content. The current -> dict annotation is incorrect if the JSON root is an array or primitive value. This could cause runtime type mismatches for callers.

🔧 Proposed fix
-def json_load(path: Path | str) -> dict:
-    """Load JSON from file and return as dictionary."""
+def json_load(path: Path | str) -> Any:
+    """Load JSON from file and return the deserialized object."""
     path = Path(path)
     text = path.read_text()
     return json.loads(text)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/robust_json.py` around lines 74 - 78, The return type of
json_load is too narrow (annotated as -> dict) while json.loads can return dict,
list, str, int, float, bool, or None; update json_load's return annotation to
reflect that (e.g., use typing.Any or define a JSONType = Union[dict, list, str,
int, float, bool, None] and use it) and adjust the docstring to say "Return
parsed JSON value" instead of "return as dictionary"; reference the json_load
function and the json.loads call when making this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about following file structure:

modelopt/torch/prune/importance_hooks/
|- base_hooks.py
|- base_hook_analysis.py
|- compare_module_output.py
|- plugins/
   |- megatron_hooks.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I like it, will do

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants