From 67a594faa5762eab63d8c3b90c6a9cfdd3fe8a6d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 23 Mar 2026 14:55:17 +0100 Subject: [PATCH 1/5] Remove module reference from autograd function args Extract weight quantization into standalone `quantize_weight()` function in base.py, eliminating the need to pass `self` (nn.Module) into autograd functions. Each op's autograd function now receives/returns Optional[Tensor] weight workspaces instead, with cache management handled by the nn.Module before/after the autograd call. Signed-off-by: Pawel Gadzinski Made-with: Cursor --- transformer_engine/pytorch/module/base.py | 265 ++++++++++-------- .../pytorch/module/grouped_linear.py | 34 ++- .../pytorch/module/layernorm_linear.py | 41 ++- .../pytorch/module/layernorm_mlp.py | 89 ++++-- transformer_engine/pytorch/module/linear.py | 38 ++- 5 files changed, 288 insertions(+), 179 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..4f5554907c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -603,6 +603,131 @@ def fill_userbuffers_buffer_for_all_gather( raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})") +def _is_weight_workspace_valid( + workspace: QuantizedTensorStorage, + quantizer: Quantizer, +) -> bool: + """Check if a cached weight workspace is compatible with the quantizer's current usage.""" + if isinstance(workspace, Float8TensorStorage): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and workspace._transpose is None + ): + return False + elif isinstance(workspace, MXFP8TensorStorage): + if quantizer.rowwise_usage and workspace._rowwise_data is None: + return False + if quantizer.columnwise_usage and workspace._columnwise_data is None: + return False + elif isinstance(workspace, NVFP4TensorStorage): + if quantizer.rowwise_usage and workspace._rowwise_data is None: + return False + if quantizer.columnwise_usage and workspace._columnwise_data is None: + return False + if isinstance(workspace, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + return False + return True + + +def quantize_weight( + *, + tensor: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + workspace: Optional[QuantizedTensorStorage] = None, + update_workspace: bool = True, + skip_update_flag: Optional[torch.Tensor] = None, + fsdp_group: Optional["dist_group_type"] = None, + workspace_dtype: Optional[torch.dtype] = None, + cache: bool = False, +) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]: + """Quantize a weight tensor, optionally reusing a cached workspace. + + This is a standalone function (no module reference) that can be called + from inside ``torch.autograd.Function.forward``. + + Parameters + ---------- + tensor: torch.Tensor, optional + Weight tensor to quantize. + quantizer: Quantizer, optional + Quantizer for casting the weight. + workspace: QuantizedTensorStorage, optional + Previously cached workspace (from the module's ``_fp8_workspaces``). + ``None`` indicates a cache miss. + update_workspace: bool, default = True + Whether to update an existing workspace with fresh values. + skip_update_flag: torch.Tensor, optional + GPU flag to conditionally skip the update. + fsdp_group: dist_group_type, optional + FSDP process group the weights are distributed over. + workspace_dtype: torch.dtype, optional + High-precision dtype for debug quantization workspaces. + cache: bool, default = False + If ``True`` and a new workspace is created, it will be returned + as the second element so the caller can store it. + + Returns + ------- + (weightmat, new_workspace) + *weightmat*: quantized weight ready for GEMM. + *new_workspace*: non-``None`` only when a brand-new workspace was + created **and** ``cache=True``. The caller should store it in + ``_fp8_workspaces``. + """ + + # Already-quantized weight (primary FP8 parameters) + if isinstance(tensor, QuantizedTensor): + update_rowwise = True if quantizer.rowwise_usage else None + update_columnwise = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise, + columnwise_usage=update_columnwise, + ) + return tensor, None + + # Validate workspace + if workspace is not None and quantizer is not None: + if not _is_weight_workspace_valid(workspace, quantizer): + workspace = None + + # FSDP gather on cached workspace + if ( + workspace is not None + and tensor is not None + and fsdp_group is not None + and workspace.data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], workspace) + + # Cache hit — update in-place and return + if workspace is not None: + if skip_update_flag is not None: + update_workspace = True + if update_workspace: + if tensor is None: + raise ValueError("tensor kwarg must be provided to update FP8 workspace") + if hasattr(workspace, "quantize_"): + workspace.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, workspace, skip_update_flag) + return workspace, None + + # Cache miss — create new workspace + if tensor is None or quantizer is None: + raise ValueError( + "tensor and quantizer kwargs must be provided to construct FP8 workspace" + ) + if cache: + saved_internal = quantizer.internal + quantizer.internal = False + out = quantizer.quantize(tensor, dtype=workspace_dtype) + if cache: + quantizer.internal = saved_internal + return out, out + return out, None + + class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" @@ -1343,130 +1468,26 @@ def clear(self): def forward(self): """Needs override.""" - def get_weight_workspace( - self, - *, - tensor: Optional[torch.Tensor] = None, - quantizer: Optional[Quantizer] = None, - cache_name: Optional[str] = None, - update_workspace: bool = True, - skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: Optional[dist_group_type] = None, - workspace_dtype: Optional[torch.dtype] = None, - ) -> QuantizedTensor: - """Get workspace buffer for weights and maybe update its values - - The workspace buffer may be cached for future function calls. - - Parameters - ---------- - tensor : torch.Tensor, optional - Values to copy into workspace. Required if the workspace - is being constructed or updated. - quantizer: Quantizer, optional - Quantizer used to cast the weights. Required if the - workspace is being constructed or updated. - cache_name: str, optional - Key for caching. - update_workspace: bool, default = True - Update workspace with values from `tensor`. - skip_update_flag: torch.Tensor, optional - GPU flag to skip updating the workspace. Take precedence - over `update_workspace` if provided. - fsdp_group: bool, default = None - FSDP process group that the weights are distributed over. - workspace_dtype: torch.dtype, default = None - If weight workspace contains high-precision tensor - for example - for debug quantization, this is dtype of the tensor. - """ - - # Handle case where weights are already quantized - # Note: Make sure weights have required usages, but do not - # destroy unnecessary usages since they may be used later. - if isinstance(tensor, QuantizedTensor): - update_rowwise_usage = True if quantizer.rowwise_usage else None - update_columnwise_usage = True if quantizer.columnwise_usage else None - tensor.update_usage( - rowwise_usage=update_rowwise_usage, - columnwise_usage=update_columnwise_usage, - ) - return tensor - - # Try getting workspace from cache - out = None - if cache_name is not None: - out = self._fp8_workspaces.get(cache_name, None) + def get_weight_workspace(self, **kwargs) -> "QuantizedTensor": + """Get workspace buffer for weights and maybe update its values. - # Reset cache if workspace is invalid - if out is not None and quantizer is not None: - reset_cache = False - if isinstance(out, Float8TensorStorage): - if ( - not is_non_tn_fp8_gemm_supported() - and quantizer.columnwise_usage - and out._transpose is None - ): - reset_cache = True - elif isinstance(out, MXFP8TensorStorage): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - elif isinstance(out, NVFP4TensorStorage): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): - reset_cache = True - if reset_cache: - out = None - del self._fp8_workspaces[cache_name] - - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - out is not None - and tensor is not None - and fsdp_group is not None - and out.data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) + Thin wrapper around :func:`quantize_weight` that manages the + ``_fp8_workspaces`` cache on the module. - # Construct workspace if needed - if out is None: - if tensor is None or quantizer is None: - raise ValueError( - "tensor and quantizer kwargs must be provided to construct FP8 workspace" - ) - - if cache_name is not None: - # Ensure the tensor in the cache is an instance of torch.Tensor, - # as it persists beyond a single forward pass. - # Setting internal=True would cause the data to be removed in prepare_for_saving(...). - quantizer_internal = quantizer.internal - quantizer.internal = False - out = quantizer.quantize(tensor, dtype=workspace_dtype) - if cache_name is not None: - quantizer.internal = quantizer_internal - - # Update cache - if cache_name is not None: - self._fp8_workspaces[cache_name] = out - return out - - # Update workspace if needed - if skip_update_flag is not None: - update_workspace = True - if update_workspace: - if tensor is None: - raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if hasattr(out, "quantize_"): - out.quantize_(tensor, noop_flag=skip_update_flag) - else: - tex.quantize(tensor, quantizer, out, skip_update_flag) - return out + See :func:`quantize_weight` for the full parameter list. The only + difference is that *cache_name* controls lookup/storage in + ``self._fp8_workspaces``. + """ + cache_name = kwargs.pop("cache_name", None) + workspace = self._fp8_workspaces.get(cache_name) if cache_name is not None else None + result, new_workspace = quantize_weight( + workspace=workspace, + cache=cache_name is not None, + **kwargs, + ) + if new_workspace is not None and cache_name is not None: + self._fp8_workspaces[cache_name] = new_workspace + return result def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2f859e748b..9b1e86f6f9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -17,6 +17,7 @@ from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, + quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -93,7 +94,9 @@ def forward( sequence_parallel, activation_dtype, is_grad_enabled, - module, + weight_workspaces, + new_workspaces_out, + cache_weight, skip_fp8_weight_update, save_original_input, debug, @@ -167,19 +170,21 @@ def forward( # Initialize weights weights_fp8: list if fp8 or debug: - # FP8 cast to workspace buffer weights_fp8 = [] - update_workspace = is_first_microbatch is None or is_first_microbatch + update_ws = is_first_microbatch is None or is_first_microbatch for i in range(num_gemms): - weight_fp8 = module.get_weight_workspace( + weight_fp8, new_ws = quantize_weight( tensor=weights[i], quantizer=weight_quantizers[i], - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, + workspace=weight_workspaces[i] if weight_workspaces else None, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, workspace_dtype=activation_dtype, + cache=cache_weight, ) weights_fp8.append(weight_fp8) + if new_workspaces_out is not None: + new_workspaces_out[i] = new_ws else: weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] @@ -893,6 +898,14 @@ def forward( linear_fn = _GroupedLinear.forward autograd_ctx = [None] + num_gemms = len(m_splits) + cache_weight = is_first_microbatch is not None + weight_workspaces = [ + self._fp8_workspaces.get(f"weight{i}") if cache_weight else None + for i in range(num_gemms) + ] + new_workspaces_out = [None] * num_gemms if cache_weight else None + non_tensor_args = ( m_splits, self.apply_bias, @@ -911,13 +924,20 @@ def forward( self.sequence_parallel, self.activation_dtype, is_grad_enabled, - self, + weight_workspaces, + new_workspaces_out, + cache_weight, None, # skip_fp8_weight_update self.save_original_input, debug, ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + if new_workspaces_out is not None: + for i, ws in enumerate(new_workspaces_out): + if ws is not None: + self._fp8_workspaces[f"weight{i}"] = ws + finally: self.end_forward() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..af2c59b330 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -20,6 +20,7 @@ from .base import ( fill_userbuffers_buffer_for_all_gather, get_ub, + quantize_weight, TransformerEngineBaseModule, get_dummy_wgrad, _2X_ACC_FPROP, @@ -94,9 +95,10 @@ def forward( ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], weight: torch.Tensor, + weight_workspace: Optional[torch.Tensor], bias: torch.Tensor, non_tensor_args: Tuple, - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring # Reduce number of arguments to autograd function in order @@ -136,7 +138,7 @@ def forward( ub_bulk_dgrad, ub_name, fsdp_group, - module, + cache_weight, skip_fp8_weight_update, symmetric_ar_type, debug, @@ -286,6 +288,7 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + new_weight_workspace = None weightmat = weight is_weight_param_quantized = False if fp8 or debug: @@ -299,15 +302,16 @@ def forward( weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) # Get quantized weight - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( + update_ws = is_first_microbatch is None or is_first_microbatch + weightmat, new_weight_workspace = quantize_weight( tensor=weight, quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, + workspace=weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) weightmat.update_usage(rowwise_usage=True) @@ -527,13 +531,15 @@ def forward( # Cached state for backward pass is ready... # ------------------------------------------------------ + ln_out_for_return = None if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp_shape) shape[0] *= tp_size if with_input_all_gather else 1 - return out, ln_out_return.view(shape) - return out, ln_out_return.view(inp_shape) - return out + ln_out_for_return = ln_out_return.view(shape) + else: + ln_out_for_return = ln_out_return.view(inp_shape) + return out, ln_out_for_return, new_weight_workspace @staticmethod def backward( @@ -1025,6 +1031,7 @@ def wgrad_gemm( dgamma, dbeta, wgrad, + None, # weight_workspace grad_bias, None, ) @@ -1540,6 +1547,11 @@ def forward( else: fwd_fn = _LayerNormLinear.forward autograd_ctx = [None] + cache_name = None if is_first_microbatch is None else "weight" + weight_workspace = ( + self._fp8_workspaces.get(cache_name) if cache_name is not None else None + ) + non_tensor_args = ( self.eps, is_first_microbatch, @@ -1575,27 +1587,28 @@ def forward( self.ub_bulk_dgrad, self.ub_name, self.fsdp_group, - self, + cache_name is not None, skip_fp8_weight_update, self.symmetric_ar_type, debug, ) - out = fwd_fn( + out, ln_out, new_weight_workspace = fwd_fn( *autograd_ctx, inp, self.layer_norm_weight, self.layer_norm_bias, weight_tensor, + weight_workspace, bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, non_tensor_args, ) + if new_weight_workspace is not None and cache_name is not None: + self._fp8_workspaces[cache_name] = new_weight_workspace + finally: self.end_forward() - if self.return_layernorm_output: - out, ln_out = out - if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4532ea60e7..d4476d6ac8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -22,6 +22,7 @@ fill_userbuffers_buffer_for_all_gather, _ub_communicators, get_ub, + quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -176,8 +177,10 @@ def _forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, + fc1_weight_workspace: Optional[torch.Tensor], fc1_bias: torch.Tensor, fc2_weight: torch.Tensor, + fc2_weight_workspace: Optional[torch.Tensor], fc2_bias: torch.Tensor, non_tensor_args: Tuple, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: @@ -228,7 +231,8 @@ def _forward( ub_bulk_dgrad, gemm_gelu_fusion, fsdp_group, - module, + fp8_meta, + cache_weight, skip_fp8_weight_update, symmetric_ar_type, checkpoint, @@ -250,7 +254,7 @@ def _forward( == "DelayedScaling" ): # only applicable for delayed scaling FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute( - module.fp8_meta + fp8_meta ) # to restore quantizers during recomputation # save the rng states ctx.cpu_rng_state = torch.get_rng_state() @@ -321,7 +325,8 @@ def _forward( "ub_bulk_dgrad": ub_bulk_dgrad, "gemm_gelu_fusion": gemm_gelu_fusion, "fsdp_group": fsdp_group, - "module": module, + "fp8_meta": fp8_meta, + "cache_weight": False, "skip_fp8_weight_update": skip_fp8_weight_update, "symmetric_ar_type": symmetric_ar_type, "checkpoint": checkpoint, @@ -465,13 +470,12 @@ def _forward( ln_out_total = ln_out # Cast weights to expected dtype + new_fc1_weight_workspace = None + new_fc2_weight_workspace = None fc1_weight_final = fc1_weight fc2_weight_final = fc2_weight if fp8 or debug: - # If weights are not quantized, we call get_weight_workspace, - # which handles weight caching etc. - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch + update_ws = is_first_microbatch is None or is_first_microbatch # No need to set the quantizer states if weights are already quantized if isinstance(fc1_weight, QuantizedTensorStorage): fc1_weight_quantizer = fc1_weight._quantizer @@ -483,23 +487,25 @@ def _forward( elif fc2_weight_quantizer is not None: fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - fc1_weight_final = module.get_weight_workspace( + fc1_weight_final, new_fc1_weight_workspace = quantize_weight( tensor=fc1_weight, quantizer=fc1_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc1_weight"), - update_workspace=update_workspace, + workspace=fc1_weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) - fc2_weight_final = module.get_weight_workspace( + fc2_weight_final, new_fc2_weight_workspace = quantize_weight( tensor=fc2_weight, quantizer=fc2_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc2_weight"), - update_workspace=update_workspace, + workspace=fc2_weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) fc1_weight_final.update_usage(rowwise_usage=True) fc2_weight_final.update_usage(rowwise_usage=True) @@ -858,13 +864,15 @@ def _forward( ) # we only get to this point if we are not recomputing for bwd, since that would have returned in the block above + ln_out_for_return = None if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp_shape) shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1 - return fc2_out, ln_out_return.view(shape) - return fc2_out, ln_out_return.view(inp_shape) - return fc2_out + ln_out_for_return = ln_out_return.view(shape) + else: + ln_out_for_return = ln_out_return.view(inp_shape) + return fc2_out, ln_out_for_return, new_fc1_weight_workspace, new_fc2_weight_workspace @staticmethod def forward( @@ -873,11 +881,13 @@ def forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, + fc1_weight_workspace: Optional[torch.Tensor], fc1_bias: torch.Tensor, fc2_weight: torch.Tensor, + fc2_weight_workspace: Optional[torch.Tensor], fc2_bias: torch.Tensor, non_tensor_args: Tuple, - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring # add recompute_for_bwd @@ -889,8 +899,10 @@ def forward( ln_weight, ln_bias, fc1_weight, + fc1_weight_workspace, fc1_bias, fc2_weight, + fc2_weight_workspace, fc2_bias, non_tensor_args, ) @@ -918,7 +930,7 @@ def _recompute(ctx): and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" ): # only applicable for delayed scaling FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute( - ctx.other_args["module"].fp8_meta + ctx.other_args["fp8_meta"] ) # set old quantizer state # get current rng state @@ -929,9 +941,16 @@ def _recompute(ctx): torch.set_rng_state(ctx.cpu_rng_state) _set_cuda_rng_state(ctx.cuda_rng_state) + # Unpack saved tensors and pass None for weight workspaces (recomputed from scratch) + ( + inp_r, ln_weight_r, ln_bias_r, + fc1_weight_r, fc1_bias_r, fc2_weight_r, fc2_bias_r, + ) = tensors out = _LayerNormMLP._forward( # recompute ctx, - *tensors, + inp_r, ln_weight_r, ln_bias_r, + fc1_weight_r, None, fc1_bias_r, + fc2_weight_r, None, fc2_bias_r, tuple(ctx.other_args.values()), ) @@ -941,7 +960,7 @@ def _recompute(ctx): and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" ): FP8GlobalStateManager.restore_fp8_meta_tensors( - ctx.other_args["module"].fp8_meta + ctx.other_args["fp8_meta"] ) # restore quantizers # set rng state for fwd @@ -1642,8 +1661,10 @@ def fc1_wgrad_gemm( dgamma, dbeta, fc1_wgrad, + None, # fc1_weight_workspace fc1_bias_grad if fc1_bias is not None else None, fc2_wgrad, # pylint: disable=possibly-used-before-assignment + None, # fc2_weight_workspace fc2_bias_grad, None, ) @@ -2107,6 +2128,19 @@ def forward( fwd_fn = _LayerNormMLP.forward autograd_ctx = [None] + cache_name_fc1 = None if is_first_microbatch is None else "fc1_weight" + cache_name_fc2 = None if is_first_microbatch is None else "fc2_weight" + fc1_weight_workspace = ( + self._fp8_workspaces.get(cache_name_fc1) + if cache_name_fc1 is not None + else None + ) + fc2_weight_workspace = ( + self._fp8_workspaces.get(cache_name_fc2) + if cache_name_fc2 is not None + else None + ) + non_tensor_args = ( self.eps, is_first_microbatch, @@ -2150,30 +2184,35 @@ def forward( self.ub_bulk_wgrad, self.gemm_gelu_fusion and not debug, self.fsdp_group, - self, + self.fp8_meta, + cache_name_fc1 is not None, skip_fp8_weight_update, self.symmetric_ar_type, self.checkpoint, debug, ) - out = fwd_fn( + out, ln_out, new_fc1_ws, new_fc2_ws = fwd_fn( *autograd_ctx, inp, self.layer_norm_weight, self.layer_norm_bias, fc1_weight, + fc1_weight_workspace, fc1_bias, fc2_weight, + fc2_weight_workspace, fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, non_tensor_args, ) + if new_fc1_ws is not None and cache_name_fc1 is not None: + self._fp8_workspaces[cache_name_fc1] = new_fc1_ws + if new_fc2_ws is not None and cache_name_fc2 is not None: + self._fp8_workspaces[cache_name_fc2] = new_fc2_ws + finally: self.end_forward() - if self.return_layernorm_output: - out, ln_out = out - if self.gemm_bias_unfused_add: out = out + cast_if_needed(fc2_bias, self.activation_dtype) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..520c82ade4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -19,6 +19,7 @@ fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, + quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -88,10 +89,11 @@ class _Linear(torch.autograd.Function): def forward( ctx, weight: torch.Tensor, + weight_workspace: Optional[torch.Tensor], inp: torch.Tensor, bias: Optional[torch.Tensor], non_tensor_args: Tuple, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # pylint: disable=missing-function-docstring ( @@ -123,7 +125,7 @@ def forward( ub_name, fp8_output, # pylint: disable=unused-variable fsdp_group, - module, + cache_weight, skip_fp8_weight_update, symmetric_ar_type, save_original_input, @@ -249,6 +251,7 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + new_weight_workspace = None weightmat = weight if fp8 or debug: # Configure quantizer @@ -262,18 +265,18 @@ def forward( ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) elif isinstance(weight, QuantizedTensor): - # If weight is already quantized, no need to set quantizer states weight_quantizer = weight._quantizer # Get quantized weight - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( + update_ws = is_first_microbatch is None or is_first_microbatch + weightmat, new_weight_workspace = quantize_weight( tensor=weight, quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, + workspace=weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) weightmat.update_usage(rowwise_usage=True) @@ -490,10 +493,12 @@ def forward( # Cached state for backward pass is ready... # ------------------------------------------------------ - return out + return out, new_weight_workspace @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + def backward( + ctx, grad_output: torch.Tensor, _grad_weight_workspace + ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring # NVTX label for profiling @@ -977,6 +982,7 @@ def wgrad_gemm( _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( wgrad, + None, # weight_workspace dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, None, @@ -1424,6 +1430,11 @@ def forward( linear_fn = _Linear.forward autograd_ctx = [None] + cache_name = None if is_first_microbatch is None else "weight" + weight_workspace = ( + self._fp8_workspaces.get(cache_name) if cache_name is not None else None + ) + non_tensor_args = ( is_first_microbatch, self.fp8, @@ -1453,19 +1464,24 @@ def forward( self.ub_name, fp8_output, self.fsdp_group, - self, + cache_name is not None, skip_fp8_weight_update, self.symmetric_ar_type, self.save_original_input, debug, ) - out = linear_fn( + out, new_weight_workspace = linear_fn( *autograd_ctx, weight_tensor, + weight_workspace, inp, bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, non_tensor_args, ) + + if new_weight_workspace is not None and cache_name is not None: + self._fp8_workspaces[cache_name] = new_weight_workspace + finally: self.end_forward() if self.gemm_bias_unfused_add: From e803782eb49863cb2b7404f72ec5c29030b38b2f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:05:23 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 4 +-- .../pytorch/module/layernorm_mlp.py | 29 ++++++++++++------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 51be4bb914..3c41f1da62 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -749,9 +749,7 @@ def quantize_weight( # Cache miss — create new workspace if tensor is None or quantizer is None: - raise ValueError( - "tensor and quantizer kwargs must be provided to construct FP8 workspace" - ) + raise ValueError("tensor and quantizer kwargs must be provided to construct FP8 workspace") if cache: saved_internal = quantizer.internal quantizer.internal = False diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f1fb234b83..48dc60fbbf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -938,14 +938,25 @@ def _recompute(ctx): # Unpack saved tensors and pass None for weight workspaces (recomputed from scratch) ( - inp_r, ln_weight_r, ln_bias_r, - fc1_weight_r, fc1_bias_r, fc2_weight_r, fc2_bias_r, + inp_r, + ln_weight_r, + ln_bias_r, + fc1_weight_r, + fc1_bias_r, + fc2_weight_r, + fc2_bias_r, ) = tensors out = _LayerNormMLP._forward( # recompute ctx, - inp_r, ln_weight_r, ln_bias_r, - fc1_weight_r, None, fc1_bias_r, - fc2_weight_r, None, fc2_bias_r, + inp_r, + ln_weight_r, + ln_bias_r, + fc1_weight_r, + None, + fc1_bias_r, + fc2_weight_r, + None, + fc2_bias_r, tuple(ctx.other_args.values()), ) @@ -2126,14 +2137,10 @@ def forward( cache_name_fc1 = None if is_first_microbatch is None else "fc1_weight" cache_name_fc2 = None if is_first_microbatch is None else "fc2_weight" fc1_weight_workspace = ( - self._fp8_workspaces.get(cache_name_fc1) - if cache_name_fc1 is not None - else None + self._fp8_workspaces.get(cache_name_fc1) if cache_name_fc1 is not None else None ) fc2_weight_workspace = ( - self._fp8_workspaces.get(cache_name_fc2) - if cache_name_fc2 is not None - else None + self._fp8_workspaces.get(cache_name_fc2) if cache_name_fc2 is not None else None ) non_tensor_args = ( From 64e8ac29f87fe06489cb749f53e03e4034e6cbd6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 23 Mar 2026 15:07:06 +0100 Subject: [PATCH 3/5] Remove unused get_weight_workspace wrapper No callers remain after the quantize_weight refactor. Signed-off-by: Pawel Gadzinski Made-with: Cursor --- transformer_engine/pytorch/module/base.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3c41f1da62..47228a99cd 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1517,27 +1517,6 @@ def clear(self): def forward(self): """Needs override.""" - def get_weight_workspace(self, **kwargs) -> "QuantizedTensor": - """Get workspace buffer for weights and maybe update its values. - - Thin wrapper around :func:`quantize_weight` that manages the - ``_fp8_workspaces`` cache on the module. - - See :func:`quantize_weight` for the full parameter list. The only - difference is that *cache_name* controls lookup/storage in - ``self._fp8_workspaces``. - """ - cache_name = kwargs.pop("cache_name", None) - workspace = self._fp8_workspaces.get(cache_name) if cache_name is not None else None - result, new_workspace = quantize_weight( - workspace=workspace, - cache=cache_name is not None, - **kwargs, - ) - if new_workspace is not None and cache_name is not None: - self._fp8_workspaces[cache_name] = new_workspace - return result - def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): From 29ae3931d5adb689ce758490b008d8a515090b85 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 23 Mar 2026 15:31:11 +0100 Subject: [PATCH 4/5] Return workspaces from _GroupedLinear via tuple instead of mutable list Signed-off-by: Pawel Gadzinski Made-with: Cursor --- .../pytorch/module/grouped_linear.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 86e48bc2da..f609d18490 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -70,7 +70,7 @@ def forward( inp: torch.Tensor, non_tensor_args: Tuple, *weights_and_biases, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, list]: # pylint: disable=missing-function-docstring # Reduce number of arguments to autograd function in order @@ -94,7 +94,6 @@ def forward( activation_dtype, is_grad_enabled, weight_workspaces, - new_workspaces_out, cache_weight, skip_fp8_weight_update, save_original_input, @@ -169,11 +168,12 @@ def forward( # Initialize weights weights_fp8: list + new_workspaces = [None] * num_gemms if fp8 or debug: weights_fp8 = [] update_ws = is_first_microbatch is None or is_first_microbatch for i in range(num_gemms): - weight_fp8, new_ws = quantize_weight( + weight_fp8, new_workspaces[i] = quantize_weight( tensor=weights[i], quantizer=weight_quantizers[i], workspace=weight_workspaces[i] if weight_workspaces else None, @@ -183,8 +183,6 @@ def forward( cache=cache_weight, ) weights_fp8.append(weight_fp8) - if new_workspaces_out is not None: - new_workspaces_out[i] = new_ws else: weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] @@ -315,10 +313,10 @@ def forward( ctx.input_quantizers = input_quantizers # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp.shape[1:-1], out.shape[-1]) + return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + def backward(ctx, grad_output: torch.Tensor, _grad_workspaces) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx) @@ -998,7 +996,6 @@ def forward( self._fp8_workspaces.get(f"weight{i}") if cache_weight else None for i in range(num_gemms) ] - new_workspaces_out = [None] * num_gemms if cache_weight else None non_tensor_args = ( m_splits, @@ -1019,16 +1016,17 @@ def forward( self.activation_dtype, is_grad_enabled, weight_workspaces, - new_workspaces_out, cache_weight, None, # skip_fp8_weight_update self.save_original_input, debug, ) - out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + out, new_workspaces = linear_fn( + *autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors + ) - if new_workspaces_out is not None: - for i, ws in enumerate(new_workspaces_out): + if cache_weight: + for i, ws in enumerate(new_workspaces): if ws is not None: self._fp8_workspaces[f"weight{i}"] = ws From d978de252dd859b7d1b4d62281a7b08594772eec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:32:07 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f609d18490..385a266b5b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -316,7 +316,9 @@ def forward( return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces @staticmethod - def backward(ctx, grad_output: torch.Tensor, _grad_workspaces) -> Tuple[Union[torch.Tensor, None], ...]: + def backward( + ctx, grad_output: torch.Tensor, _grad_workspaces + ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx)