Activation hooks redesign (reuse hooks component across both minitron and puzzletron)#1022
Activation hooks redesign (reuse hooks component across both minitron and puzzletron)#1022danielkorzekwa wants to merge 6 commits intomainfrom
Conversation
- Add base hooks framework in modelopt/torch/nas/plugins/megatron_hooks/ - base_hooks.py: Core hook infrastructure - base_hooks_analysis.py: Analysis utilities for hooks - megatron_hooks.py: Megatron-specific hook implementations - compare_module_outputs.py: Module comparison utilities - Add tests for activation hooks - Update test utilities for distributed testing - Update minitron pruning tests to use new activation hooks
The activation hooks infrastructure depends on aprint from puzzletron.tools.logger. Adding minimal logger module to satisfy this dependency. Note: Some docstring linting warnings are suppressed as this is copied code.
… moved later outside of puzzletron module) Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
📝 WalkthroughWalkthroughThis PR introduces a new Megatron NAS plugin system for activation-based importance estimation in neural networks. It includes abstract hook classes for capturing and analyzing layer activations, tensor-parallel support, output comparison tools, and comprehensive unit tests. Additionally, utility functions for logging and robust JSON serialization are added. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error)
✅ Passed checks (3 passed)
✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
🧹 Nitpick comments (6)
tests/_test_utils/torch/distributed/utils.py (1)
26-30: Socket not explicitly closed after getting free port.The socket is bound to get a free port but never closed. While Python's garbage collector will eventually close it, explicitly closing the socket ensures the port is released promptly and avoids potential resource leaks.
♻️ Suggested fix
def get_free_port(): sock = socket.socket() sock.bind(("", 0)) port = sock.getsockname()[1] + sock.close() return port🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/_test_utils/torch/distributed/utils.py` around lines 26 - 30, The get_free_port function creates a socket but never closes it; update get_free_port to explicitly close the socket after binding (e.g., use a context manager with socket.socket(...) as sock or call sock.close()) so the port is released promptly; modify the function containing get_free_port to ensure the socket is closed before returning the port while keeping the same behavior of binding to ("", 0) and returning getsockname()[1].modelopt/torch/puzzletron/tools/logger.py (3)
96-108: Fragile stack frame navigation.The triple
f_backnavigation assumes a fixed call depth. If the call chain changes (e.g., adding a wrapper or callingdist_logdirectly), incorrect source locations will be reported. Consider usinginspect.stack()with a more robust approach.♻️ Suggested improvement
`@staticmethod` - def get_caller_location() -> str: + def get_caller_location(depth: int = 3) -> str: """Get the caller location from the stack frame.""" - # Get the caller's stack frame - frame = inspect.currentframe() - - # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source - caller_frame = frame.f_back.f_back.f_back - - # Get the filename and line number from the caller's stack frame - filename = os.path.basename(caller_frame.f_code.co_filename) - lineno = caller_frame.f_lineno - return f"{filename}:{lineno}" + stack = inspect.stack() + if len(stack) > depth: + frame_info = stack[depth] + return f"{os.path.basename(frame_info.filename)}:{frame_info.lineno}" + return "unknown:0"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/logger.py` around lines 96 - 108, The get_caller_location method currently walks a fixed f_back chain which is brittle; replace that logic in get_caller_location with a robust inspect.stack()-based search: call inspect.stack() and iterate frames to find the first frame whose module/filename is not this logger module (or where the function name is not get_caller_location/dist_log), then use that frame's filename and lineno to return "filename:lineno"; ensure you fall back to the last non-None frame if none match to avoid exceptions.
23-23: Unused import oftorch.distributed.launch.This import is marked with
# noqa: F401(unused import) but there's no apparent reason for its presence. If it's needed for a side effect, please add a comment explaining why.♻️ Suggested fix - remove if not needed
-import torch.distributed.launch # noqa: F401🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/logger.py` at line 23, The unused import torch.distributed.launch (currently annotated with # noqa: F401) should be removed to eliminate dead code; if the import is intentionally required for a side effect, replace the bare # noqa with a short explanatory comment (e.g., "import to register torch distributed launch entrypoint for CLI/side-effects") next to the import so future readers understand its purpose—update the import statement at the top of logger.py accordingly.
111-114: Global logger class modification may have unintended side effects.
logging.setLoggerClass(DistributedLogger)changes the default logger class for the entire process. Any subsequentlogging.getLogger()calls in other modules will createDistributedLoggerinstances, which may not be intended.Consider using a factory function or creating the logger directly without modifying the global logger class.
♻️ Alternative approach
-# Initialize logger -logging.setLoggerClass(DistributedLogger) -logger = logging.getLogger(__name__) +# Initialize logger without modifying global logger class +logger = DistributedLogger(__name__)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/logger.py` around lines 111 - 114, Remove the global side-effect call to logging.setLoggerClass(DistributedLogger) and instead instantiate or provide a factory that returns a DistributedLogger explicitly; replace the current sequence with something like creating the module logger by calling DistributedLogger(__name__) (or implement a get_distributed_logger(name) helper that returns DistributedLogger(name)), assign it to the logger variable and keep logger.propagate = False. Ensure no other global logger class is changed so other modules calling logging.getLogger() are unaffected.modelopt/torch/puzzletron/tools/robust_json.py (2)
31-31: Hard import of optional dependency at module level.Per coding guidelines, optional dependencies should be gated.
omegaconfmay not be installed in all environments.♻️ Suggested fix - gate the import
-from omegaconf import DictConfig, ListConfig, OmegaConf +try: + from omegaconf import DictConfig, ListConfig, OmegaConf + _HAS_OMEGACONF = True +except ImportError: + DictConfig = ListConfig = None + _HAS_OMEGACONF = FalseThen in
RobustJSONEncoder.default:- if isinstance(o, (DictConfig, ListConfig)): + if _HAS_OMEGACONF and isinstance(o, (DictConfig, ListConfig)): return OmegaConf.to_container(o, resolve=True)As per coding guidelines: "Avoid hard imports of optional dependencies at module level; gate features by install extras."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/robust_json.py` at line 31, The module currently hard-imports omegaconf at top-level; gate this optional dependency by moving the import into the code path that needs it and handling ImportError: remove the top-level "from omegaconf ..." import and instead import omegaconf (or specific symbols) inside RobustJSONEncoder.default (and/or any other functions that use DictConfig/ListConfig/OmegaConf), catch ImportError and fall back to treating those objects as regular mappings or raise a clear runtime error instructing to install the optional extra; update references to DictConfig/ListConfig/OmegaConf in RobustJSONEncoder.default to use the local import or the fallback behavior.
74-78: Return type hint is too restrictive.
json.loadscan return any JSON-compatible type (dict, list, str, int, etc.), not justdict. Either update the type hint or add validation.♻️ Suggested fix
-def json_load(path: Path | str) -> dict: +def json_load(path: Path | str) -> Any: """Load JSON from file and return as dictionary.""" path = Path(path) text = path.read_text() return json.loads(text)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/robust_json.py` around lines 74 - 78, The function json_load currently types its return as dict but json.loads can return any JSON-compatible type; update json_load's signature and implementation: either change the return type from dict to a broad JSON type (e.g., Any or a custom JSONType alias for dict|list|str|int|float|bool|None) or keep dict but validate the loaded value and raise a clear exception if it's not a dict; refer to the function name json_load and the code that calls Path.read_text()/json.loads to implement the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py`:
- Around line 24-103: In evaluate_importance_scores, avoid building autograd
graphs and guard against empty input: wrap the per-batch evaluation loop (the
section using linear_layer, pruned_activations, pruned_output, rmse/cosine
computation) in a torch.no_grad() context to prevent gradient graph allocation,
and add an early validation at the top that checks activations_batches is
non-empty (raise a ValueError with a clear message or return zeroed metrics)
before computing num_to_prune; reference the function name
evaluate_importance_scores and the local symbols activations_batches,
linear_layer, pruned_activations, original_output, pruned_output, rmse_values,
and cosine_values when making the edits.
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Around line 154-156: The code currently mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before json_dump, which removes the
model reference for later uses; instead, create a copy of
activation_hooks_kwargs (e.g., tmp = dict(args.activation_hooks_kwargs) or use
copy.deepcopy) and remove "model" from that copy if present, then call
json_dump(OmegaConf.to_container(args_copy_or_modified_args, resolve=True),
activations_log_dir / "args.json") so the original args.activation_hooks_kwargs
remains unchanged; reference the symbols args.activation_hooks_kwargs,
json_dump, OmegaConf.to_container, and activations_log_dir / "args.json".
- Around line 162-180: save_hook_states currently calls hook.state_dict() for
every hook which fails when certain hooks (e.g.,
IndependentKvHeadContributionHook, LayerNormContributionHook) raise
NotImplementedError; modify save_hook_states to safely handle hooks that don't
support checkpointing by wrapping the state collection in a try/except (catch
NotImplementedError and skip that hook) or by checking for a supported
method/attribute before calling state_dict, and optionally record/log skipped
module names; ensure you reference save_hook_states and state_dict in your
change so only hooks that successfully return state are included in the saved
hook_states dict.
- Around line 231-259: The L2NormHook.load_state_dict currently assigns
checkpointed "_activations" directly, which can leave tensors on the wrong
device and later cause device-mismatch in L2NormHook.accumulate (and in the "+="
path). Update L2NormHook.load_state_dict to move the loaded activations to the
current module/device before assigning (follow the pattern used in
IndependentChannelContributionHook.load_state_dict): detect the target device
(e.g., from a provided module or torch.device of existing tensors), call
.to(device) on state_dict["activations"] (preserving dtype), then assign to
self._activations so subsequent accumulate() and add_ operations are
device-safe.
- Around line 456-480: The recomputed forward output (output_curr) omits the
layer bias while output_tensor includes it, skewing scaling_factor_per_token;
modify the recomputation in the block using curr_activations and
self.weight_matrix so it also adds the layer bias when present (e.g., compute
output_curr = F.linear(input=curr_activations, weight=self.weight_matrix,
bias=self.bias) or add self.bias expanded to output_curr's shape), ensuring the
bias shape matches output_curr before computing output_norms and
scaling_factor_per_token; reference symbols: output_tensor, output_curr,
curr_activations, self.weight_matrix, self.bias, scaling_factor_per_token.
In `@modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py`:
- Around line 240-244: Replace the current prints-and-return branch that handles
layer mismatches with a hard failure: instead of printing "ERROR: Layer
mismatch!" and returning, raise a RuntimeError (or other appropriate exception)
that includes the mismatched ref_layers and comp_layers in the message so the
CLI/CI exits non-zero; modify the code in compare_module_outputs.py where
ref_layers and comp_layers are compared (the current if set(ref_layers) !=
set(comp_layers): block) to raise the exception with clear context.
In `@modelopt/torch/puzzletron/tools/logger.py`:
- Around line 69-73: The NotImplementedError message in the ranks validation
(inside the function/method that broadcasts messages, referencing variables msg
and ranks) is missing the 'last' choice; update the error text to list all valid
options consistently with the check — e.g., include 'last' alongside 'all',
'main', and 'local_main' — so the raised message accurately reflects the allowed
ranks values.
- Around line 79-84: The "last" branch incorrectly compares self.local_rank to
self.world_size - 1; replace that check to target the last local rank on node 0
by testing node and local sizes: change the condition (ranks == "last" and
self.local_rank != self.world_size - 1) to require that the process is not the
last local rank on node 0 — e.g., use self.node_rank and self.local_world_size
and only allow printing when (self.node_rank == 0 and self.local_rank ==
self.local_world_size - 1); update the condition accordingly in the function
that contains the ranks logic so it references self.node_rank, self.local_rank,
and self.local_world_size instead of self.world_size.
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`:
- Around line 217-257: The test uses fragile, hardware-specific hard-coded
activation values (e.g., pruning_scores["layer_scores"], activations entries
like "decoder.layers.0.mlp" / "decoder.layers.0.self_attention") checked with
_assert_approx and abs=1e-3; replace these strict exact-value checks with more
robust assertions: validate tensor shapes and value ranges or use relative
tolerance (rtol) instead of strict absolute checks, or widen/make the tolerance
configurable via a test constant; for layer-level correctness prefer asserting
expected monotonic/relative relationships (e.g., pruned vs unpruned scores)
rather than exact floats; optionally split the large conditional block into
separate focused tests for MHA vs GQA using the same symbols (pruning_scores,
activations, _assert_approx) to improve readability and maintainability.
In `@tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py`:
- Around line 113-154: The forward hook handle created in _run_hook_and_evaluate
is only removed on the happy path; wrap the work that runs the forward passes
and calls hook.accumulate (the loop that appends to all_activations and the call
to importance_scores = hook.accumulate()) in a try/finally and call
handle.remove() in the finally block to guarantee the hook is detached even on
exceptions; re-raise the exception after cleanup if one occurred so failures are
propagated, and keep the subsequent call to evaluate_importance_scores
unchanged.
---
Nitpick comments:
In `@modelopt/torch/puzzletron/tools/logger.py`:
- Around line 96-108: The get_caller_location method currently walks a fixed
f_back chain which is brittle; replace that logic in get_caller_location with a
robust inspect.stack()-based search: call inspect.stack() and iterate frames to
find the first frame whose module/filename is not this logger module (or where
the function name is not get_caller_location/dist_log), then use that frame's
filename and lineno to return "filename:lineno"; ensure you fall back to the
last non-None frame if none match to avoid exceptions.
- Line 23: The unused import torch.distributed.launch (currently annotated with
# noqa: F401) should be removed to eliminate dead code; if the import is
intentionally required for a side effect, replace the bare # noqa with a short
explanatory comment (e.g., "import to register torch distributed launch
entrypoint for CLI/side-effects") next to the import so future readers
understand its purpose—update the import statement at the top of logger.py
accordingly.
- Around line 111-114: Remove the global side-effect call to
logging.setLoggerClass(DistributedLogger) and instead instantiate or provide a
factory that returns a DistributedLogger explicitly; replace the current
sequence with something like creating the module logger by calling
DistributedLogger(__name__) (or implement a get_distributed_logger(name) helper
that returns DistributedLogger(name)), assign it to the logger variable and keep
logger.propagate = False. Ensure no other global logger class is changed so
other modules calling logging.getLogger() are unaffected.
In `@modelopt/torch/puzzletron/tools/robust_json.py`:
- Line 31: The module currently hard-imports omegaconf at top-level; gate this
optional dependency by moving the import into the code path that needs it and
handling ImportError: remove the top-level "from omegaconf ..." import and
instead import omegaconf (or specific symbols) inside RobustJSONEncoder.default
(and/or any other functions that use DictConfig/ListConfig/OmegaConf), catch
ImportError and fall back to treating those objects as regular mappings or raise
a clear runtime error instructing to install the optional extra; update
references to DictConfig/ListConfig/OmegaConf in RobustJSONEncoder.default to
use the local import or the fallback behavior.
- Around line 74-78: The function json_load currently types its return as dict
but json.loads can return any JSON-compatible type; update json_load's signature
and implementation: either change the return type from dict to a broad JSON type
(e.g., Any or a custom JSONType alias for dict|list|str|int|float|bool|None) or
keep dict but validate the loaded value and raise a clear exception if it's not
a dict; refer to the function name json_load and the code that calls
Path.read_text()/json.loads to implement the chosen approach.
In `@tests/_test_utils/torch/distributed/utils.py`:
- Around line 26-30: The get_free_port function creates a socket but never
closes it; update get_free_port to explicitly close the socket after binding
(e.g., use a context manager with socket.socket(...) as sock or call
sock.close()) so the port is released promptly; modify the function containing
get_free_port to ensure the socket is closed before returning the port while
keeping the same behavior of binding to ("", 0) and returning getsockname()[1].
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: afede6d7-95c6-4028-bab9-e7b8241fd750
📒 Files selected for processing (14)
modelopt/torch/nas/plugins/megatron_hooks/__init__.pymodelopt/torch/nas/plugins/megatron_hooks/base_hooks.pymodelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.pymodelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.pymodelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.pymodelopt/torch/puzzletron/__init__.pymodelopt/torch/puzzletron/tools/__init__.pymodelopt/torch/puzzletron/tools/logger.pymodelopt/torch/puzzletron/tools/robust_json.pytests/_test_utils/torch/distributed/utils.pytests/conftest.pytests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.pytests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.pytests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
| def evaluate_importance_scores( | ||
| linear_layer: nn.Linear, | ||
| activations_batches: list[torch.Tensor], | ||
| importance_scores: torch.Tensor, | ||
| prune_ratio: float = 0.2, | ||
| ) -> dict[str, float]: | ||
| """Compute reconstruction error after pruning input channels of a linear layer. | ||
|
|
||
| This function simulates channel pruning by zeroing out input channels identified as | ||
| least important, then measures how much the layer's output changes. | ||
|
|
||
| Args: | ||
| linear_layer: The linear layer to analyze with shape (out_features, in_features). | ||
| For example: nn.Linear(in_features=1024, out_features=4096) | ||
| activations_batches: List of input activation tensors. | ||
| Each tensor has shape [seq_len, batch_size, in_features]. | ||
| The last dimension must match linear_layer.in_features. | ||
| Example: List of [16, 8, 1024] tensors | ||
| importance_scores: Importance score for each input channel (feature). | ||
| Shape: [in_features]. Lower scores = less important. | ||
| Example: [1024] tensor with one score per input feature | ||
| prune_ratio: Fraction of input channels to prune (default: 0.2 means prune 20%). | ||
|
|
||
| Returns: | ||
| Dictionary containing averaged metrics across all activation batches: | ||
| - rmse: Root mean squared error between original and pruned output | ||
| - cosine_similarity: Cosine similarity between original and pruned output | ||
| - num_pruned: Number of input channels pruned | ||
|
|
||
| Example: | ||
| >>> layer = nn.Linear(in_features=1024, out_features=4096) | ||
| >>> # Collect multiple batches for robust evaluation | ||
| >>> activations_list = [torch.randn(16, 8, 1024) for _ in range(100)] | ||
| >>> scores = torch.randn(1024) # one score per input feature | ||
| >>> metrics = evaluate_importance_scores(layer, activations_list, scores, 0.2) | ||
| >>> print(f"RMSE: {metrics['rmse']:.4f}, Pruned: {metrics['num_pruned']} channels") | ||
|
|
||
| Note: | ||
| - This simulates pruning (zeros out inputs) without modifying layer weights | ||
| - "Channels" refers to INPUT features, not output features | ||
|
|
||
| """ | ||
| num_channels = importance_scores.shape[0] | ||
| num_to_prune = int(num_channels * prune_ratio) | ||
|
|
||
| # Identify channels to prune (lowest scoring = least important) | ||
| _, channels_to_prune = torch.topk(importance_scores, num_to_prune, largest=False) | ||
|
|
||
| # Compute metrics for each batch and average | ||
| rmse_values = [] | ||
| cosine_values = [] | ||
|
|
||
| for activations in activations_batches: | ||
| # Get original output | ||
| original_output = linear_layer(activations) | ||
|
|
||
| # Prune by zeroing out identified channels | ||
| pruned_activations = activations.clone() | ||
| pruned_activations[..., channels_to_prune] = 0 | ||
|
|
||
| # Get pruned output | ||
| pruned_output = linear_layer(pruned_activations) | ||
|
|
||
| # Compute metrics for this batch | ||
| rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item() | ||
| rmse_values.append(rmse) | ||
|
|
||
| # Cosine similarity (flatten to vectors) | ||
| original_flat = original_output.reshape(-1) | ||
| pruned_flat = pruned_output.reshape(-1) | ||
| cosine = F.cosine_similarity( | ||
| original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1 | ||
| ).item() | ||
| cosine_values.append(cosine) | ||
|
|
||
| # Return averaged metrics | ||
| return { | ||
| "rmse": sum(rmse_values) / len(rmse_values), | ||
| "cosine_similarity": sum(cosine_values) / len(cosine_values), | ||
| "num_pruned": num_to_prune, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
fd -t f "base_hooks_analysis.py"Repository: NVIDIA/Model-Optimizer
Length of output: 201
🏁 Script executed:
cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py | head -110Repository: NVIDIA/Model-Optimizer
Length of output: 5245
🏁 Script executed:
cat -n tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.pyRepository: NVIDIA/Model-Optimizer
Length of output: 7539
🏁 Script executed:
# Also check if there are imports or dependencies we should verify
grep -n "torch.no_grad\|activations_batches\|evaluate_importance_scores" modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.pyRepository: NVIDIA/Model-Optimizer
Length of output: 395
Wrap the evaluation loop in torch.no_grad() and validate non-empty input.
This analysis-only function builds two autograd graphs per batch unnecessarily, creating memory pressure during calibration. Additionally, passing an empty activations_batches list silently fails at the return statement with a ZeroDivisionError.
🛠️ Suggested fix
def evaluate_importance_scores(
linear_layer: nn.Linear,
activations_batches: list[torch.Tensor],
importance_scores: torch.Tensor,
prune_ratio: float = 0.2,
) -> dict[str, float]:
@@
"""
num_channels = importance_scores.shape[0]
+ if not activations_batches:
+ raise ValueError("activations_batches must be non-empty")
num_to_prune = int(num_channels * prune_ratio)
@@
- for activations in activations_batches:
- # Get original output
- original_output = linear_layer(activations)
-
- # Prune by zeroing out identified channels
- pruned_activations = activations.clone()
- pruned_activations[..., channels_to_prune] = 0
-
- # Get pruned output
- pruned_output = linear_layer(pruned_activations)
-
- # Compute metrics for this batch
- rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
- rmse_values.append(rmse)
-
- # Cosine similarity (flatten to vectors)
- original_flat = original_output.reshape(-1)
- pruned_flat = pruned_output.reshape(-1)
- cosine = F.cosine_similarity(
- original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
- ).item()
- cosine_values.append(cosine)
+ with torch.no_grad():
+ for activations in activations_batches:
+ original_output = linear_layer(activations)
+
+ pruned_activations = activations.clone()
+ pruned_activations[..., channels_to_prune] = 0
+ pruned_output = linear_layer(pruned_activations)
+
+ rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
+ rmse_values.append(rmse)
+
+ original_flat = original_output.reshape(-1)
+ pruned_flat = pruned_output.reshape(-1)
+ cosine = F.cosine_similarity(
+ original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
+ ).item()
+ cosine_values.append(cosine)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py` around
lines 24 - 103, In evaluate_importance_scores, avoid building autograd graphs
and guard against empty input: wrap the per-batch evaluation loop (the section
using linear_layer, pruned_activations, pruned_output, rmse/cosine computation)
in a torch.no_grad() context to prevent gradient graph allocation, and add an
early validation at the top that checks activations_batches is non-empty (raise
a ValueError with a clear message or return zeroed metrics) before computing
num_to_prune; reference the function name evaluate_importance_scores and the
local symbols activations_batches, linear_layer, pruned_activations,
original_output, pruned_output, rmse_values, and cosine_values when making the
edits.
| if rank == 0: | ||
| args.activation_hooks_kwargs.pop("model") | ||
| json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") |
There was a problem hiding this comment.
Don't mutate args.activation_hooks_kwargs just to dump JSON.
pop("model") edits the live config object. Reusing the same args later loses the model reference, and a second dump on rank 0 can fail on the missing key.
🛠️ Suggested fix
if rank == 0:
- args.activation_hooks_kwargs.pop("model")
- json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
+ args_to_dump = OmegaConf.to_container(args, resolve=True)
+ args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
+ json_dump(args_to_dump, activations_log_dir / "args.json")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if rank == 0: | |
| args.activation_hooks_kwargs.pop("model") | |
| json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json") | |
| if rank == 0: | |
| args_to_dump = OmegaConf.to_container(args, resolve=True) | |
| args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None) | |
| json_dump(args_to_dump, activations_log_dir / "args.json") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 154 -
156, The code currently mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before json_dump, which removes the
model reference for later uses; instead, create a copy of
activation_hooks_kwargs (e.g., tmp = dict(args.activation_hooks_kwargs) or use
copy.deepcopy) and remove "model" from that copy if present, then call
json_dump(OmegaConf.to_container(args_copy_or_modified_args, resolve=True),
activations_log_dir / "args.json") so the original args.activation_hooks_kwargs
remains unchanged; reference the symbols args.activation_hooks_kwargs,
json_dump, OmegaConf.to_container, and activations_log_dir / "args.json".
| def save_hook_states( | ||
| cls: type["ForwardHook"], | ||
| activation_hooks: dict[str, "ForwardHook"], | ||
| activations_log_dir: Path | str, | ||
| ) -> None: | ||
| """Save hook states for checkpointing (separate from final results). | ||
|
|
||
| This can be called periodically during scoring. | ||
| Note: Synchronization should be handled at a higher level to avoid deadlocks. | ||
| """ | ||
| activations_log_dir = Path(activations_log_dir) | ||
| activations_log_dir.mkdir(exist_ok=True, parents=True) | ||
| rank = dist.rank() | ||
|
|
||
| hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth" | ||
| hook_states = { | ||
| module_name: hook.state_dict() for module_name, hook in activation_hooks.items() | ||
| } | ||
| torch.save(hook_states, hook_states_path) |
There was a problem hiding this comment.
save_hook_states() needs to handle hooks without checkpoint support.
This helper blindly calls state_dict() on every hook, but IndependentKvHeadContributionHook and LayerNormContributionHook later raise NotImplementedError. One such hook makes periodic checkpointing fail for the whole run.
🛠️ Suggested fix
- hook_states = {
- module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
- }
+ hook_states = {}
+ unsupported = []
+ for module_name, hook in activation_hooks.items():
+ try:
+ hook_states[module_name] = hook.state_dict()
+ except NotImplementedError:
+ unsupported.append(module_name)
+
+ if unsupported:
+ aprint(
+ "Skipping hook checkpoint save for hooks without state support: "
+ f"{unsupported}"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 162 -
180, save_hook_states currently calls hook.state_dict() for every hook which
fails when certain hooks (e.g., IndependentKvHeadContributionHook,
LayerNormContributionHook) raise NotImplementedError; modify save_hook_states to
safely handle hooks that don't support checkpointing by wrapping the state
collection in a try/except (catch NotImplementedError and skip that hook) or by
checking for a supported method/attribute before calling state_dict, and
optionally record/log skipped module names; ensure you reference
save_hook_states and state_dict in your change so only hooks that successfully
return state are included in the saved hook_states dict.
| if self._activations is None: | ||
| self._activations = activations | ||
| else: | ||
| self._activations += activations | ||
|
|
||
| def accumulate(self) -> torch.Tensor: | ||
| """Return the accumulated L2 norm of activations. | ||
|
|
||
| Returns: | ||
| Tensor of accumulated scores, one per channel | ||
|
|
||
| Raises: | ||
| AssertionError: If no activations have been collected yet | ||
| """ | ||
| assert self._activations is not None, "No activations collected for importance estimation." | ||
| # Convert squared sum to L2 norm | ||
| return self._activations.pow(0.5) | ||
|
|
||
| def to_dict(self) -> dict[str, torch.Tensor]: | ||
| """Convert to dict format for saving.""" | ||
| return {"score": self.accumulate().cpu()} | ||
|
|
||
| def state_dict(self) -> dict: | ||
| """Return the state dictionary containing activations.""" | ||
| return {"activations": self._activations} | ||
|
|
||
| def load_state_dict(self, state_dict: dict) -> None: | ||
| """Load activations from checkpoint.""" | ||
| self._activations = state_dict["activations"] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pyRepository: NVIDIA/Model-Optimizer
Length of output: 125
🏁 Script executed:
cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | sed -n '200,280p'Repository: NVIDIA/Model-Optimizer
Length of output: 3929
🏁 Script executed:
# Find the IndependentChannelContributionHook class
rg "class IndependentChannelContributionHook" -A 50 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2237
🏁 Script executed:
# Get the full L2NormHook class definition
rg "class L2NormHook" -B 5 -A 100 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pyRepository: NVIDIA/Model-Optimizer
Length of output: 4345
🏁 Script executed:
# Find IndependentChannelContributionHook's load_state_dict method
rg "class IndependentChannelContributionHook" -A 200 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | grep -A 20 "load_state_dict"Repository: NVIDIA/Model-Optimizer
Length of output: 1178
🏁 Script executed:
# Let's get more context on IndependentChannelContributionHook's state management
sed -n '262,500p' modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pyRepository: NVIDIA/Model-Optimizer
Length of output: 10681
Make L2NormHook state reload device-agnostic.
After restoring from checkpoint, _activations remains on the checkpoint device. When the resumed layer runs on a different device, accumulation fails on line 234 due to device mismatch in the += operation. IndependentChannelContributionHook.load_state_dict() already demonstrates the correct pattern for device-agnostic state loading.
🛠️ Suggested fix
if self._activations is None:
self._activations = activations
else:
+ if self._activations.device != activations.device:
+ self._activations = self._activations.to(activations.device)
self._activations += activations
@@
def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
- return {"activations": self._activations}
+ return {
+ "activations": None
+ if self._activations is None
+ else self._activations.cpu().clone()
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if self._activations is None: | |
| self._activations = activations | |
| else: | |
| self._activations += activations | |
| def accumulate(self) -> torch.Tensor: | |
| """Return the accumulated L2 norm of activations. | |
| Returns: | |
| Tensor of accumulated scores, one per channel | |
| Raises: | |
| AssertionError: If no activations have been collected yet | |
| """ | |
| assert self._activations is not None, "No activations collected for importance estimation." | |
| # Convert squared sum to L2 norm | |
| return self._activations.pow(0.5) | |
| def to_dict(self) -> dict[str, torch.Tensor]: | |
| """Convert to dict format for saving.""" | |
| return {"score": self.accumulate().cpu()} | |
| def state_dict(self) -> dict: | |
| """Return the state dictionary containing activations.""" | |
| return {"activations": self._activations} | |
| def load_state_dict(self, state_dict: dict) -> None: | |
| """Load activations from checkpoint.""" | |
| self._activations = state_dict["activations"] | |
| if self._activations is None: | |
| self._activations = activations | |
| else: | |
| if self._activations.device != activations.device: | |
| self._activations = self._activations.to(activations.device) | |
| self._activations += activations | |
| def accumulate(self) -> torch.Tensor: | |
| """Return the accumulated L2 norm of activations. | |
| Returns: | |
| Tensor of accumulated scores, one per channel | |
| Raises: | |
| AssertionError: If no activations have been collected yet | |
| """ | |
| assert self._activations is not None, "No activations collected for importance estimation." | |
| # Convert squared sum to L2 norm | |
| return self._activations.pow(0.5) | |
| def to_dict(self) -> dict[str, torch.Tensor]: | |
| """Convert to dict format for saving.""" | |
| return {"score": self.accumulate().cpu()} | |
| def state_dict(self) -> dict: | |
| """Return the state dictionary containing activations.""" | |
| return { | |
| "activations": None | |
| if self._activations is None | |
| else self._activations.cpu().clone() | |
| } | |
| def load_state_dict(self, state_dict: dict) -> None: | |
| """Load activations from checkpoint.""" | |
| self._activations = state_dict["activations"] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 231 -
259, The L2NormHook.load_state_dict currently assigns checkpointed
"_activations" directly, which can leave tensors on the wrong device and later
cause device-mismatch in L2NormHook.accumulate (and in the "+=" path). Update
L2NormHook.load_state_dict to move the loaded activations to the current
module/device before assigning (follow the pattern used in
IndependentChannelContributionHook.load_state_dict): detect the target device
(e.g., from a provided module or torch.device of existing tensors), call
.to(device) on state_dict["activations"] (preserving dtype), then assign to
self._activations so subsequent accumulate() and add_ operations are
device-safe.
| # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear) | ||
| # TODO: Consider better design to handle RowParallelLinear and nn.Linear | ||
| if isinstance(output, tuple): | ||
| output_tensor = output[0] | ||
| else: | ||
| output_tensor = output | ||
|
|
||
| activations = args[0] | ||
|
|
||
| # Don't aggregate activations from non-max subnets (e.g. from profiling) | ||
| if self.max_size is not None and activations.shape[-1] != self.max_size: | ||
| return | ||
|
|
||
| n_channels_to_prune = self.pruning_schedule[self.curr_iter] | ||
|
|
||
| curr_activations = activations.clone() # Shape B,T,I | ||
| curr_activations[..., self.pruned_channels] = 0 | ||
| output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E | ||
|
|
||
| if self.calibration_method is None: | ||
| scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T | ||
| elif self.calibration_method == "scale_by_magnitude": | ||
| output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T | ||
| output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T | ||
| scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | sed -n '440,500p'Repository: NVIDIA/Model-Optimizer
Length of output: 3427
🏁 Script executed:
cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 3769
🏁 Script executed:
ast-grep --pattern 'def $method($args) {
$$$
}'Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg -A 50 "Handle case where output is a tuple" modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2550
🏁 Script executed:
wc -l modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pyRepository: NVIDIA/Model-Optimizer
Length of output: 125
Include bias in the pruned forward recomputation.
For biased nn.Linear layers, output_tensor includes the bias term but output_curr does not. This mismatch causes the scaling factor computation (line 480) and subsequent channel importance ranking (line 487) to be incorrect. The bias difference must be accounted for in the recomputed forward pass.
Fix
curr_activations = activations.clone() # Shape B,T,I
curr_activations[..., self.pruned_channels] = 0
+ bias = None if isinstance(output, tuple) else getattr(module, "bias", None)
- output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E
+ output_curr = F.linear(
+ input=curr_activations,
+ weight=self.weight_matrix,
+ bias=bias,
+ ) # Shape B,T,E🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 456 -
480, The recomputed forward output (output_curr) omits the layer bias while
output_tensor includes it, skewing scaling_factor_per_token; modify the
recomputation in the block using curr_activations and self.weight_matrix so it
also adds the layer bias when present (e.g., compute output_curr =
F.linear(input=curr_activations, weight=self.weight_matrix, bias=self.bias) or
add self.bias expanded to output_curr's shape), ensuring the bias shape matches
output_curr before computing output_norms and scaling_factor_per_token;
reference symbols: output_tensor, output_curr, curr_activations,
self.weight_matrix, self.bias, scaling_factor_per_token.
| if set(ref_layers) != set(comp_layers): | ||
| print("\nERROR: Layer mismatch!") | ||
| print(f"Reference layers: {ref_layers}") | ||
| print(f"Compare layers: {comp_layers}") | ||
| return |
There was a problem hiding this comment.
Raise on layer mismatch instead of returning successfully.
A bad comparison is a hard failure. Returning here makes the CLI exit with status 0, so scripts and CI can silently accept invalid results.
🛠️ Suggested fix
if set(ref_layers) != set(comp_layers):
- print("\nERROR: Layer mismatch!")
- print(f"Reference layers: {ref_layers}")
- print(f"Compare layers: {comp_layers}")
- return
+ raise ValueError(
+ "Layer mismatch: "
+ f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if set(ref_layers) != set(comp_layers): | |
| print("\nERROR: Layer mismatch!") | |
| print(f"Reference layers: {ref_layers}") | |
| print(f"Compare layers: {comp_layers}") | |
| return | |
| if set(ref_layers) != set(comp_layers): | |
| raise ValueError( | |
| "Layer mismatch: " | |
| f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py` around
lines 240 - 244, Replace the current prints-and-return branch that handles layer
mismatches with a hard failure: instead of printing "ERROR: Layer mismatch!" and
returning, raise a RuntimeError (or other appropriate exception) that includes
the mismatched ref_layers and comp_layers in the message so the CLI/CI exits
non-zero; modify the code in compare_module_outputs.py where ref_layers and
comp_layers are compared (the current if set(ref_layers) != set(comp_layers):
block) to raise the exception with clear context.
| if ranks not in ["all", "main", "local_main", "last"]: | ||
| raise NotImplementedError( | ||
| f"Could not broadcast msg {msg} - " | ||
| f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" | ||
| ) |
There was a problem hiding this comment.
Error message is inconsistent with valid choices.
The validation checks for ["all", "main", "local_main", "last"] but the error message only lists ['all', 'main', 'local_main'], missing 'last'.
🐛 Proposed fix
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
- f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
+ f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if ranks not in ["all", "main", "local_main", "last"]: | |
| raise NotImplementedError( | |
| f"Could not broadcast msg {msg} - " | |
| f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}" | |
| ) | |
| if ranks not in ["all", "main", "local_main", "last"]: | |
| raise NotImplementedError( | |
| f"Could not broadcast msg {msg} - " | |
| f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/tools/logger.py` around lines 69 - 73, The
NotImplementedError message in the ranks validation (inside the function/method
that broadcasts messages, referencing variables msg and ranks) is missing the
'last' choice; update the error text to list all valid options consistently with
the check — e.g., include 'last' alongside 'all', 'main', and 'local_main' — so
the raised message accurately reflects the allowed ranks values.
| elif ( | ||
| (ranks == "main" and self.global_rank != 0) | ||
| or (ranks == "last" and self.local_rank != self.world_size - 1) | ||
| or (ranks == "local_main" and self.local_rank != 0) | ||
| ): | ||
| return |
There was a problem hiding this comment.
Logic bug in 'last' rank check for multi-node environments.
The condition self.local_rank != self.world_size - 1 compares local_rank (0 to local_world_size-1) against world_size (total processes across all nodes). In a multi-node setup with 8 total GPUs across 2 nodes, local_rank ranges 0-3 but world_size - 1 is 7, so the condition local_rank != 7 is always true and no rank ever prints.
Based on the docstring at line 167 ("rank -1 in node 0"), the intent seems to be the last rank on node 0.
🐛 Proposed fix
elif (
(ranks == "main" and self.global_rank != 0)
- or (ranks == "last" and self.local_rank != self.world_size - 1)
+ or (ranks == "last" and self.global_rank != self.world_size - 1)
or (ranks == "local_main" and self.local_rank != 0)
):
returnOr if the intent is truly "last local rank on each node":
- or (ranks == "last" and self.local_rank != self.world_size - 1)
+ or (ranks == "last" and self.local_rank != int(os.environ.get("LOCAL_WORLD_SIZE", 1)) - 1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/tools/logger.py` around lines 79 - 84, The "last"
branch incorrectly compares self.local_rank to self.world_size - 1; replace that
check to target the last local rank on node 0 by testing node and local sizes:
change the condition (ranks == "last" and self.local_rank != self.world_size -
1) to require that the process is not the last local rank on node 0 — e.g., use
self.node_rank and self.local_world_size and only allow printing when
(self.node_rank == 0 and self.local_rank == self.local_world_size - 1); update
the condition accordingly in the function that contains the ranks logic so it
references self.node_rank, self.local_rank, and self.local_world_size instead of
self.world_size.
| # TODO: Simplify it: this unit test is too long, | ||
| # hard to read (the same set of assertions across different test cases with if-else). | ||
|
|
||
| assert len(pruning_scores["activations_per_rank"]) == size | ||
| activations = pruning_scores["activations_per_rank"][rank] | ||
|
|
||
| # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4) | ||
| if size == 1 and pruned_ffn_div == 4: | ||
| # Layer scores | ||
| _assert_approx(pruning_scores["layer_scores"], {1: 0.028923, 2: 0.046508}) | ||
|
|
||
| # Validate decoder.layers.0.mlp activations | ||
| mlp_0_acts = activations["decoder.layers.0.mlp"] | ||
| _assert_approx(mlp_0_acts.min().item(), 0.000026) | ||
| _assert_approx(mlp_0_acts.max().item(), 0.000729) | ||
| _assert_approx(mlp_0_acts.mean().item(), 0.000201) | ||
|
|
||
| # Validate decoder.layers.1.mlp activations | ||
| mlp_1_acts = activations["decoder.layers.1.mlp"] | ||
| _assert_approx(mlp_1_acts.min().item(), 0.000022) | ||
| _assert_approx(mlp_1_acts.max().item(), 0.000762) | ||
| _assert_approx(mlp_1_acts.mean().item(), 0.000162) | ||
|
|
||
| # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2) | ||
| elif size == 1 and pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1: | ||
| # Layer scores | ||
| _assert_approx(pruning_scores["layer_scores"], {1: 0.028056, 2: 0.038353}) | ||
|
|
||
| # Validate decoder.layers.0.self_attention activations | ||
| attn_0_acts = activations["decoder.layers.0.self_attention"] | ||
| assert attn_0_acts.shape == torch.Size([hidden_size]) | ||
| _assert_approx(attn_0_acts.min().item(), 0.010091) | ||
| _assert_approx(attn_0_acts.max().item(), 0.023826) | ||
| _assert_approx(attn_0_acts.mean().item(), 0.014548) | ||
|
|
||
| # Validate decoder.layers.1.self_attention activations | ||
| attn_1_acts = activations["decoder.layers.1.self_attention"] | ||
| assert attn_1_acts.shape == torch.Size([hidden_size]) | ||
| _assert_approx(attn_1_acts.min().item(), 0.009982) | ||
| _assert_approx(attn_1_acts.max().item(), 0.035644) | ||
| _assert_approx(attn_1_acts.mean().item(), 0.020140) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if there are any hook cleanup or state reset mechanisms in the test utilities
rg -n "cleanup|reset|teardown" tests/_test_utils/torch/ -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 4291
🏁 Script executed:
#!/bin/bash
# Check the test file structure and imports
head -50 tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1818
🏁 Script executed:
#!/bin/bash
# Look for fixture definitions and pytest markers in the test file
rg -n "@pytest|@fixture|def setup|def teardown" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pyRepository: NVIDIA/Model-Optimizer
Length of output: 155
🏁 Script executed:
#!/bin/bash
# Check if the test uses DistributedProcess or similar distributed utilities
rg -n "DistributedProcess|distributed|teardown" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pyRepository: NVIDIA/Model-Optimizer
Length of output: 136
🏁 Script executed:
#!/bin/bash
# Check for conftest.py and pytest configuration
find tests/gpu_megatron -name "conftest.py" -exec cat {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 2622
🏁 Script executed:
#!/bin/bash
# Check pytest configuration
cat pyproject.toml | grep -A 20 "\[tool.pytest"Repository: NVIDIA/Model-Optimizer
Length of output: 756
🏁 Script executed:
#!/bin/bash
# Look at the full test function containing lines 217-257
sed -n '180,260p' tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3785
🏁 Script executed:
#!/bin/bash
# Get the full function signature and parametrize decorators
sed -n '110,180p' tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2465
🏁 Script executed:
#!/bin/bash
# Search for where this test function is defined to see if it uses dist_workers fixture
grep -n "def test_mcore_gpt_minitron_pruning_and_export" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
#!/bin/bash
# Check how set_seed is implemented
rg -n "def set_seed" tests/_test_utils/torch/ -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 808
🏁 Script executed:
#!/bin/bash
# Look for test function wrapper that calls _test_mcore_gpt_pruning
rg -n "test_mcore_gpt_pruning_and_export|test_mcore_gpt_pruning[^_]" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py -B 5 -A 15 | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 1842
Hard-coded activation values with tight tolerance may cause test fragility across different hardware and CUDA versions.
The test uses very specific hard-coded values (e.g., 0.028923, 0.046508) with abs=1e-3 tolerance. While proper teardown and seeding infrastructure is in place (via megatron_worker_teardown and set_seed), GPU-specific activation values can vary slightly across different hardware and CUDA versions, causing these assertions to fail even though the pruning logic is correct.
Consider:
- Using relative comparisons or shape/range checks instead of exact activation value assertions
- Documenting which GPU/CUDA version these values were captured on and re-validating periodically
- If exact values are needed, widen tolerance or make it configurable per environment
The TODO comment at line 217 also notes this test's complexity—splitting into focused unit tests would improve maintainability.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`
around lines 217 - 257, The test uses fragile, hardware-specific hard-coded
activation values (e.g., pruning_scores["layer_scores"], activations entries
like "decoder.layers.0.mlp" / "decoder.layers.0.self_attention") checked with
_assert_approx and abs=1e-3; replace these strict exact-value checks with more
robust assertions: validate tensor shapes and value ranges or use relative
tolerance (rtol) instead of strict absolute checks, or widen/make the tolerance
configurable via a test constant; for layer-level correctness prefer asserting
expected monotonic/relative relationships (e.g., pruned vs unpruned scores)
rather than exact floats; optionally split the large conditional block into
separate focused tests for MHA vs GQA using the same symbols (pruning_scores,
activations, _assert_approx) to improve readability and maintainability.
| def _run_hook_and_evaluate( | ||
| layer: nn.Linear, | ||
| hook, | ||
| num_iterations: int, | ||
| prune_ratio: float, | ||
| ) -> dict: | ||
| """Shared helper to run hook, collect scores, and evaluate. | ||
|
|
||
| Args: | ||
| layer: Linear layer to test | ||
| hook: Hook instance (already created) | ||
| num_iterations: Number of forward passes | ||
| prune_ratio: Fraction of channels to prune | ||
|
|
||
| Returns: | ||
| Dictionary with evaluation metrics | ||
| """ | ||
| handle = layer.register_forward_hook(hook) # Store the handle | ||
|
|
||
| # Run forward passes | ||
| all_activations = [] | ||
| for _ in range(num_iterations): | ||
| activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50 | ||
| all_activations.append(activations) | ||
| _ = layer(activations) | ||
|
|
||
| # Get importance scores from hook | ||
| importance_scores = hook.accumulate() | ||
|
|
||
| # Remove the hook before evaluation to avoid triggering it again | ||
| handle.remove() | ||
|
|
||
| # Evaluate the importance scores by simulating pruning on all collected activations | ||
| # Pass the list of activations to compute averaged metrics across batches | ||
| metrics = evaluate_importance_scores( | ||
| layer, | ||
| all_activations, # List of activation batches | ||
| importance_scores, | ||
| prune_ratio=prune_ratio, | ||
| ) | ||
|
|
||
| return metrics |
There was a problem hiding this comment.
Always remove the forward hook in a finally block.
handle.remove() only runs on the happy path. If the forward loop or hook.accumulate() fails, the hook stays attached in the reused worker and can bleed into the next scenario.
🛠️ Suggested fix
def _run_hook_and_evaluate(
layer: nn.Linear,
hook,
num_iterations: int,
@@
- handle = layer.register_forward_hook(hook) # Store the handle
-
- # Run forward passes
- all_activations = []
- for _ in range(num_iterations):
- activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50
- all_activations.append(activations)
- _ = layer(activations)
-
- # Get importance scores from hook
- importance_scores = hook.accumulate()
-
- # Remove the hook before evaluation to avoid triggering it again
- handle.remove()
+ handle = layer.register_forward_hook(hook)
+ try:
+ all_activations = []
+ for _ in range(num_iterations):
+ activations = torch.randn(16, 8, layer.in_features)
+ all_activations.append(activations)
+ _ = layer(activations)
+
+ importance_scores = hook.accumulate()
+ finally:
+ handle.remove()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py`
around lines 113 - 154, The forward hook handle created in
_run_hook_and_evaluate is only removed on the happy path; wrap the work that
runs the forward passes and calls hook.accumulate (the loop that appends to
all_activations and the call to importance_scores = hook.accumulate()) in a
try/finally and call handle.remove() in the finally block to guarantee the hook
is detached even on exceptions; re-raise the exception after cleanup if one
occurred so failures are propagated, and keep the subsequent call to
evaluate_importance_scores unchanged.
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)
248-259: Shape assertions are robust; consider extending this pattern.The
assert attn_0_acts.shape == torch.Size([hidden_size])checks (lines 249, 256) are hardware-independent and reliable. Consider adopting similar structural assertions for the MLP activations in the first test case (lines 231-240) to improve robustness while maintaining the statistical checks as secondary validation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py` around lines 248 - 259, Add hardware-independent shape assertions for the MLP activation entries similar to the attention checks: for the MLP activations retrieved from the activations dict (e.g., activations["decoder.layers.0.mlp"] and activations["decoder.layers.1.mlp"]) assert their .shape equals torch.Size([hidden_size]) before or alongside the existing min/max/mean _assert_approx checks; update the test variables corresponding to the first test case’s MLP activations (the variables that hold those activation tensors) to include these assert ... .shape == torch.Size([hidden_size]) statements.modelopt/torch/utils/robust_json.py (2)
47-48: Fragile dtype detection via string comparison.Matching
type(o).__name__ == "dtype"will catch any class named "dtype", not justnumpy.dtypeortorch.dtype. This could lead to unexpected behavior with other types. Consider adding a comment documenting this intentional duck-typing approach, or checkingo.__class__.__module__as well (e.g., starts withnumpyortorch).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/robust_json.py` around lines 47 - 48, The current fragile dtype detection uses type(o).__name__ == "dtype"; update it to only match actual numpy/torch dtypes by also checking the class module (e.g., check o.__class__.__module__ and ensure it startswith "numpy" or "torch") or explicitly check for instances via duck-typing attributes, and add a brief comment explaining the intent; replace the condition referencing type(o).__name__ == "dtype" with a combined check using o.__class__.__module__ (or equivalent) so only numpy/torch dtype classes are matched in robust_json.py.
15-15: Consider removing the blanket mypy suppression.Disabling mypy for the entire file with
# mypy: ignore-errorsbypasses type checking. The file already has type annotations, so targeted# type: ignorecomments on specific problematic lines (if any) would be preferable to maintain type safety for the rest of the module.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/utils/robust_json.py` at line 15, Remove the top-level blanket mypy suppression in robust_json.py (the `# mypy: ignore-errors` comment) and instead address specific typing issues: delete that header, run mypy to find failing lines, and add targeted `# type: ignore[...]` comments only on the problematic expressions (or refine annotations on functions/classes such as any helper functions in robust_json.py) so the rest of the module retains type checking.modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py (2)
675-681: Consider documenting why checkpointing is not supported.The
state_dictandload_state_dictmethods raiseNotImplementedError. While this is a valid design choice, adding a brief explanation in the docstring or error message (e.g., "KV head pruning is designed to complete in a single run") would help users understand the limitation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 675 - 681, Update the docstrings and/or raised NotImplementedError messages for the state_dict and load_state_dict methods in base_hooks.py to briefly explain why checkpointing isn't supported (e.g., "Checkpointing not supported because KV head pruning is a one-time operation and does not maintain persistent mutable state across runs"), so users understand the design choice; modify the docstrings of state_dict and load_state_dict and the text passed to NotImplementedError in those methods to include that brief rationale.
810-815: Consider usingjson_dumpfromrobust_jsonfor consistency.The module imports
json_dumpfrommodelopt.torch.utils.robust_json(line 30) but uses the standardjson.dumphere. Usingjson_dumpwould ensure consistent JSON serialization across the codebase and automatically create parent directories.Suggested change
- output_path = activations_log_dir / "channel_importance_results.json" - aprint(f"Saving channel importance data to {output_path}") - with open(output_path, "w") as f: - json.dump(output_data, f, indent=2) + output_path = activations_log_dir / "channel_importance_results.json" + aprint(f"Saving channel importance data to {output_path}") + json_dump(output_data, output_path)Note: This would lose the
indent=2formatting. If pretty-printing is important, consider adding anindentparameter tojson_dumpor keep the current approach.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 810 - 815, Replace the use of json.dump when writing output_data to output_path with the project utility json_dump (imported from modelopt.torch.utils.robust_json) so serialization is consistent and parent directories are created automatically; update the write block that currently uses output_path and json.dump to call json_dump(output_data, output_path) (or add an indent param to json_dump if pretty-printing is required) and remove the manual open(...) context since json_dump handles file creation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/utils/logging.py`:
- Around line 206-208: The aprint wrapper currently always passes flush=True
while also forwarding kwargs, causing a duplicate-key TypeError when callers
pass flush in kwargs; update aprint to handle/merge the flush kwarg by
extracting it from kwargs (e.g., pop 'flush' if present) and then call print
with the resolved flush value (use True as the default if not provided) so
callers can override flush without causing a duplicate keyword error; reference
the aprint function in logging.py to locate the change.
In `@modelopt/torch/utils/robust_json.py`:
- Around line 74-78: The return type of json_load is too narrow (annotated as ->
dict) while json.loads can return dict, list, str, int, float, bool, or None;
update json_load's return annotation to reflect that (e.g., use typing.Any or
define a JSONType = Union[dict, list, str, int, float, bool, None] and use it)
and adjust the docstring to say "Return parsed JSON value" instead of "return as
dictionary"; reference the json_load function and the json.loads call when
making this change.
---
Nitpick comments:
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Around line 675-681: Update the docstrings and/or raised NotImplementedError
messages for the state_dict and load_state_dict methods in base_hooks.py to
briefly explain why checkpointing isn't supported (e.g., "Checkpointing not
supported because KV head pruning is a one-time operation and does not maintain
persistent mutable state across runs"), so users understand the design choice;
modify the docstrings of state_dict and load_state_dict and the text passed to
NotImplementedError in those methods to include that brief rationale.
- Around line 810-815: Replace the use of json.dump when writing output_data to
output_path with the project utility json_dump (imported from
modelopt.torch.utils.robust_json) so serialization is consistent and parent
directories are created automatically; update the write block that currently
uses output_path and json.dump to call json_dump(output_data, output_path) (or
add an indent param to json_dump if pretty-printing is required) and remove the
manual open(...) context since json_dump handles file creation.
In `@modelopt/torch/utils/robust_json.py`:
- Around line 47-48: The current fragile dtype detection uses type(o).__name__
== "dtype"; update it to only match actual numpy/torch dtypes by also checking
the class module (e.g., check o.__class__.__module__ and ensure it startswith
"numpy" or "torch") or explicitly check for instances via duck-typing
attributes, and add a brief comment explaining the intent; replace the condition
referencing type(o).__name__ == "dtype" with a combined check using
o.__class__.__module__ (or equivalent) so only numpy/torch dtype classes are
matched in robust_json.py.
- Line 15: Remove the top-level blanket mypy suppression in robust_json.py (the
`# mypy: ignore-errors` comment) and instead address specific typing issues:
delete that header, run mypy to find failing lines, and add targeted `# type:
ignore[...]` comments only on the problematic expressions (or refine annotations
on functions/classes such as any helper functions in robust_json.py) so the rest
of the module retains type checking.
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`:
- Around line 248-259: Add hardware-independent shape assertions for the MLP
activation entries similar to the attention checks: for the MLP activations
retrieved from the activations dict (e.g., activations["decoder.layers.0.mlp"]
and activations["decoder.layers.1.mlp"]) assert their .shape equals
torch.Size([hidden_size]) before or alongside the existing min/max/mean
_assert_approx checks; update the test variables corresponding to the first test
case’s MLP activations (the variables that hold those activation tensors) to
include these assert ... .shape == torch.Size([hidden_size]) statements.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ec61ff4f-cb7a-4262-91ac-c9d486e2dd0f
📒 Files selected for processing (4)
modelopt/torch/nas/plugins/megatron_hooks/base_hooks.pymodelopt/torch/utils/logging.pymodelopt/torch/utils/robust_json.pytests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
| def aprint(*args, **kwargs): | ||
| """All ranks from all nodes print.""" | ||
| print(*args, **kwargs, flush=True) |
There was a problem hiding this comment.
Avoid duplicating the flush keyword.
aprint(..., flush=False) currently raises TypeError because flush is passed both via **kwargs and as an explicit keyword. Since this is a new public wrapper around print, it should merge or override the kwarg instead of duplicating it.
Proposed fix
def aprint(*args, **kwargs):
"""All ranks from all nodes print."""
- print(*args, **kwargs, flush=True)
+ kwargs = {**kwargs, "flush": True}
+ print(*args, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/utils/logging.py` around lines 206 - 208, The aprint wrapper
currently always passes flush=True while also forwarding kwargs, causing a
duplicate-key TypeError when callers pass flush in kwargs; update aprint to
handle/merge the flush kwarg by extracting it from kwargs (e.g., pop 'flush' if
present) and then call print with the resolved flush value (use True as the
default if not provided) so callers can override flush without causing a
duplicate keyword error; reference the aprint function in logging.py to locate
the change.
| def json_load(path: Path | str) -> dict: | ||
| """Load JSON from file and return as dictionary.""" | ||
| path = Path(path) | ||
| text = path.read_text() | ||
| return json.loads(text) |
There was a problem hiding this comment.
Return type annotation is too restrictive.
json.loads can return various types (dict, list, str, int, float, bool, or None) depending on the JSON content. The current -> dict annotation is incorrect if the JSON root is an array or primitive value. This could cause runtime type mismatches for callers.
🔧 Proposed fix
-def json_load(path: Path | str) -> dict:
- """Load JSON from file and return as dictionary."""
+def json_load(path: Path | str) -> Any:
+ """Load JSON from file and return the deserialized object."""
path = Path(path)
text = path.read_text()
return json.loads(text)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/utils/robust_json.py` around lines 74 - 78, The return type of
json_load is too narrow (annotated as -> dict) while json.loads can return dict,
list, str, int, float, bool, or None; update json_load's return annotation to
reflect that (e.g., use typing.Any or define a JSONType = Union[dict, list, str,
int, float, bool, None] and use it) and adjust the docstring to say "Return
parsed JSON value" instead of "return as dictionary"; reference the json_load
function and the json.loads call when making this change.
There was a problem hiding this comment.
How about following file structure:
modelopt/torch/prune/importance_hooks/
|- base_hooks.py
|- base_hook_analysis.py
|- compare_module_output.py
|- plugins/
|- megatron_hooks.py
What does this PR do?
Type of change: Redesign of existing feature
This PR introduces a shared activation hooks infrastructure for minitron and puzzletron. The activation hooks framework provides a reusable component for collecting and analyzing activations during forward passes, which is used by both minitron pruning and puzzletron algorithms.
Note! Minitron megatron.py/mcore_minitron.py:ImportanceEstimatorRegistry code does not use this component yet - will be refactored in a separate MR.
Key changes:
Added
modelopt/torch/nas/plugins/megatron_hooks/module with base hooks framework:base_hooks.py: Core hook infrastructure for registering and managing forward hooksbase_hooks_analysis.py: Analysis utilities for processing collected activationsmegatron_hooks.py: Megatron-specific hook implementationscompare_module_outputs.py: Utilities for comparing module outputsAdded unit tests in
tests/gpu/torch/nas/plugins/megatron_hooks/:test_base_hooks.py: Tests for base hooks functionalitytest_base_hooks_analysis.py: Tests for activation analysis utilitiesUpdated
test_mcore_gpt_minitron_pruning.pyto validate activation collectionUpdated test utilities for distributed testing support
Before your PR is "Ready for review"
CONTRIBUTING.md: ✅ NoSummary by CodeRabbit
Release Notes
New Features
Tests