diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..ddb74fd636 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -36,6 +36,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, QuantizerFactory, QuantizeLayout, @@ -150,8 +151,13 @@ def assert_dequantized_grouped_scaled_tensor( a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray ): if isinstance(a, GroupedScaledTensor1x): - assert a.group_sizes.sum() == b.shape[0] - b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0) + group_sizes = ( + a.first_dims + if a.first_dims is not None + else jnp.ones(a.original_shape[0], dtype=jnp.int32) + ) + assert group_sizes.sum() == b.shape[0] + b = jnp.split(b, jnp.cumulative_sum(group_sizes)[:-1], axis=0) dq_a = a.dequantize() for dq_a_i, b_i in zip(dq_a, b): if len(dq_a_i) == 0: @@ -1787,13 +1793,18 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm + lhs_tensor = GroupedNoScaleTensor( + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape + ) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( - lhs, - rhs, - group_sizes, - contracting_dims, + lhs_tensor, + rhs_tensor, + contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1825,8 +1836,17 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape + ) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + lhs_tensor, + rhs_tensor, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, ) allclose_dtype = jnp.float8_e4m3fn diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index aaf8e8ecea..aaec5affa8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -37,6 +37,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, Quantizer, GroupedQuantizer, @@ -73,12 +74,14 @@ # Cache whether the CUDA-graphable grouped GEMM implementation is available at import time. # Calling get_grouped_gemm_setup_workspace_size raises a RuntimeError mentioning "cublas" when # compiled against cuBLAS < 13.2, in which case the cuda-graphable path is unavailable. +_v2_grouped_gemm_available_reason = "" try: get_grouped_gemm_setup_workspace_size(1) _v2_grouped_gemm_available = True except RuntimeError as e: if "cublas" in str(e).lower(): _v2_grouped_gemm_available = False + _v2_grouped_gemm_available_reason = str(e) else: raise @@ -1392,17 +1395,47 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) +def _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups: int, +) -> None: + """Assert that all non-empty *_dims arrays have exactly num_groups elements. + + rhs_first_dims / rhs_last_dims describe the ragged contracting K dimension. + K totals need not fill the entire buffer (padding is allowed), so only the + array length is checked, not the per-group sum. + """ + for name, aval in [ + ("lhs_first_dims", lhs_first_dims_aval), + ("lhs_last_dims", lhs_last_dims_aval), + ("out_first_dims", out_first_dims_aval), + ("out_last_dims", out_last_dims_aval), + ("rhs_first_dims", rhs_first_dims_aval), + ("rhs_last_dims", rhs_last_dims_aval), + ]: + if aval.size > 0: + assert ( + aval.size == num_groups + ), f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, + # lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, + # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26) inner_primitive = None outer_primitive = None @@ -1413,53 +1446,85 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - group_sizes_aval, + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, ): """ Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_data: Left-hand side input matrix data (may be 1D for quantized) lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_data: Right-hand side input matrix data (may be 1D for quantized) rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group + lhs_first_dims: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_first_dims: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_first_dims: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR * alpha: 1D array of shape (G,) containing alpha values for each group * beta: 1D array of shape (G,) containing beta values for each group - M: Number of rows in the output matrix - N: Number of columns in the output matrix - K: Number of columns in the left-hand side matrix lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided - is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation - where both lhs and rhs are 2D matrices and output is (G, M, N) + out_shape: Pre-computed output shape tuple + lhs_left_size: Product of lhs dims before axis_boundary + lhs_right_size: Product of lhs dims after axis_boundary + rhs_left_size: Product of rhs dims before axis_boundary + rhs_right_size: Product of rhs dims after axis_boundary Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval - del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + del lhs_data_aval, rhs_data_aval + del lhs_is_trans, rhs_is_trans + del lhs_axis_boundary, rhs_axis_boundary + del lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size + del bias_aval + del has_bias, use_async_d2h_group_sizes + + num_groups = ( + lhs_first_dims_aval.size + or lhs_last_dims_aval.size + or rhs_first_dims_aval.size + or rhs_last_dims_aval.size + or out_first_dims_aval.size + or out_last_dims_aval.size + or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 + ) - num_groups = group_sizes_aval.size + _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups, + ) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1470,9 +1535,6 @@ def abstract( dtype=jnp.uint8, ) - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (num_groups, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) if use_v2_ffi: @@ -1480,7 +1542,24 @@ def abstract( shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8 ) # Temporary buffer for int32 -> int64 conversion of group_sizes on device. - int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize + # Each non-empty *_dims buffer needs its own slot of num_groups int64 elements so that + # make_grouped_tensor can write to a distinct region per ragged dimension. Allocate + # exactly as many slots as there are non-empty buffers (minimum 1 to avoid zero-size). + num_ragged_dim_buffers = sum( + 1 + for aval in [ + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + ] + if aval.size > 0 + ) + int64_workspace_size = ( + max(num_ragged_dim_buffers, 1) * num_groups * jnp.dtype(jnp.int64).itemsize + ) int64_workspace_aval = jax.core.ShapedArray( shape=(int64_workspace_size,), dtype=jnp.uint8 ) @@ -1545,45 +1624,52 @@ def outer_abstract(*args, **kwargs): def lowering( ctx, *args, - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, ): - del out_dtype + del out_dtype, out_shape # Python-only; not forwarded to C++ if use_v2_ffi: ffi_name = GroupedGemmPrimitive.name_graph_safe return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + lhs_left_size=lhs_left_size, + lhs_right_size=lhs_right_size, + rhs_left_size=rhs_left_size, + rhs_right_size=rhs_right_size, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + lhs_left_size=lhs_left_size, + lhs_right_size=lhs_right_size, + rhs_left_size=rhs_left_size, + rhs_right_size=rhs_right_size, ) @staticmethod @@ -1593,20 +1679,28 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, ): if GroupedGemmPrimitive.inner_primitive is None: raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") @@ -1620,19 +1714,27 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, *additional_args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + out_shape=out_shape, + lhs_left_size=lhs_left_size, + lhs_right_size=lhs_right_size, + rhs_left_size=rhs_left_size, + rhs_right_size=rhs_right_size, ) return (out,) @@ -1922,6 +2024,12 @@ def grouped_gemm_copy_group_sizes( return out +@cache +def _should_enforce_v2_grouped_gemm() -> bool: + """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" + return os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") == "1" + + def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, @@ -1933,21 +2041,42 @@ def _can_use_v2_grouped_gemm( # feature-compatible with the main branch. # Bias can be supported in a kernel or in pure-JAX in the future. + enforce_v2_gmm = _should_enforce_v2_grouped_gemm() + if not _v2_grouped_gemm_available: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM is not available but NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is" + " enabled. The reason for V2 grouped GEMM not being available:" + f" {_v2_grouped_gemm_available_reason}" + ) return False # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). # Fall back to the v1 path on SM90 (Hopper) and older architectures. if get_device_compute_capability(0) < 100: + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM requires SM100+ (Blackwell or newer) but current device" + f" compute capability of GPU 0 is {get_device_compute_capability(0)} and" + " NVTE_JAX_ENFORCE_V2_GROUPED_GEMM is enabled." + ) return False - return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias: + return True + + if enforce_v2_gmm: + raise RuntimeError( + "The TE V2 grouped GEMM currently only supports BF16 with no quantization recipe and" + f" without bias, but received {scaling_mode=}, {dtype=}, {has_bias=}" + ) + return False def grouped_gemm( - lhs: Union[jnp.ndarray, GroupedScaledTensor1x], - rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - group_sizes: jnp.ndarray, + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -1960,9 +2089,8 @@ def grouped_gemm( Grouped GEMM operation. Args: - lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - group_sizes: 1D array containing the sizes of each group + lhs: Left-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x + rhs: Right-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -1972,49 +2100,74 @@ def grouped_gemm( Returns: A jnp.ndarray containing the result of the grouped GEMM operation - - Note: - Tested shapes: - lhs: [M, K] or [K, N] - rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ # TODO(Phuong): implement the precision del precision - if isinstance(lhs, jnp.ndarray): - if not isinstance(rhs, jnp.ndarray): - raise TypeError( - f"Expected rhs to be jnp.ndarray when lhs is jnp.ndarray, but got type={type(rhs)}" - ) - out_dtype = lhs.dtype - lhs_shape = lhs.shape - rhs_shape = rhs.shape - lhs_data = lhs - rhs_data = rhs - lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + empty_gs = jnp.empty((0,), jnp.int32) + + # Extract data, dims, and metadata from tensor objects. + # Keep data in its original layout (may be 1D for quantized tensors) to preserve + # JAX sharding; the C++ side uses original_shape to derive m/n/k. + if isinstance(lhs, GroupedNoScaleTensor): + lhs_data = lhs.data + lhs_shape = lhs.original_shape + lhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING + out_dtype = lhs.data.dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs elif isinstance(lhs, GroupedScaledTensor1x): - if not isinstance(rhs, GroupedScaledTensor1x): - raise TypeError( - "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but" - f" got type={type(rhs)}" - ) - out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape - rhs_shape = rhs.original_shape lhs_data = lhs.data - rhs_data = rhs.data lhs_scale_inv = lhs.scale_inv + scaling_mode = lhs.scaling_mode + out_dtype = lhs.dq_dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + else: + raise TypeError( + f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" + ) + + if isinstance(rhs, GroupedNoScaleTensor): + rhs_data = rhs.data + rhs_shape = rhs.original_shape + rhs_scale_inv = jnp.empty((0,), jnp.float32) + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + elif isinstance(rhs, GroupedScaledTensor1x): + rhs_shape = rhs.original_shape + rhs_data = rhs.data rhs_scale_inv = rhs.scale_inv - if lhs.scaling_mode != rhs.scaling_mode: + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: raise ValueError( f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," f" rhs.scaling_mode={rhs.scaling_mode}" ) - scaling_mode = lhs.scaling_mode + if isinstance(lhs, GroupedScaledTensor1x): + scaling_mode = lhs.scaling_mode else: - raise TypeError("Unsupported lhs type object!") + raise TypeError( + f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" + ) + + # Infer output dims from which operand has the ragged non-contracting dim. + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) + out_first_dims = empty_gs + out_last_dims = empty_gs + elif lhs_first_dims.size > 0: + out_first_dims = lhs_first_dims + out_last_dims = empty_gs + elif lhs_last_dims.size > 0: + out_first_dims = empty_gs + out_last_dims = lhs_last_dims + else: + out_first_dims = out_last_dims = empty_gs out_dtype = preferred_element_type or out_dtype @@ -2023,26 +2176,10 @@ def grouped_gemm( lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) - # rhs_shape [G, K, N] - rhs_is_trans = rhs_contract_dim[0] != 1 + # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). + rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - is_grouped_dense_wgrad = False - if len(rhs_shape) == 2: - rhs_is_trans = rhs_contract_dim[0] != 0 - is_grouped_dense_wgrad = True - - # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? - if ( - is_grouped_dense_wgrad - and not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - ): - lhs_is_trans = True - rhs_is_trans = False - lhs_flatten_axis = 1 - rhs_flatten_axis = 1 - if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2073,9 +2210,21 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + active_group_sizes = next( + ( + gs + for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0 + ), + empty_gs, + ) + lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data + rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data + lhs_q = grouped_quantize( + lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis + ) rhs_q = grouped_quantize( - rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) lhs_data = lhs_q.data rhs_data = rhs_q.data @@ -2110,38 +2259,66 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if group_sizes.size == rhs_shape[0]: + if ( + lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + ): # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Calling GroupedGEMM Custom Call - K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) - K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - if K_lhs != K_rhs: + # Compute N-D axis boundaries from final (post-adjustment) contracting dims. + lhs_axis_boundary = get_lhs_axis_boundary(lhs_contract_dim, lhs_is_trans) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_contract_dim, rhs_is_trans) + + num_gemms = ( + lhs_first_dims.size + or lhs_last_dims.size + or rhs_first_dims.size + or rhs_last_dims.size + or out_first_dims.size + or out_last_dims.size + ) + if num_gemms == 0: raise ValueError( - f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from" - f" lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + "grouped_gemm requires at least one non-empty dimension array. " + "Ensure lhs or rhs tensor objects carry first_dims or last_dims." ) - M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G - if is_grouped_dense_wgrad: - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) + # Pre-compute collapsed 2D sizes from original N-D shapes. + # These are static Python ints passed as primitive parameters (must be hashable). + lhs_left_size = math.prod(lhs_shape[:lhs_axis_boundary]) + lhs_right_size = math.prod(lhs_shape[lhs_axis_boundary:]) + rhs_left_size = math.prod(rhs_shape[:rhs_axis_boundary]) + rhs_right_size = math.prod(rhs_shape[rhs_axis_boundary:]) + + # Pre-compute output shape from N-D input shapes (static Python ints). + if lhs_is_trans: + lhs_non_contracting = lhs_shape[lhs_axis_boundary:] else: - if group_sizes.size != rhs_shape[0]: - raise ValueError( - "Expected group_sizes.size == rhs_shape[0], but got" - f" group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" - ) + lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + if rhs_is_trans: + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + # wgrad: rhs (e.g. grad_T of shape (N, M)) has no G batch dim; include all dims + rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary)) + else: + # fwd/dgrad: rhs (e.g. kernel_T of shape (G, N, K)) has G batch dim at dim 0; skip it + rhs_non_contracting = tuple(rhs_shape[d] for d in range(rhs_axis_boundary) if d != 0) + else: + rhs_non_contracting = rhs_shape[rhs_axis_boundary:] + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + out_shape = (num_gemms, *lhs_non_contracting, *rhs_non_contracting) + else: + out_shape = (*lhs_non_contracting, *rhs_non_contracting) has_bias = bias is not None - if has_bias and bias.shape != (group_sizes.size, N): - raise ValueError( - f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" - ) + if has_bias: + N_dim = math.prod(rhs_non_contracting) + assert bias.shape == ( + num_gemms, + N_dim, + ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" bias = jnp.empty((), jnp.float32) if bias is None else bias if group_offset is not None: @@ -2153,7 +2330,6 @@ def grouped_gemm( use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) if use_v2_ffi: - num_gemms = group_sizes.shape[0] additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta else: @@ -2166,19 +2342,27 @@ def grouped_gemm( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, additional_arg_1, - M=M, - N=N, - K=K_lhs, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + out_shape=tuple(int(d) for d in out_shape), + lhs_left_size=int(lhs_left_size), + lhs_right_size=int(lhs_right_size), + rhs_left_size=int(rhs_left_size), + rhs_right_size=int(rhs_right_size), ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index bf4e833c89..a3d363e42a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -43,6 +43,7 @@ ScalingMode, compute_scale_from_amax, NoScaleTensor, + GroupedNoScaleTensor, get_rht_matrix, QuantizeLayout, ) @@ -1001,7 +1002,6 @@ class GroupedQuantizePrimitive(BasePrimitive): 5, 6, 7, - 8, ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype inner_primitive = None outer_primitive = None @@ -1016,7 +1016,6 @@ def abstract( scaling_mode, q_layout, flatten_axis, - group_axis, scale_dtype, ): """ @@ -1038,7 +1037,6 @@ def abstract( ).get_grouped_scale_shape_2x( x_aval.shape, group_sizes_aval.size, - group_axis, is_padded=True, flatten_axis=flatten_axis, ) @@ -1099,7 +1097,6 @@ def lowering( scaling_mode, q_layout, flatten_axis, - group_axis, scale_dtype, ): """ @@ -1110,7 +1107,6 @@ def lowering( assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 - assert group_axis == 0 return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1130,7 +1126,6 @@ def impl( scaling_mode, q_layout, flatten_axis, - group_axis, scale_dtype, ): """ @@ -1151,7 +1146,6 @@ def impl( scaling_mode=scaling_mode, q_layout=q_layout, flatten_axis=flatten_axis, - group_axis=group_axis, scale_dtype=scale_dtype, ) return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) @@ -1164,20 +1158,18 @@ def grouped_quantize( x: jnp.ndarray, quantizer: GroupedQuantizer, group_sizes: jnp.ndarray = None, - amax: jnp.ndarray = None, flatten_axis: int = -1, -) -> GroupedScaledTensor1x: +) -> Union[GroupedScaledTensor1x, GroupedNoScaleTensor]: """Quantize a tensor in grouped manner. This function quantizes a tensor by splitting it into groups along a specified axis and applying quantization to each group separately. The groups can be either specified - explicitly through group_sizes or automatically split along the group_axis. + explicitly through group_sizes or automatically split along axis 0. Args: x: Input tensor to quantize quantizer: The quantizer to use for quantization group_sizes: Array of ints containing the size of each group (default: None) - amax: The amax of x; if None, it is auto-generated. (default: None) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) Returns: @@ -1185,31 +1177,34 @@ def grouped_quantize( Note: - If group_sizes is not provided, the tensor will be split into equal-sized groups - along the group_axis - - The group_axis is currently fixed to 0 + along axis 0 - The quantizer's q_layout determines whether row-wise, column-wise, or both quantization is applied """ if quantizer is None: - if isinstance(x, NoScaleTensor): + if isinstance(x, GroupedNoScaleTensor): return x - return NoScaleTensor(data=x, amax=None) + return GroupedNoScaleTensor( + data=x, + amax=None, + first_dims=group_sizes, + last_dims=None, + original_shape=x.shape, + ) # TODO(Phuong): add support for flatten_axis = -2 assert flatten_axis in ( -1, x.ndim - 1, ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}" - group_axis = 0 + ragged_first_dims = group_sizes # None if no explicit group_sizes (kernel case) if group_sizes is None: - group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) + group_sizes = jnp.ones(x.shape[0], dtype=jnp.int32) if not GroupedQuantizePrimitive.enabled(): - return quantizer.quantize( - x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis - ) + return quantizer.quantize(x, flatten_axis=flatten_axis, group_sizes=group_sizes) n_groups = group_sizes.size original_shape = x.shape assert n_groups == len( @@ -1222,13 +1217,8 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - if amax is not None: - row_amax = amax - else: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) - segment_ids = jnp.repeat( - jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] - ) + row_amax = jnp.max(jnp.abs(x), axis=range(1, x.ndim)) + segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[0]) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) @@ -1256,7 +1246,6 @@ def grouped_quantize( scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, - group_axis=group_axis, scale_dtype=quantizer.get_scale_dtype(), ) @@ -1280,9 +1269,8 @@ def grouped_quantize( q_layout=quantizer.q_layout, data_layout=quantizer.get_data_layout(), flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=ragged_first_dims, original_shape=original_shape, - group_axis=group_axis, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..a74b209e4f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -55,6 +55,32 @@ struct GemmConfig { bool use_split_accumulator; }; +struct GroupedGemmV2Config { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; + int64_t lhs_left_size; + int64_t lhs_right_size; + int64_t rhs_left_size; + int64_t rhs_right_size; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + bool has_bias; + bool use_async_d2h_group_sizes; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; + int64_t lhs_left_size; + int64_t lhs_right_size; + int64_t rhs_left_size; + int64_t rhs_right_size; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -192,6 +218,30 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("rhs_transposed"), ::xla::ffi::StructMember("use_split_accumulator")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary"), + ::xla::ffi::StructMember("lhs_left_size"), + ::xla::ffi::StructMember("lhs_right_size"), + ::xla::ffi::StructMember("rhs_left_size"), + ::xla::ffi::StructMember("rhs_right_size")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("has_bias"), + ::xla::ffi::StructMember("use_async_d2h_group_sizes"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary"), + ::xla::ffi::StructMember("lhs_left_size"), + ::xla::ffi::StructMember("lhs_right_size"), + ::xla::ffi::StructMember("rhs_left_size"), + ::xla::ffi::StructMember("rhs_right_size")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 2acefa2d30..0d1ef405f4 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -619,137 +619,99 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. -Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, - Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, size_t m, - size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode, bool is_grouped_dense_wgrad) { - // Notes on matrix layouts and transpose: - // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major [m, k] for N - [k, m] for T - // B: row-major [k, n] for N - [n, k] for T - // on exiting this function, JAX expect: - // C: row-major with size [m, n]. - // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m] for T - [m, k] for N - // B: column-major with size [n, k] for T - [k, n] for N - // - // If we call cuBLAS GEMM for A * B, the output will be: - // C: column-major with size [m, n] --> row-major with size [n, m]. - // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. +// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. +// int64_offset (in int64 elements) is updated on return to the next available slot so callers can +// thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked +// before each slot is used. Only NO_SCALING is supported. +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + // Flatten dims at axis_boundary to produce a 2D NVTE shape. + // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, + // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); + wrapper.set_rowwise(data, std::nullopt); + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); - bool has_bias = product(bias.dimensions()) > 0; - auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); +// Returns num_gemms from the first non-empty per-tensor group_sizes buffer, +// falling back to the element count of alpha for the uniform-batch case. +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, Buffer_Type const &out_last_dims, + Buffer_Type const &alpha) { + if (lhs_first_dims.element_count() > 0) { + return lhs_first_dims.element_count(); + } else if (lhs_last_dims.element_count() > 0) { + return lhs_last_dims.element_count(); + } else if (rhs_first_dims.element_count() > 0) { + return rhs_first_dims.element_count(); + } else if (rhs_last_dims.element_count() > 0) { + return rhs_last_dims.element_count(); + } else if (out_first_dims.element_count() > 0) { + return out_first_dims.element_count(); + } else if (out_last_dims.element_count() > 0) { + return out_last_dims.element_count(); + } else { + return alpha.element_count(); // uniform batch: no ragged tensor + } +} + +} // namespace jax +} // namespace transformer_engine - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; +namespace transformer_engine { +namespace jax { - // Convert int32 group_sizes to int64 into the dedicated output buffer. - NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), - int64_sizes_ptr, num_gemms, stream); +// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. +Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, + Result_Type cublas_workspace, Result_Type setup_workspace, + Result_Type int64_workspace, GroupedGemmV2Config config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary, + lhs_left_size, lhs_right_size, rhs_left_size, rhs_right_size] = config; NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); - // It is weird that TE/Common GEMM only use colwise for MXFP8 - const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); - const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; - const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; - const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, + rhs_last_dims, out_first_dims, out_last_dims, alpha); - // Outputs - auto out_ptr = reinterpret_cast(output->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Workspaces. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); - // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); - auto workspace_total_size = product(cublas_workspace->dimensions()); - - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - const size_t workspace_alignment_padding = 256; - const size_t tensor_scaling_sinv_aligment = 16; - const size_t mxfp8_scaling_sinv_alignment_padding = 256; - auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { - // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned - // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. - workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); - } - auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned - auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; - - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); - size_t actual_lhs_size = product(lhs_data.dimensions()); - size_t actual_rhs_size = product(rhs_data.dimensions()); - size_t actual_out_size = product(output->dimensions()); - NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", - expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, - "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, - " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, - " * ", n, " = ", expected_out_size, ", got ", actual_out_size); - } else { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, - " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, - "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, - " = ", expected_out_size, ", got ", actual_out_size); - } - - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - bool grad = false; - bool accumulate = false; - bool use_split_accumulator = false; - auto bias_shape = std::vector{has_bias ? n : 0}; - const int arch = cuda::sm_arch(); - - if (arch < 100 && is_fp8_gemm) { - NVTE_CHECK(!lhs_is_trans && rhs_is_trans, - "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", - "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); - } - + auto workspace_size = product(cublas_workspace->dimensions()) - 256; TensorWrapper workspace_setup(setup_workspace_ptr, std::vector{product(setup_workspace->dimensions())}, DType::kByte); @@ -763,59 +725,21 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_grouped_dense_wgrad) { - NVTE_CHECK(lhs_is_trans && !rhs_is_trans, - "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); - - //// RHS - NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - - //// LHS - NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; - lhs_is_trans = true; - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - - //// OUTPUT - NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, - alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), - workspace_cublas.data(), - nullptr, // config (use defaults) - stream); - - return ffi_with_cuda_error_check(); - } - - // Nominal case for FWD or DGRAD - - //// RHS - NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; - if (rhs_is_trans) { - rhsShape.data[0] = num_gemms * n; - rhsShape.data[1] = k; - } - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - - //// LHS - NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; - if (lhs_is_trans) { - std::swap(lhsShape.data[0], lhsShape.data[1]); - } - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); - - //// OUTPUT - NVTEShape outShape{.data = {m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed. + // int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each. + // int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims + // buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor. + auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); + size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); + size_t int64_offset = 0; + auto rhs_tensor = + make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -834,28 +758,31 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes (int32) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("is_grouped_dense_wgrad"), + .Attrs(), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, + GroupedGemmConfig config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes, + lhs_axis_boundary, rhs_axis_boundary, lhs_left_size, lhs_right_size, rhs_left_size, + rhs_right_size] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -872,6 +799,54 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else + NVTE_CHECK(false, + "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to " + "determine num_gemms."); + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; + + // Derive m, n, k from pre-computed original shape sizes (passed from Python). + // lhs_left_size = product of original lhs dims before axis_boundary + // lhs_right_size = product of original lhs dims after axis_boundary + // Same pattern for rhs. + size_t k = lhs_is_trans ? lhs_left_size : lhs_right_size; + size_t m, n; + if (is_rhs_ragged) { + // wgrad: non-contracting lhs dims form M; non-contracting rhs dims form N + m = lhs_is_trans ? lhs_right_size : lhs_left_size; + n = rhs_is_trans ? rhs_left_size : rhs_right_size; + } else { + m = lhs_is_trans ? lhs_right_size : lhs_left_size; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_left_size / num_gemms : rhs_right_size; + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -884,9 +859,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -953,14 +925,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -976,25 +948,28 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + if (any_ragged) { + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + auto gs_data_ptr = reinterpret_cast(active_gs_ptr->untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_rhs_ragged) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -1042,7 +1017,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto lhs_shape_i = std::vector{m_i, k}; auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { size_t k_i = dim_list_host[i]; lhs_shape_i[0] = lhs_is_trans ? k_i : m; lhs_shape_i[1] = lhs_is_trans ? m : k_i; @@ -1237,19 +1212,16 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // rhs_data .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("has_bias") - .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attrs()); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index fe02e61fc0..dbd7bbb1ff 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -18,15 +18,11 @@ from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope from .quantize import ( - ScaledTensorFactory, ScaledTensor, - ScalingMode, QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, - is_fp8_gemm_with_all_layouts_supported, TensorUsage, - QuantizeLayout, ) @@ -325,7 +321,6 @@ def grouped_dense( group_sizes: jnp.ndarray, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), bias: jnp.ndarray = None, - kernel_amax: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, @@ -342,7 +337,6 @@ def grouped_dense( contracting_dims: Tuple of sequences specifying which dimensions to contract (currently only supports ((1,), (1,))) bias: Bias tensor of shape (G, N) - kernel_amax: The amax values of weight matrix of shape (G,) precision: JAX precision for the GEMM operation preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) @@ -361,7 +355,6 @@ def grouped_dense( group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -371,14 +364,13 @@ def grouped_dense( return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7, 9)) def _grouped_dense( x, kernel, group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -391,7 +383,6 @@ def _grouped_dense( group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -407,7 +398,6 @@ def _grouped_dense_fwd_rule( group_sizes, contracting_dims, bias, - kernel_amax, precision, preferred_element_type, group_offset, @@ -415,118 +405,42 @@ def _grouped_dense_fwd_rule( kernel_fsdp_info, ): use_bias = bias is not None - is_noop_quantizer_set = quantizer_set == noop_quantizer_set kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." + del kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx, kernel_fsdp_info, kernel_fsdp_enabled - if is_noop_quantizer_set: - grouped_gemm_x = x - grouped_gemm_kernel = kernel - ctx_x = x - ctx_kernel = kernel - flatten_axis_k = None - - if kernel_fsdp_enabled: - kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) - else: - original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout - - x_contracting_dims, k_contracting_dims = contracting_dims - flatten_axis_x = -len(x_contracting_dims) - flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis - - assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" - assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" - # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( - "grouped_dense for FP8 can only handle x_contracting_dims=(1,) " - "and k_contracting_dims=(1,) for now, " - f"got {x_contracting_dims=} and {k_contracting_dims=}" - ) + x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis - casted_x = tex.grouped_quantize( - x, - quantizer_set.x, - group_sizes, - flatten_axis=flatten_axis_x, - ) + casted_x = tex.grouped_quantize( + x, + quantizer_set.x, + group_sizes, + flatten_axis=flatten_axis_x, + ) - ctx_kernel_usage = TensorUsage.RHS_TRANS - if kernel_fsdp_enabled: - assert quantizer_set.kernel.scaling_mode in [ - ScalingMode.CURRENT_TENSOR_SCALING, - ScalingMode.DELAYED_TENSOR_SCALING, - ] - # Perform `cast` only - ctx_kernel_usage = TensorUsage.LHS - quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE - - casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k - ) - contracting_dims = (x_contracting_dims, k_contracting_dims) - - # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have - # rowwise_casted_x.original_shape == (M, K) - # colwise_casted_kernel.original_shape == (G, N, K) - grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) - ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) - ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage) - - if kernel_fsdp_enabled: - ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape) - global_ctx_kernel_data = _all_gather_kernel( - ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx - ) - kernel_shape = global_ctx_kernel_data.shape - - ctx_kernel = ScaledTensorFactory.create_1x( - global_ctx_kernel_data.reshape(-1), - ctx_kernel.scale_inv, - scaling_mode=ctx_kernel.scaling_mode, - dq_dtype=ctx_kernel.dq_dtype, - is_colwise=False, - data_layout="N", - flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, - original_shape=kernel_shape, - group_axis=ctx_kernel.group_axis, - ) - - if is_fp8_gemm_with_all_layouts_supported(): - grouped_gemm_kernel = ctx_kernel - else: - grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1) - grouped_gemm_kernel = ScaledTensorFactory.create_1x( - grouped_gemm_kernel_data.reshape(-1), - ctx_kernel.scale_inv, - scaling_mode=ctx_kernel.scaling_mode, - dq_dtype=ctx_kernel.dq_dtype, - is_colwise=True, - data_layout="T", - flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, - original_shape=kernel_shape, - group_axis=ctx_kernel.group_axis, - ) - else: - grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) - - # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one. - # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. - quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout + casted_kernel = tex.grouped_quantize(kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k) + contracting_dims = (x_contracting_dims, k_contracting_dims) + # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have + # rowwise_casted_x.original_shape == (M, K) + # colwise_casted_kernel.original_shape == (G, N, K) + grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) + ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) + + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, + contracting_dims=contracting_dims, + bias=bias, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) ctx = ( @@ -540,7 +454,6 @@ def _grouped_dense_fwd_rule( x.shape, kernel.shape, use_bias, - is_noop_quantizer_set, quantizer_set, flatten_axis_k, ) @@ -550,6 +463,10 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): + kernel_fsdp_mesh_axis, _ = kernel_fsdp_info + kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims ( @@ -559,62 +476,41 @@ def _grouped_dense_bwd_rule( x_shape, kernel_shape, use_bias, - is_noop_quantizer_set, quantizer_set, flatten_axis_k, ) = ctx - if is_noop_quantizer_set: - # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) - # g_contracting_dim = (1, ) - # k_contracting_dim = (2, ) - g_contracting_dim = tuple( - range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) - ) - k_contracting_dim = tuple( - dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims - ) - dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = grad - dgrad_kernel_T = ctx_kernel - - # g_contracting_dim = (0, ) - # x_contracting_dim = (0, ) - g_contracting_dim = x_contracting_dim = tuple( - range(0, len(x_shape) - len(fwd_x_contracting_dims)) - ) - wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_x_T = ctx_x - wgrad_grad = grad - else: - casted_grad = tex.grouped_quantize( - grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k - ) + # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) + # g_contracting_dim = (1, ) + # k_contracting_dim = (2, ) + g_contracting_dim = tuple( + range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + ) + k_contracting_dim = tuple( + dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims + ) - # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use - # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the - # extra transpose for FP8 in grouped_gemm - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - g_contracting_dim = (1,) - k_contracting_dim = (2,) - dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) - dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS) - dgrad_kernel_T = ctx_kernel - - # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work - # after the extra transpose for FP8 in grouped_gemm - # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? - g_contracting_dim = (0,) - x_contracting_dim = (0,) - wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) - wgrad_x_T = ctx_x - wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) + casted_grad = tex.grouped_quantize( + grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k + ) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS) + dgrad_kernel_T = ctx_kernel + + # g_contracting_dim = (0, ) + # x_contracting_dim = (0, ) + g_contracting_dim = x_contracting_dim = tuple( + range(0, len(x_shape) - len(fwd_x_contracting_dims)) + ) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + + wgrad_x_T = ctx_x + wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - group_sizes, - dgrad_contracting_dims, + contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, @@ -623,23 +519,16 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - group_sizes, - wgrad_contracting_dims, + contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, ) - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - if kernel_fsdp_mesh_axis is not None: - wgrad = _psum_scatter_kernel( - wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx - ) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None - dkernel_amax = None - return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set + return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 74787b9308..5abb2e74df 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -275,29 +275,45 @@ def _grouped_dequantize(grouped_scaled_tensor): """ data = grouped_scaled_tensor.data scale_inv = grouped_scaled_tensor.scale_inv - group_sizes = grouped_scaled_tensor.group_sizes + group_sizes = ( + grouped_scaled_tensor.first_dims + if grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 + else grouped_scaled_tensor.last_dims + ) + # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape + if group_sizes is None: + group_sizes = jnp.ones(grouped_scaled_tensor.original_shape[0], dtype=jnp.int32) flatten_axis = grouped_scaled_tensor.flatten_axis scaling_mode = grouped_scaled_tensor.scaling_mode original_shape = grouped_scaled_tensor.original_shape - group_axis = grouped_scaled_tensor.group_axis - flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis output = [] - non_group_shape = tuple( - original_shape[i] for i in range(len(original_shape)) if i != group_axis + # For transposed (colwise) tensors with ragged groups, the group dimension is the last + # axis of original_shape (e.g. original_shape = (N, M) with groups along M), while the + # non-group dimensions are all axes before it. For the uniform-groups case the group + # dimension stays at axis 0, so the existing axis-0 logic applies. + is_transposed_ragged = ( + grouped_scaled_tensor.data_layout == "T" and group_sizes.size != original_shape[0] ) + if is_transposed_ragged: + non_group_shape = original_shape[:-1] + else: + non_group_shape = tuple(original_shape[i] for i in range(len(original_shape)) if i != 0) matrix_sizes = group_sizes * math.prod(non_group_shape) data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1]) scale_inv_ptr = 0 for i, data_i in enumerate(data): - data_shape_i = ( - *original_shape[:group_axis], - group_sizes[i], - *original_shape[group_axis + 1 :], - ) + if is_transposed_ragged: + data_shape_i = (*non_group_shape, group_sizes[i]) + else: + data_shape_i = ( + group_sizes[i], + *original_shape[1:], + ) assert math.prod(data_shape_i) == data_i.size, ( f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to" f" {data_i.size}" diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..db56db935d 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -920,7 +920,7 @@ def __post_init__(self): self.data_layout = self.quantizers[0].data_layout def _create_grouped_tensor_from_tensor_list( - self, tensor_list, group_sizes, original_shape, group_axis, mode + self, tensor_list, group_sizes, original_shape, mode ): # mode 0 = concate, mode 1 = add # TODO(Ming Huang): Consider to apply Enum for mode. @@ -948,9 +948,8 @@ def _create_grouped_tensor_from_tensor_list( is_colwise=tensor_list[0].is_colwise, data_layout=tensor_list[0].data_layout, flatten_axis=tensor_list[0].flatten_axis, - group_sizes=group_sizes, + first_dims=group_sizes, original_shape=original_shape, - group_axis=group_axis, ) def _quantize_func(self, *args, **kwargs): @@ -964,12 +963,11 @@ def quantize( dq_dtype=None, flatten_axis=-1, group_sizes=None, - group_axis=0, ): """Quantize a tensor in grouped manner. Expected input shape: [M, K] or [G, K, N] - Split to x.shape[group_axis] number of groups if group_sizes is not given + Split to x.shape[0] number of groups if group_sizes is not given Args: x: Input tensor to quantize @@ -978,12 +976,10 @@ def quantize( dq_dtype: Data type for dequantized values flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) group_sizes: Array of ints containing the size of each group (default: None) - group_axis: The axis along which grouping is performed (default: 0) Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ - assert group_axis == 0, "Only group_axis == 0 is supported now!" dq_dtype = dq_dtype if dq_dtype is not None else x.dtype if flatten_axis < 0: @@ -1023,8 +1019,8 @@ def quantize( tensor_list.append(tensor) combine_mode = 1 # Add else: - group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) - x = jnp.split(x, x.shape[group_axis], axis=group_axis) + group_sizes = jnp.ones(x.shape[0], dtype=jnp.int32) + x = jnp.split(x, x.shape[0], axis=0) tensor_list = [] for i in range(len(group_sizes)): @@ -1038,12 +1034,12 @@ def quantize( if is_rowwise: rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list] grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list( - rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode + rowwise_tensor_list, group_sizes, original_shape, combine_mode ) if is_colwise: colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list] grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list( - colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode + colwise_tensor_list, group_sizes, original_shape, combine_mode ) if is_colwise and is_rowwise: diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 61c3af178c..26b998ba90 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -135,14 +135,13 @@ def get_scale_shape( @abstractmethod def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Original shape of the data tensor n_groups: Number of groups in grouped quantization - group_axis: The axis along which grouping is performed is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -253,7 +252,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return QuantizeLayout.ROWWISE def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for scale tensors in this mode. @@ -266,7 +265,7 @@ def get_grouped_scale_shape( Returns: The shape for scale tensors """ - del data_shape, group_axis, is_colwise + del data_shape, is_colwise assert isinstance(n_groups, int) return (n_groups,) @@ -370,7 +369,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return QuantizeLayout.COLWISE def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for scale tensors in this mode. @@ -383,7 +382,7 @@ def get_grouped_scale_shape( Returns: The shape for scale tensors """ - del data_shape, group_axis, is_colwise + del data_shape, is_colwise assert isinstance(n_groups, int) return (n_groups,) @@ -613,7 +612,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return QuantizeLayout.COLWISE def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[int]: """Get the shape for grouped scale tensors in this mode. If padded: The estimiated maximal possible shape for grouped scale tensor is return instead. @@ -937,14 +936,13 @@ def get_shardy_sharding_rules( ) def get_grouped_scale_shape_2x( - self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_padded=True, flatten_axis=-1 ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor n_groups: Number of groups for grouped quantization - group_axis: The axis along which grouping is performed is_padded: Whether to use padded shapes flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -954,7 +952,6 @@ def get_grouped_scale_shape_2x( rowwise_scale_shape = self.get_grouped_scale_shape( data_shape, n_groups, - group_axis, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis, @@ -962,7 +959,6 @@ def get_grouped_scale_shape_2x( colwise_scale_shape = self.get_grouped_scale_shape( data_shape, n_groups, - group_axis, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis, @@ -970,7 +966,7 @@ def get_grouped_scale_shape_2x( return (rowwise_scale_shape, colwise_scale_shape) def get_grouped_scale_shape( - self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + self, data_shape, n_groups, is_colwise, is_padded=True, flatten_axis=-1 ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. @@ -985,7 +981,6 @@ def get_grouped_scale_shape( return self._get_impl().get_grouped_scale_shape( data_shape, n_groups, - group_axis, is_colwise=is_colwise, is_padded=is_padded, flatten_axis=flatten_axis, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c26cb8a531..b1f49dacdc 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -9,7 +9,7 @@ rowwise and colwise quantization modes with proper scaling and dequantization. """ from dataclasses import dataclass -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple from abc import ABC, abstractmethod import jax.numpy as jnp @@ -32,6 +32,7 @@ "ScaledTensor1x", "ScaledTensor2x", "GroupedScaledTensor1x", + "GroupedNoScaleTensor", "ScaledTensorFactory", "with_sharding_constraint_by_logical_axes", ] @@ -365,21 +366,22 @@ class GroupedScaledTensor1x(ScaledTensor1x): where elements are grouped along a specified axis. Attributes: - group_sizes: Array containing the size of each group + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping - group_axis: The axis along which grouping is performed (default: 0) """ - group_sizes: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] original_shape: Tuple - group_axis: int def __init__( self, data, scale_inv, amax, - group_sizes, + first_dims, + last_dims, scaling_mode, dq_dtype, _dq_func, @@ -387,12 +389,11 @@ def __init__( data_layout, flatten_axis, original_shape, - group_axis=0, ): self.flatten_axis = flatten_axis - self.group_sizes = group_sizes + self.first_dims = first_dims + self.last_dims = last_dims self.original_shape = original_shape - self.group_axis = group_axis # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, @@ -410,7 +411,6 @@ def __init__( def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" - assert self.group_axis >= 0 assert self.flatten_axis > 0 data_ndim = len(self.original_shape) @@ -418,14 +418,19 @@ def __post_init__(self): 0 < self.flatten_axis < data_ndim ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}" - assert ( - 0 <= self.group_axis < data_ndim - ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" + active_dims = ( + self.first_dims + if self.first_dims is not None and self.first_dims.size > 0 + else self.last_dims + ) + if active_dims is not None: + num_groups = active_dims.size + else: + num_groups = self.original_shape[0] expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( self.original_shape, - self.group_sizes.size, - self.group_axis, + num_groups, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis, @@ -442,7 +447,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax, self.group_sizes) + children = (self.data, self.scale_inv, self.amax, self.first_dims, self.last_dims) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -451,7 +456,6 @@ def tree_flatten(self): self.data_layout, self.flatten_axis, self.original_shape, - self.group_axis, ) return (children, aux_data) @@ -473,6 +477,81 @@ def checkpoint(self, quantizer): return jax_checkpoint_name(self, name=quantizer.checkpoint_name) +@register_pytree_node_class +@dataclass +class GroupedNoScaleTensor(AbstractBaseTensor1x): + """Unquantized grouped tensor. + + Stores N-D data with per-group dimension sizes so that grouped_gemm() + can extract first/last dims automatically without explicit parameters. + + Attributes: + data: The raw (unquantized) tensor data in N-D layout + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged + original_shape: Shape of data (same as data.shape for N-D unquantized) + """ + + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] + original_shape: Tuple + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations.""" + children = (self.data, self.amax, self.first_dims, self.last_dims) + aux_data = (self.original_shape,) + return (children, aux_data) + + @property + def ndim(self): + """Number of dimensions of the underlying array.""" + return self.data.ndim + + def dequantize(self): + """This is a no-op for a higher-precision tensor so this simply returns the tensor's data.""" + return self.data + + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) + assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor" + return self + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names) + + return GroupedNoScaleTensor( + data=data, + amax=self.amax, + first_dims=self.first_dims, + last_dims=self.last_dims, + original_shape=self.original_shape, + ) + + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + assert quantizer is None, "NoScaleTensor does not support quantization." + return self + + @register_pytree_node_class @dataclass class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): @@ -570,9 +649,9 @@ def create_1x( is_colwise=False, data_layout="N", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, - group_axis=0, has_rht_applied=False, ): """Creates a single-scale quantized tensor. @@ -586,29 +665,37 @@ def create_1x( is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor - group_sizes: Array of ints containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) - group_axis: The axis along which grouping is performed (default: 0) has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: - A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided + A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether first_dims or last_dims is provided """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if group_sizes is not None: - flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + if first_dims is not None or last_dims is not None or original_shape is not None: assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + + # Determine num_groups from whichever dims array is provided, or from original_shape + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) + if active_dims is not None: + num_groups = active_dims.size + else: + num_groups = original_shape[0] # Handling attrs of transposed tensors - group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": - if original_shape[0] == group_sizes.size: + if original_shape[0] == num_groups: original_shape = ( original_shape[0], *original_shape[flatten_axis:], @@ -620,7 +707,6 @@ def create_1x( *original_shape[flatten_axis:], *original_shape[:flatten_axis], ) - group_axis = flatten_axis flatten_axis = len(original_shape) - flatten_axis return GroupedScaledTensor1x( @@ -633,9 +719,9 @@ def create_1x( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, ) # Handling attrs of transposed tensors @@ -668,9 +754,9 @@ def create_2x( dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, - group_axis=0, rowwise_has_rht_applied=False, colwise_has_rht_applied=False, ): @@ -686,9 +772,9 @@ def create_2x( dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") flatten_axis: The quantization axis for the tensor - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) - group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -710,9 +796,9 @@ def create_2x( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, ) colwise_tensor = ScaledTensorFactory.create_1x( @@ -724,9 +810,9 @@ def create_2x( is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -744,9 +830,9 @@ def create( data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, flatten_axis: int = -1, - group_sizes: jnp.ndarray = None, + first_dims: jnp.ndarray = None, + last_dims: jnp.ndarray = None, original_shape: Tuple[int] = None, - group_axis: int = 0, rowwise_has_rht_applied: bool = False, colwise_has_rht_applied: bool = False, ): @@ -762,9 +848,9 @@ def create( data_layout: The data_layout specification (default: "NN") q_layout: The quantization axis (default: ROWWISE) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) - group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -785,9 +871,9 @@ def create( dq_dtype, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, rowwise_has_rht_applied=rowwise_has_rht_applied, colwise_has_rht_applied=colwise_has_rht_applied, ) @@ -802,9 +888,9 @@ def create( is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, ) @@ -817,9 +903,9 @@ def create( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, - group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, )