[Common] Persistent Grouped MXFP8 quantization kernel#2738
[Common] Persistent Grouped MXFP8 quantization kernel#2738Oleg-Goncharov wants to merge 73 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR significantly refactors the grouped MXFP8 quantization kernel by introducing a persistent grid-stride scheduler and lifting Key changes:
Previously raised concerns are resolved: the hardcoded Remaining minor points:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Host as Host (group_quantize)
participant UPD as update_tma_descriptors kernel<br/>(num_tensors blocks × 1 thread)
participant GMEM as g_tensor_maps<br/>(static __device__)
participant KRNL as group_quantize_mxfp8_kernel<br/>(persistent: SM_count × 24 blocks)
Host->>Host: compute work_blocks_X/Y
Host->>Host: create base CUtensorMaps (input, output)
alt non-single-tensor (VARYING_*)
Host->>UPD: launch<<<num_tensors, 1, stream>>>
UPD->>GMEM: modify_base_tensor_map → g_tensor_maps.input/output[tensor_id]
end
Host->>KRNL: launch<<<static_grid, THREADS_PER_CHUNK, shmem, stream>>>
loop while !job_finished (persistent grid-stride)
KRNL->>KRNL: decode_job(ctaid_X, ctaid_Y) → JobDescriptor
KRNL->>KRNL: is_job_valid? / job_has_work?
alt is_single_tensor
KRNL->>KRNL: use tensor_map_*_static (passed as grid_constant)
else non-single-tensor
KRNL->>GMEM: fence_acquire_tensormap (once per tensor_id change)
KRNL->>GMEM: read g_tensor_maps.*[tensor_id]
end
KRNL->>KRNL: prime pipeline: prefetch_input_stage (PREFETCH_STAGES=1)
loop STAGES=4 (32-row slices)
KRNL->>KRNL: prefetch next stage (double-buffer)
KRNL->>KRNL: mbarrier_wait (TMA global→shared complete)
KRNL->>KRNL: process_colwise_stage / process_rowwise_stage
KRNL->>KRNL: store_output_stage (TMA shared→global)
end
KRNL->>KRNL: advance_to_next_job (stride by static_block_stride)
end
KRNL->>Host: atomicMaxFloat(amax_ptr, block_amax)
Reviews (22): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
924ff91 to
325181b
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
5815335 to
aa484a3
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Description
This PR adds a persistent grouped MXFP8 quantization kernel with static scheduling.
Type of change
Changes
Checklist: