Skip to content

[Common] Persistent Grouped MXFP8 quantization kernel#2738

Open
Oleg-Goncharov wants to merge 73 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel
Open

[Common] Persistent Grouped MXFP8 quantization kernel#2738
Oleg-Goncharov wants to merge 73 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Mar 5, 2026

Description

This PR adds a persistent grouped MXFP8 quantization kernel with static scheduling.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added persistent kernel
  • Added TunableConfig structure to tune performance

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Mar 5, 2026
@Oleg-Goncharov Oleg-Goncharov requested a review from ptrendx March 5, 2026 16:18
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR significantly refactors the grouped MXFP8 quantization kernel by introducing a persistent grid-stride scheduler and lifting ScalingType and ShapeRepresentation to compile-time template parameters, eliminating runtime switches inside the device code. The main kernel now runs as a static_grid_size = SM_count × 24 block grid where each CTA iterates over multiple logical work items, reducing kernel launch overhead for workloads with many small tensors.

Key changes:

  • TunableConfig struct centralizes CHUNK_DIM_Y/X, THREADS_PER_CHUNK, PERSISTENT, and STATIC_PERSISTENT_BLOCKS_PER_SM=24 as compile-time constants.
  • JobDescriptor / BlockDescriptor / decode_job / decode_block helpers in common.cuh cleanly separate work-item decoding from the kernel body.
  • process_colwise_stage / process_rowwise_stage device functions replace the previous monolithic per-stage inline code, improving readability.
  • TensorMapStorage g_tensor_maps consolidates four separate device-global CUtensorMap arrays into one struct with static __device__ linkage (internal-linkage by design to avoid ODR violations when the header is included in multiple TUs).
  • exp2f_rcp<bf16> specialisation and new mul_cvt_4x overloads in ptx.cuh enable the fast BF16/FP16 cast-only path.
  • New public API nvte_group_quantize_v2 is exposed in cast.h.
  • Tests are updated to use the new API, add empty-tensor configs, and a large (16×4096-row) persistent-scheduler stress case.

Previously raised concerns are resolved: the hardcoded 128 in decode_block now uses CHUNK_DIM_X; rowwise_scale_is_within_bounds correctly scales the index to column units; the column-alignment assertion for non-single-tensor shapes is restored in get_tensor_cols_num; the redundant dead condition on block_offset_X_in_tensor is removed.

Remaining minor points:

  • The quant_config parameter added to group_quantize is threaded through to the call site but no field is currently read from it — a comment explaining the intended future use would clarify intent.
  • The static __device__ TensorMapStorage g_tensor_maps declaration in the shared header depends on both the update kernel and the main kernel being compiled into the same translation unit. This invariant currently holds but is not explicitly documented, which could cause subtle failures if the code is later reorganised.

Confidence Score: 4/5

  • PR is safe to merge; previously reported correctness bugs are fixed and the persistent scheduler logic is sound.
  • All P0/P1 concerns from prior review rounds are addressed: the hardcoded-128 mapping bug, the units mismatch in rowwise_scale_is_within_bounds, the missing column-alignment check, and the dead bounds condition are all resolved. The persistent scheduler design (grid-stride loop, double-buffered TMA pipeline, empty-tensor skip, fence_acquire caching) is logically correct. New test cases cover zero-sized tensors and a large-scale persistent-scheduler scenario. The two remaining comments are P2 style/documentation items (unused quant_config parameter and undocumented single-TU invariant for g_tensor_maps) that do not affect correctness.
  • transformer_engine/common/cast/core/common.cuh (static device g_tensor_maps ODR contract) and transformer_engine/common/cast/dispatch/quantize.cuh (unused quant_config parameter).

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Core kernel file: adds TunableConfig struct, introduces persistent grid-stride scheduling (PERSISTENT=true, 24 blocks/SM), refactors the main kernel to use a while-loop over multiple jobs per CTA. ScalingType and ShapeRepresentation are now compile-time template parameters (no runtime switch in the kernel). Breaking apart the old monolithic stage loop into process_colwise_stage / process_rowwise_stage device functions significantly improves readability. Previously-flagged bugs (hardcoded 128 in decode_block, redundant bounds check, rowwise_scale_is_within_bounds units) are all fixed. The unused quant_config parameter is the only remaining minor concern.
transformer_engine/common/cast/core/common.cuh New JobDescriptor / BlockDescriptor structs and the decode_job / decode_block / is_job_valid / advance_to_next_job helpers are added here. The update_tma_descriptors kernel is moved from the group-quantize header to this shared header. The four separate CUtensorMap arrays are consolidated into a single TensorMapStorage struct declared as static device with an explanatory comment about the internal-linkage ODR strategy. barrier helpers (initialize, destroy, prefetch_input_stage, store_output_stage) are also extracted here for reuse. All previously flagged issues (signed/unsigned arithmetic in dbias offset, output stride assumption for VARYING_LAST_DIM) are addressed.
transformer_engine/common/util/ptx.cuh Adds exp2f_rcp specialization (SM 9.0+) and four new mul_cvt_4x overloads for (bf16x4, bf16x2) and (fp16x4, fp16) × fp8e4m3/fp8e5m2, enabling the fast BF16/FP16 cast-only path in the new kernel. FPx4 gets alignas(4*sizeof(T)) to satisfy TMA alignment requirements. mxfp8_scaling.cu callers updated to use explicit template argument exp2f_rcp.
tests/cpp/operator/test_cast_mxfp8_grouped.cu Updates tests to use the new nvte_group_quantize_v2 API; adds empty-tensor configs ({128,0,0,256} and {128,0,128}) to verify the persistent loop correctly skips zero-sized tensors; adds a large VARYING_FIRST_DIM case (16×4096 rows) to stress the persistent scheduler; local divide_round_up_blocks lambda handles the N=0 case; comparison now uses elts_num as a flat 1D range to avoid stale trailing data.
transformer_engine/common/utils.cuh ShapeRepresentation enum moved here from cast/core/common.cuh (where it was inside the dispatch::common namespace) to make it accessible without pulling in the full cast headers. The duplicate definition in hadamard_transform/graph_safe_group_hadamard_transform.cu is removed accordingly.
transformer_engine/common/common.h Two new dispatch macros: TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH and TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH. These lift runtime enum values to compile-time template arguments, enabling the new kernel specialisations. SCALING_FACTORS_SWIZZLE_ALIGNMENT constant (128) added.
transformer_engine/common/include/transformer_engine/cast.h New public API nvte_group_quantize_v2 added; documentation notes the last-dimension alignment requirement for varying-last-dim tensors. Backward-compatible addition.
transformer_engine/common/cast/dispatch/quantize.cuh Passes &quant_config_cpp to group_quantize in both forward and backward helper paths. Minimal change; the parameter is threaded through but not yet consumed by the kernel launch code.

Sequence Diagram

sequenceDiagram
    participant Host as Host (group_quantize)
    participant UPD as update_tma_descriptors kernel<br/>(num_tensors blocks × 1 thread)
    participant GMEM as g_tensor_maps<br/>(static __device__)
    participant KRNL as group_quantize_mxfp8_kernel<br/>(persistent: SM_count × 24 blocks)

    Host->>Host: compute work_blocks_X/Y
    Host->>Host: create base CUtensorMaps (input, output)
    alt non-single-tensor (VARYING_*)
        Host->>UPD: launch<<<num_tensors, 1, stream>>>
        UPD->>GMEM: modify_base_tensor_map → g_tensor_maps.input/output[tensor_id]
    end
    Host->>KRNL: launch<<<static_grid, THREADS_PER_CHUNK, shmem, stream>>>
    loop while !job_finished  (persistent grid-stride)
        KRNL->>KRNL: decode_job(ctaid_X, ctaid_Y) → JobDescriptor
        KRNL->>KRNL: is_job_valid? / job_has_work?
        alt is_single_tensor
            KRNL->>KRNL: use tensor_map_*_static (passed as grid_constant)
        else non-single-tensor
            KRNL->>GMEM: fence_acquire_tensormap (once per tensor_id change)
            KRNL->>GMEM: read g_tensor_maps.*[tensor_id]
        end
        KRNL->>KRNL: prime pipeline: prefetch_input_stage (PREFETCH_STAGES=1)
        loop STAGES=4 (32-row slices)
            KRNL->>KRNL: prefetch next stage (double-buffer)
            KRNL->>KRNL: mbarrier_wait (TMA global→shared complete)
            KRNL->>KRNL: process_colwise_stage / process_rowwise_stage
            KRNL->>KRNL: store_output_stage (TMA shared→global)
        end
        KRNL->>KRNL: advance_to_next_job (stride by static_block_stride)
    end
    KRNL->>Host: atomicMaxFloat(amax_ptr, block_amax)
Loading

Reviews (22): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 924ff91 to 325181b Compare March 6, 2026 10:39
Oleg-Goncharov and others added 15 commits March 10, 2026 11:58
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@ptrendx ptrendx force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 5815335 to aa484a3 Compare March 10, 2026 19:07
@ptrendx ptrendx marked this pull request as draft March 12, 2026 16:26
Oleg-Goncharov and others added 2 commits March 13, 2026 17:08
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Oleg-Goncharov and others added 23 commits March 18, 2026 18:29
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Oleg-Goncharov and others added 2 commits March 25, 2026 13:01
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants