diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..47228a99cd 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -635,6 +635,131 @@ def fill_userbuffers_buffer_for_all_gather( raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})") +def _is_weight_workspace_valid( + workspace: QuantizedTensorStorage, + quantizer: Quantizer, +) -> bool: + """Check if a cached weight workspace is compatible with the quantizer's current usage.""" + if isinstance(workspace, Float8TensorStorage): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and workspace._transpose is None + ): + return False + elif isinstance(workspace, MXFP8TensorStorage): + if quantizer.rowwise_usage and workspace._rowwise_data is None: + return False + if quantizer.columnwise_usage and workspace._columnwise_data is None: + return False + elif isinstance(workspace, NVFP4TensorStorage): + if quantizer.rowwise_usage and workspace._rowwise_data is None: + return False + if quantizer.columnwise_usage and workspace._columnwise_data is None: + return False + if isinstance(workspace, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + return False + return True + + +def quantize_weight( + *, + tensor: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + workspace: Optional[QuantizedTensorStorage] = None, + update_workspace: bool = True, + skip_update_flag: Optional[torch.Tensor] = None, + fsdp_group: Optional["dist_group_type"] = None, + workspace_dtype: Optional[torch.dtype] = None, + cache: bool = False, +) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]: + """Quantize a weight tensor, optionally reusing a cached workspace. + + This is a standalone function (no module reference) that can be called + from inside ``torch.autograd.Function.forward``. + + Parameters + ---------- + tensor: torch.Tensor, optional + Weight tensor to quantize. + quantizer: Quantizer, optional + Quantizer for casting the weight. + workspace: QuantizedTensorStorage, optional + Previously cached workspace (from the module's ``_fp8_workspaces``). + ``None`` indicates a cache miss. + update_workspace: bool, default = True + Whether to update an existing workspace with fresh values. + skip_update_flag: torch.Tensor, optional + GPU flag to conditionally skip the update. + fsdp_group: dist_group_type, optional + FSDP process group the weights are distributed over. + workspace_dtype: torch.dtype, optional + High-precision dtype for debug quantization workspaces. + cache: bool, default = False + If ``True`` and a new workspace is created, it will be returned + as the second element so the caller can store it. + + Returns + ------- + (weightmat, new_workspace) + *weightmat*: quantized weight ready for GEMM. + *new_workspace*: non-``None`` only when a brand-new workspace was + created **and** ``cache=True``. The caller should store it in + ``_fp8_workspaces``. + """ + + # Already-quantized weight (primary FP8 parameters) + if isinstance(tensor, QuantizedTensor): + update_rowwise = True if quantizer.rowwise_usage else None + update_columnwise = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise, + columnwise_usage=update_columnwise, + ) + if isinstance(quantizer, DebugQuantizer): + tensor = quantizer.wrap_quantized_tensor(tensor) + return tensor, None + + # Validate workspace + if workspace is not None and quantizer is not None: + if not _is_weight_workspace_valid(workspace, quantizer): + workspace = None + + # FSDP gather on cached workspace + if ( + workspace is not None + and tensor is not None + and fsdp_group is not None + and workspace.data.shape != tensor.data.shape + ): + _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], workspace) + + # Cache hit — update in-place and return + if workspace is not None: + if skip_update_flag is not None: + update_workspace = True + if update_workspace: + if tensor is None: + raise ValueError("tensor kwarg must be provided to update FP8 workspace") + if hasattr(workspace, "quantize_"): + workspace.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, workspace, skip_update_flag) + return workspace, None + + # Cache miss — create new workspace + if tensor is None or quantizer is None: + raise ValueError("tensor and quantizer kwargs must be provided to construct FP8 workspace") + if cache: + saved_internal = quantizer.internal + quantizer.internal = False + out = quantizer.quantize(tensor, dtype=workspace_dtype) + if cache: + quantizer.internal = saved_internal + return out, out + return out, None + + class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" @@ -1392,135 +1517,6 @@ def clear(self): def forward(self): """Needs override.""" - def get_weight_workspace( - self, - *, - tensor: Optional[torch.Tensor] = None, - quantizer: Optional[Quantizer] = None, - cache_name: Optional[str] = None, - update_workspace: bool = True, - skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: Optional[dist_group_type] = None, - workspace_dtype: Optional[torch.dtype] = None, - ) -> QuantizedTensor: - """Get workspace buffer for weights and maybe update its values - - The workspace buffer may be cached for future function calls. - - Parameters - ---------- - tensor : torch.Tensor, optional - Values to copy into workspace. Required if the workspace - is being constructed or updated. - quantizer: Quantizer, optional - Quantizer used to cast the weights. Required if the - workspace is being constructed or updated. - cache_name: str, optional - Key for caching. - update_workspace: bool, default = True - Update workspace with values from `tensor`. - skip_update_flag: torch.Tensor, optional - GPU flag to skip updating the workspace. Take precedence - over `update_workspace` if provided. - fsdp_group: bool, default = None - FSDP process group that the weights are distributed over. - workspace_dtype: torch.dtype, default = None - If weight workspace contains high-precision tensor - for example - for debug quantization, this is dtype of the tensor. - """ - - # Handle case where weights are already quantized - # Note: Make sure weights have required usages, but do not - # destroy unnecessary usages since they may be used later. - if isinstance(tensor, QuantizedTensor): - update_rowwise_usage = True if quantizer.rowwise_usage else None - update_columnwise_usage = True if quantizer.columnwise_usage else None - tensor.update_usage( - rowwise_usage=update_rowwise_usage, - columnwise_usage=update_columnwise_usage, - ) - - if isinstance(quantizer, DebugQuantizer): - tensor = quantizer.wrap_quantized_tensor(tensor) - - return tensor - - # Try getting workspace from cache - out = None - if cache_name is not None: - out = self._fp8_workspaces.get(cache_name, None) - - # Reset cache if workspace is invalid - if out is not None and quantizer is not None: - reset_cache = False - if isinstance(out, Float8TensorStorage): - if ( - not is_non_tn_fp8_gemm_supported() - and quantizer.columnwise_usage - and out._transpose is None - ): - reset_cache = True - elif isinstance(out, MXFP8TensorStorage): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - elif isinstance(out, NVFP4TensorStorage): - if quantizer.rowwise_usage and out._rowwise_data is None: - reset_cache = True - elif quantizer.columnwise_usage and out._columnwise_data is None: - reset_cache = True - if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): - reset_cache = True - if reset_cache: - out = None - del self._fp8_workspaces[cache_name] - - # Gather cached Fp8 workspace if it's distributed - # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work - # for models initialized with Fp8 primary weights. - if ( - out is not None - and tensor is not None - and fsdp_group is not None - and out.data.shape != tensor.data.shape - ): - _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) - - # Construct workspace if needed - if out is None: - if tensor is None or quantizer is None: - raise ValueError( - "tensor and quantizer kwargs must be provided to construct FP8 workspace" - ) - - if cache_name is not None: - # Ensure the tensor in the cache is an instance of torch.Tensor, - # as it persists beyond a single forward pass. - # Setting internal=True would cause the data to be removed in prepare_for_saving(...). - quantizer_internal = quantizer.internal - quantizer.internal = False - out = quantizer.quantize(tensor, dtype=workspace_dtype) - if cache_name is not None: - quantizer.internal = quantizer_internal - - # Update cache - if cache_name is not None: - self._fp8_workspaces[cache_name] = out - return out - - # Update workspace if needed - if skip_update_flag is not None: - update_workspace = True - if update_workspace: - if tensor is None: - raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if hasattr(out, "quantize_"): - out.quantize_(tensor, noop_flag=skip_update_flag) - else: - tex.quantize(tensor, quantizer, out, skip_update_flag) - return out - def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0adda48e36..385a266b5b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -16,6 +16,7 @@ from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, + quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -69,7 +70,7 @@ def forward( inp: torch.Tensor, non_tensor_args: Tuple, *weights_and_biases, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, list]: # pylint: disable=missing-function-docstring # Reduce number of arguments to autograd function in order @@ -92,7 +93,8 @@ def forward( sequence_parallel, activation_dtype, is_grad_enabled, - module, + weight_workspaces, + cache_weight, skip_fp8_weight_update, save_original_input, debug, @@ -166,18 +168,19 @@ def forward( # Initialize weights weights_fp8: list + new_workspaces = [None] * num_gemms if fp8 or debug: - # FP8 cast to workspace buffer weights_fp8 = [] - update_workspace = is_first_microbatch is None or is_first_microbatch + update_ws = is_first_microbatch is None or is_first_microbatch for i in range(num_gemms): - weight_fp8 = module.get_weight_workspace( + weight_fp8, new_workspaces[i] = quantize_weight( tensor=weights[i], quantizer=weight_quantizers[i], - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, + workspace=weight_workspaces[i] if weight_workspaces else None, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, workspace_dtype=activation_dtype, + cache=cache_weight, ) weights_fp8.append(weight_fp8) @@ -310,10 +313,12 @@ def forward( ctx.input_quantizers = input_quantizers # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp.shape[1:-1], out.shape[-1]) + return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + def backward( + ctx, grad_output: torch.Tensor, _grad_workspaces + ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx) @@ -987,6 +992,13 @@ def forward( linear_fn = _GroupedLinear.forward autograd_ctx = [None] + num_gemms = len(m_splits) + cache_weight = is_first_microbatch is not None + weight_workspaces = [ + self._fp8_workspaces.get(f"weight{i}") if cache_weight else None + for i in range(num_gemms) + ] + non_tensor_args = ( m_splits, self.apply_bias, @@ -1005,12 +1017,20 @@ def forward( self.sequence_parallel, self.activation_dtype, is_grad_enabled, - self, + weight_workspaces, + cache_weight, None, # skip_fp8_weight_update self.save_original_input, debug, ) - out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + out, new_workspaces = linear_fn( + *autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors + ) + + if cache_weight: + for i, ws in enumerate(new_workspaces): + if ws is not None: + self._fp8_workspaces[f"weight{i}"] = ws finally: self.end_forward() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ed91bc1235..42d2ee13b8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -20,6 +20,7 @@ from .base import ( fill_userbuffers_buffer_for_all_gather, get_ub, + quantize_weight, TransformerEngineBaseModule, get_dummy_wgrad, _2X_ACC_FPROP, @@ -93,9 +94,10 @@ def forward( ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], weight: torch.Tensor, + weight_workspace: Optional[torch.Tensor], bias: torch.Tensor, non_tensor_args: Tuple, - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring # Reduce number of arguments to autograd function in order @@ -135,7 +137,7 @@ def forward( ub_bulk_dgrad, ub_name, fsdp_group, - module, + cache_weight, skip_fp8_weight_update, symmetric_ar_type, debug, @@ -284,6 +286,7 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + new_weight_workspace = None weightmat = weight is_weight_param_quantized = False if fp8 or debug: @@ -298,15 +301,16 @@ def forward( weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) # Get quantized weight - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( + update_ws = is_first_microbatch is None or is_first_microbatch + weightmat, new_weight_workspace = quantize_weight( tensor=weight, quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, + workspace=weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) weightmat.update_usage(rowwise_usage=True) @@ -526,13 +530,15 @@ def forward( # Cached state for backward pass is ready... # ------------------------------------------------------ + ln_out_for_return = None if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp_shape) shape[0] *= tp_size if with_input_all_gather else 1 - return out, ln_out_return.view(shape) - return out, ln_out_return.view(inp_shape) - return out + ln_out_for_return = ln_out_return.view(shape) + else: + ln_out_for_return = ln_out_return.view(inp_shape) + return out, ln_out_for_return, new_weight_workspace @staticmethod def backward( @@ -1019,6 +1025,7 @@ def wgrad_gemm( dgamma, dbeta, wgrad, + None, # weight_workspace grad_bias, None, ) @@ -1538,6 +1545,11 @@ def forward( else: fwd_fn = _LayerNormLinear.forward autograd_ctx = [None] + cache_name = None if is_first_microbatch is None else "weight" + weight_workspace = ( + self._fp8_workspaces.get(cache_name) if cache_name is not None else None + ) + non_tensor_args = ( self.eps, is_first_microbatch, @@ -1573,27 +1585,28 @@ def forward( self.ub_bulk_dgrad, self.ub_name, self.fsdp_group, - self, + cache_name is not None, skip_fp8_weight_update, self.symmetric_ar_type, debug, ) - out = fwd_fn( + out, ln_out, new_weight_workspace = fwd_fn( *autograd_ctx, inp, self.layer_norm_weight, self.layer_norm_bias, weight_tensor, + weight_workspace, bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, non_tensor_args, ) + if new_weight_workspace is not None and cache_name is not None: + self._fp8_workspaces[cache_name] = new_weight_workspace + finally: self.end_forward() - if self.return_layernorm_output: - out, ln_out = out - if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index cc3dcc4064..48dc60fbbf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -22,6 +22,7 @@ fill_userbuffers_buffer_for_all_gather, _ub_communicators, get_ub, + quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -175,8 +176,10 @@ def _forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, + fc1_weight_workspace: Optional[torch.Tensor], fc1_bias: torch.Tensor, fc2_weight: torch.Tensor, + fc2_weight_workspace: Optional[torch.Tensor], fc2_bias: torch.Tensor, non_tensor_args: Tuple, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: @@ -227,7 +230,8 @@ def _forward( ub_bulk_dgrad, gemm_gelu_fusion, fsdp_group, - module, + fp8_meta, + cache_weight, skip_fp8_weight_update, symmetric_ar_type, checkpoint, @@ -249,7 +253,7 @@ def _forward( == "DelayedScaling" ): # only applicable for delayed scaling FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute( - module.fp8_meta + fp8_meta ) # to restore quantizers during recomputation # save the rng states ctx.cpu_rng_state = torch.get_rng_state() @@ -320,7 +324,8 @@ def _forward( "ub_bulk_dgrad": ub_bulk_dgrad, "gemm_gelu_fusion": gemm_gelu_fusion, "fsdp_group": fsdp_group, - "module": module, + "fp8_meta": fp8_meta, + "cache_weight": False, "skip_fp8_weight_update": skip_fp8_weight_update, "symmetric_ar_type": symmetric_ar_type, "checkpoint": checkpoint, @@ -463,13 +468,12 @@ def _forward( ln_out_total = ln_out # Cast weights to expected dtype + new_fc1_weight_workspace = None + new_fc2_weight_workspace = None fc1_weight_final = fc1_weight fc2_weight_final = fc2_weight if fp8 or debug: - # If weights are not quantized, we call get_weight_workspace, - # which handles weight caching etc. - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch + update_ws = is_first_microbatch is None or is_first_microbatch # No need to set the quantizer states if weights are already quantized # for debug mode we create quantizer every iteration, thus we need to set the quantizer states if isinstance(fc1_weight, QuantizedTensorStorage) and not debug: @@ -482,23 +486,25 @@ def _forward( elif fc2_weight_quantizer is not None: fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - fc1_weight_final = module.get_weight_workspace( + fc1_weight_final, new_fc1_weight_workspace = quantize_weight( tensor=fc1_weight, quantizer=fc1_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc1_weight"), - update_workspace=update_workspace, + workspace=fc1_weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) - fc2_weight_final = module.get_weight_workspace( + fc2_weight_final, new_fc2_weight_workspace = quantize_weight( tensor=fc2_weight, quantizer=fc2_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc2_weight"), - update_workspace=update_workspace, + workspace=fc2_weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) fc1_weight_final.update_usage(rowwise_usage=True) fc2_weight_final.update_usage(rowwise_usage=True) @@ -857,13 +863,15 @@ def _forward( ) # we only get to this point if we are not recomputing for bwd, since that would have returned in the block above + ln_out_for_return = None if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp_shape) shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1 - return fc2_out, ln_out_return.view(shape) - return fc2_out, ln_out_return.view(inp_shape) - return fc2_out + ln_out_for_return = ln_out_return.view(shape) + else: + ln_out_for_return = ln_out_return.view(inp_shape) + return fc2_out, ln_out_for_return, new_fc1_weight_workspace, new_fc2_weight_workspace @staticmethod def forward( @@ -872,11 +880,13 @@ def forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, + fc1_weight_workspace: Optional[torch.Tensor], fc1_bias: torch.Tensor, fc2_weight: torch.Tensor, + fc2_weight_workspace: Optional[torch.Tensor], fc2_bias: torch.Tensor, non_tensor_args: Tuple, - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring # add recompute_for_bwd @@ -888,8 +898,10 @@ def forward( ln_weight, ln_bias, fc1_weight, + fc1_weight_workspace, fc1_bias, fc2_weight, + fc2_weight_workspace, fc2_bias, non_tensor_args, ) @@ -913,7 +925,7 @@ def _recompute(ctx): and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" ): # only applicable for delayed scaling FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute( - ctx.other_args["module"].fp8_meta + ctx.other_args["fp8_meta"] ) # set old quantizer state # get current rng state @@ -924,9 +936,27 @@ def _recompute(ctx): torch.set_rng_state(ctx.cpu_rng_state) _set_cuda_rng_state(ctx.cuda_rng_state) + # Unpack saved tensors and pass None for weight workspaces (recomputed from scratch) + ( + inp_r, + ln_weight_r, + ln_bias_r, + fc1_weight_r, + fc1_bias_r, + fc2_weight_r, + fc2_bias_r, + ) = tensors out = _LayerNormMLP._forward( # recompute ctx, - *tensors, + inp_r, + ln_weight_r, + ln_bias_r, + fc1_weight_r, + None, + fc1_bias_r, + fc2_weight_r, + None, + fc2_bias_r, tuple(ctx.other_args.values()), ) @@ -936,7 +966,7 @@ def _recompute(ctx): and FP8GlobalStateManager.get_fp8_recipe().__class__.__name__ == "DelayedScaling" ): FP8GlobalStateManager.restore_fp8_meta_tensors( - ctx.other_args["module"].fp8_meta + ctx.other_args["fp8_meta"] ) # restore quantizers # set rng state for fwd @@ -1637,8 +1667,10 @@ def fc1_wgrad_gemm( dgamma, dbeta, fc1_wgrad, + None, # fc1_weight_workspace fc1_bias_grad if fc1_bias is not None else None, fc2_wgrad, # pylint: disable=possibly-used-before-assignment + None, # fc2_weight_workspace fc2_bias_grad, None, ) @@ -2102,6 +2134,15 @@ def forward( fwd_fn = _LayerNormMLP.forward autograd_ctx = [None] + cache_name_fc1 = None if is_first_microbatch is None else "fc1_weight" + cache_name_fc2 = None if is_first_microbatch is None else "fc2_weight" + fc1_weight_workspace = ( + self._fp8_workspaces.get(cache_name_fc1) if cache_name_fc1 is not None else None + ) + fc2_weight_workspace = ( + self._fp8_workspaces.get(cache_name_fc2) if cache_name_fc2 is not None else None + ) + non_tensor_args = ( self.eps, is_first_microbatch, @@ -2145,30 +2186,35 @@ def forward( self.ub_bulk_wgrad, self.gemm_gelu_fusion and not debug, self.fsdp_group, - self, + self.fp8_meta, + cache_name_fc1 is not None, skip_fp8_weight_update, self.symmetric_ar_type, self.checkpoint, debug, ) - out = fwd_fn( + out, ln_out, new_fc1_ws, new_fc2_ws = fwd_fn( *autograd_ctx, inp, self.layer_norm_weight, self.layer_norm_bias, fc1_weight, + fc1_weight_workspace, fc1_bias, fc2_weight, + fc2_weight_workspace, fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, non_tensor_args, ) + if new_fc1_ws is not None and cache_name_fc1 is not None: + self._fp8_workspaces[cache_name_fc1] = new_fc1_ws + if new_fc2_ws is not None and cache_name_fc2 is not None: + self._fp8_workspaces[cache_name_fc2] = new_fc2_ws + finally: self.end_forward() - if self.return_layernorm_output: - out, ln_out = out - if self.gemm_bias_unfused_add: out = out + cast_if_needed(fc2_bias, self.activation_dtype) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ea921341a4..d4dfa9277b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -19,6 +19,7 @@ fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, + quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -87,10 +88,11 @@ class _Linear(torch.autograd.Function): def forward( ctx, weight: torch.Tensor, + weight_workspace: Optional[torch.Tensor], inp: torch.Tensor, bias: Optional[torch.Tensor], non_tensor_args: Tuple, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # pylint: disable=missing-function-docstring ( @@ -122,7 +124,7 @@ def forward( ub_name, fp8_output, # pylint: disable=unused-variable fsdp_group, - module, + cache_weight, skip_fp8_weight_update, symmetric_ar_type, save_original_input, @@ -247,6 +249,7 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + new_weight_workspace = None weightmat = weight if fp8 or debug: # Configure quantizer @@ -261,18 +264,18 @@ def forward( ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) elif isinstance(weight, QuantizedTensor): - # If weight is already quantized, no need to set quantizer states weight_quantizer = weight._quantizer # Get quantized weight - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( + update_ws = is_first_microbatch is None or is_first_microbatch + weightmat, new_weight_workspace = quantize_weight( tensor=weight, quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, + workspace=weight_workspace, + update_workspace=update_ws, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + cache=cache_weight, ) weightmat.update_usage(rowwise_usage=True) @@ -489,10 +492,12 @@ def forward( # Cached state for backward pass is ready... # ------------------------------------------------------ - return out + return out, new_weight_workspace @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + def backward( + ctx, grad_output: torch.Tensor, _grad_weight_workspace + ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring # NVTX label for profiling @@ -971,6 +976,7 @@ def wgrad_gemm( _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( wgrad, + None, # weight_workspace dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, None, @@ -1418,6 +1424,11 @@ def forward( linear_fn = _Linear.forward autograd_ctx = [None] + cache_name = None if is_first_microbatch is None else "weight" + weight_workspace = ( + self._fp8_workspaces.get(cache_name) if cache_name is not None else None + ) + non_tensor_args = ( is_first_microbatch, self.fp8, @@ -1447,19 +1458,24 @@ def forward( self.ub_name, fp8_output, self.fsdp_group, - self, + cache_name is not None, skip_fp8_weight_update, self.symmetric_ar_type, self.save_original_input, debug, ) - out = linear_fn( + out, new_weight_workspace = linear_fn( *autograd_ctx, weight_tensor, + weight_workspace, inp, bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, non_tensor_args, ) + + if new_weight_workspace is not None and cache_name is not None: + self._fp8_workspaces[cache_name] = new_weight_workspace + finally: self.end_forward() if self.gemm_bias_unfused_add: