Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 125 additions & 129 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,131 @@ def fill_userbuffers_buffer_for_all_gather(
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")


def _is_weight_workspace_valid(
workspace: QuantizedTensorStorage,
quantizer: Quantizer,
) -> bool:
"""Check if a cached weight workspace is compatible with the quantizer's current usage."""
if isinstance(workspace, Float8TensorStorage):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and workspace._transpose is None
):
return False
elif isinstance(workspace, MXFP8TensorStorage):
if quantizer.rowwise_usage and workspace._rowwise_data is None:
return False
if quantizer.columnwise_usage and workspace._columnwise_data is None:
return False
elif isinstance(workspace, NVFP4TensorStorage):
if quantizer.rowwise_usage and workspace._rowwise_data is None:
return False
if quantizer.columnwise_usage and workspace._columnwise_data is None:
return False
if isinstance(workspace, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
return False
return True


def quantize_weight(
*,
tensor: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
workspace: Optional[QuantizedTensorStorage] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional["dist_group_type"] = None,
workspace_dtype: Optional[torch.dtype] = None,
cache: bool = False,
) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]:
"""Quantize a weight tensor, optionally reusing a cached workspace.

This is a standalone function (no module reference) that can be called
from inside ``torch.autograd.Function.forward``.

Parameters
----------
tensor: torch.Tensor, optional
Weight tensor to quantize.
quantizer: Quantizer, optional
Quantizer for casting the weight.
workspace: QuantizedTensorStorage, optional
Previously cached workspace (from the module's ``_fp8_workspaces``).
``None`` indicates a cache miss.
update_workspace: bool, default = True
Whether to update an existing workspace with fresh values.
skip_update_flag: torch.Tensor, optional
GPU flag to conditionally skip the update.
fsdp_group: dist_group_type, optional
FSDP process group the weights are distributed over.
workspace_dtype: torch.dtype, optional
High-precision dtype for debug quantization workspaces.
cache: bool, default = False
If ``True`` and a new workspace is created, it will be returned
as the second element so the caller can store it.

Returns
-------
(weightmat, new_workspace)
*weightmat*: quantized weight ready for GEMM.
*new_workspace*: non-``None`` only when a brand-new workspace was
created **and** ``cache=True``. The caller should store it in
``_fp8_workspaces``.
"""

# Already-quantized weight (primary FP8 parameters)
if isinstance(tensor, QuantizedTensor):
update_rowwise = True if quantizer.rowwise_usage else None
update_columnwise = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise,
columnwise_usage=update_columnwise,
)
if isinstance(quantizer, DebugQuantizer):
tensor = quantizer.wrap_quantized_tensor(tensor)
return tensor, None

# Validate workspace
if workspace is not None and quantizer is not None:
if not _is_weight_workspace_valid(workspace, quantizer):
workspace = None

# FSDP gather on cached workspace
if (
workspace is not None
and tensor is not None
and fsdp_group is not None
and workspace.data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], workspace)

# Cache hit — update in-place and return
if workspace is not None:
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if hasattr(workspace, "quantize_"):
workspace.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, workspace, skip_update_flag)
return workspace, None

# Cache miss — create new workspace
if tensor is None or quantizer is None:
raise ValueError("tensor and quantizer kwargs must be provided to construct FP8 workspace")
if cache:
saved_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache:
quantizer.internal = saved_internal
return out, out
return out, None


class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""

Expand Down Expand Up @@ -1392,135 +1517,6 @@ def clear(self):
def forward(self):
"""Needs override."""

def get_weight_workspace(
self,
*,
tensor: Optional[torch.Tensor] = None,
quantizer: Optional[Quantizer] = None,
cache_name: Optional[str] = None,
update_workspace: bool = True,
skip_update_flag: Optional[torch.Tensor] = None,
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get workspace buffer for weights and maybe update its values

The workspace buffer may be cached for future function calls.

Parameters
----------
tensor : torch.Tensor, optional
Values to copy into workspace. Required if the workspace
is being constructed or updated.
quantizer: Quantizer, optional
Quantizer used to cast the weights. Required if the
workspace is being constructed or updated.
cache_name: str, optional
Key for caching.
update_workspace: bool, default = True
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
over `update_workspace` if provided.
fsdp_group: bool, default = None
FSDP process group that the weights are distributed over.
workspace_dtype: torch.dtype, default = None
If weight workspace contains high-precision tensor - for example
for debug quantization, this is dtype of the tensor.
"""

# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)

if isinstance(quantizer, DebugQuantizer):
tensor = quantizer.wrap_quantized_tensor(tensor)

return tensor

# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)

# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorStorage):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
elif isinstance(out, NVFP4TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
if reset_cache:
out = None
del self._fp8_workspaces[cache_name]

# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
if (
out is not None
and tensor is not None
and fsdp_group is not None
and out.data.shape != tensor.data.shape
):
_fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

# Construct workspace if needed
if out is None:
if tensor is None or quantizer is None:
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)

if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal

# Update cache
if cache_name is not None:
self._fp8_workspaces[cache_name] = out
return out

# Update workspace if needed
if skip_update_flag is not None:
update_workspace = True
if update_workspace:
if tensor is None:
raise ValueError("tensor kwarg must be provided to update FP8 workspace")
if hasattr(out, "quantize_"):
out.quantize_(tensor, noop_flag=skip_update_flag)
else:
tex.quantize(tensor, quantizer, out, skip_update_flag)
return out

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
Expand Down
42 changes: 31 additions & 11 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from .base import (
get_dummy_wgrad,
quantize_weight,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
Expand Down Expand Up @@ -69,7 +70,7 @@ def forward(
inp: torch.Tensor,
non_tensor_args: Tuple,
*weights_and_biases,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, list]:
# pylint: disable=missing-function-docstring

# Reduce number of arguments to autograd function in order
Expand All @@ -92,7 +93,8 @@ def forward(
sequence_parallel,
activation_dtype,
is_grad_enabled,
module,
weight_workspaces,
cache_weight,
skip_fp8_weight_update,
save_original_input,
debug,
Expand Down Expand Up @@ -166,18 +168,19 @@ def forward(

# Initialize weights
weights_fp8: list
new_workspaces = [None] * num_gemms
if fp8 or debug:
# FP8 cast to workspace buffer
weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch
update_ws = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
weight_fp8, new_workspaces[i] = quantize_weight(
tensor=weights[i],
quantizer=weight_quantizers[i],
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
workspace=weight_workspaces[i] if weight_workspaces else None,
update_workspace=update_ws,
skip_update_flag=skip_fp8_weight_update,
workspace_dtype=activation_dtype,
cache=cache_weight,
)
weights_fp8.append(weight_fp8)

Expand Down Expand Up @@ -310,10 +313,12 @@ def forward(
ctx.input_quantizers = input_quantizers

# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
def backward(
ctx, grad_output: torch.Tensor, _grad_workspaces
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_func_ctx(ctx)
Expand Down Expand Up @@ -987,6 +992,13 @@ def forward(
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

num_gemms = len(m_splits)
cache_weight = is_first_microbatch is not None
weight_workspaces = [
self._fp8_workspaces.get(f"weight{i}") if cache_weight else None
for i in range(num_gemms)
]

non_tensor_args = (
m_splits,
self.apply_bias,
Expand All @@ -1005,12 +1017,20 @@ def forward(
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
weight_workspaces,
cache_weight,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
out, new_workspaces = linear_fn(
*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors
)

if cache_weight:
for i, ws in enumerate(new_workspaces):
if ws is not None:
self._fp8_workspaces[f"weight{i}"] = ws

finally:
self.end_forward()
Expand Down
Loading
Loading