Add block-level CUB JIT support#1180
Conversation
Use CUB Block* APIs for JIT-able reductions, scans, sorts, and argsorts, including reduced-rank block kernels and fusion-safe scalar handling for surrounding MatX operators. Query CUB temporary storage by compiling an NVRTC probe to PTX and reading the global initializer, then feed that static shared-memory usage into launch capability negotiation. Cap CUB JIT elements-per-thread to powers of two whose total item width is at most 16 bytes, and add coverage for sum/prod/min/max/cumsum/sort/argsort plus fused expressions with fftshift, fft, and linspace-generated inputs.
|
/build |
Greptile SummaryThis PR adds block-level CUB JIT support to MatX, enabling
Confidence Score: 3/5Safe to merge on single-GPU setups; multi-GPU deployments with warm disk cubin caches risk loading a wrong-architecture cubin and crashing. The PR is large (23 files, ~4000 LOC) and addresses a complex JIT + CUB integration. Many previously flagged issues (double-free in PostRun, argsort complex guard, shmem_cache arch key, PTX regex, JIT stream, etc.) have been resolved. One new P1 remains: the disk cubin cache key omits the GPU SM architecture, meaning a cubin compiled for architecture X can be silently loaded on architecture Y in multi-GPU or cross-machine scenarios, causing cuModuleLoadDataEx to fail or produce wrong results. include/matx/core/nvrtc_helper.h — disk cubin cache key construction around line 638 needs the nvrtc_arch string appended, mirroring the fix already applied to shmem_cache. Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Exec as JIT CUDA Executor
participant Cache as JIT Launch Cache
participant NVRTC as nvrtc_helper
participant ShmCache as shmem_cache
participant Disk as Disk Cubin Cache
participant GPU as GPU (cuLaunchKernel)
User->>Exec: "ExecWithRank<RANK>(op, stream)"
Exec->>Cache: GetJITLaunchParamsCacheKey(device_id + ":" + kernel_op_type)
alt Cache HIT
Cache-->>Exec: blocks, threads, shm_size, global_kernel
else Cache MISS
Exec->>NVRTC: find_best_launch_params(op, ept, block_size)
NVRTC->>ShmCache: nvrtc_get_cub_block_shmem_size(arch, algo, type, ept, bs)
alt shmem_cache HIT
ShmCache-->>NVRTC: cached shm_size
else shmem_cache MISS
NVRTC->>NVRTC: make_cub_shmem_probe_source() to PTX compile
NVRTC->>NVRTC: regex extract temp_storage_size from PTX
NVRTC-->>ShmCache: store keyed by arch+algo+type+ept+bs
end
NVRTC->>Exec: best_ept, shm_size, block_size
Exec->>NVRTC: nvrtc_compile_and_run(kernel_name, kernel_op_type)
NVRTC->>Disk: lookup cubin by device_N_+kernel_name+kernel_op_type
alt Disk HIT
Disk-->>NVRTC: cubin bytes
else Disk MISS
NVRTC->>NVRTC: NVRTC compile to NVJITLINK to cubin
NVRTC-->>Disk: store cubin
end
NVRTC->>NVRTC: cuModuleLoadDataEx to cuModuleGetFunction
NVRTC-->>Cache: store launch params
end
alt "block_reduces_rank == true"
Exec->>Exec: get_grid_dims_block_reduce (batch dims only)
else
Exec->>Exec: get_grid_dims_block (rank-preserving)
end
Exec->>GPU: cuLaunchKernel(func, blocks, threads, shm, stream, args)
GPU->>GPU: matxOpT1KernelBlock / matxOpT2KernelBlockReduce
note over GPU: BlockReduce/BlockScan/BlockSort via CUB
GPU-->>User: results written to output tensor
Reviews (46): Last reviewed commit: "Address Greptile CUB JIT review issues" | Re-trigger Greptile |
Use explicit SUM_QUERY identities for shared-memory aggregation, make non-JIT CUB shared-memory queries fail loudly, and make the NVRTC temp-storage probe robust to signed or unsigned PTX declarations. Also align generated binary JIT capability routing with direct block-reduction detection so fused CUB reductions keep block-level thread participation while non-reduction operands use scalar indexing.
Evaluate direct block-reduction RHS operators with the active CapType so CUB sees the negotiated EPT and block size. Keep ScalarCap only for the output write, where the block aggregate is stored by a single thread.
|
/build |
Pass the executor stream through nvrtc_compile_and_run and into both cuLaunchKernel paths so CUDAJITExecutor(stream) preserves caller stream ordering. Add a nonblocking-stream regression test that would fail if the JIT kernel launched on stream 0.
|
@greptile review |
Check the block-reduction batch grid count before narrowing it into dim3::x, and add grid-dim coverage for the normal and oversized block-reduction cases.
|
@greptile review |
Normalize CUB shared-memory probe cache keys for algorithms whose temp storage does not depend on EPT, clarify the dynamic shared-memory limit name/logging, and make JIT launch-parameter caches inline so they are shared across translation units.
|
@greptile review |
|
/build |
|
@greptile review |
|
/build |
|
@greptile review |
1 similar comment
|
@greptile review |
|
@greptile review |
|
@greptile review |
|
@greptile review |
|
@greptile review |
|
@greptile review |
1 similar comment
|
@greptile review |
|
@greptile review |
|
/build |
Replace local power-of-two checks in CUB JIT sizing with libcudacxx math helpers. Keep existing invalid-size guards while using cuda::is_power_of_two and cuda::prev_power_of_two for EPT and block-thread selection.
Move shared CUB JIT EPT and block-thread sizing policy into cub_device.h. Reduce sum/min/max/prod/sort/argsort/cumsum operators to calls into the common helpers while preserving the RTC-only CUB probe path.
|
@greptile review |
|
@greptile review |
|
/build |
|
@greptile review |
|
@greptile review |
| std::string kernel_name = get_kernel_name_for_rank<RANK>(stride, global_kernel, pass_through_threads); | ||
| std::string cache_key = kernel_name + "_" + kernel_op_type; | ||
| std::string kernel_name = get_kernel_name_for_rank<RANK>(stride, global_kernel, pass_through_threads, block_reduces_rank); | ||
| std::string cache_key = device_cache_prefix + kernel_name + "_" + kernel_op_type; |
There was a problem hiding this comment.
Disk cubin cache key does not include GPU architecture
cache_key is built as "device_N_" + kernel_name + "_" + kernel_op_type, which encodes only the CUDA device index. The cubin stored on disk is compiled for a specific SM architecture (via get_preprocessor_options()). If the same build-directory cache is accessed from a machine with a different GPU at device index 0 (e.g. an NFS-mounted build tree shared across cluster nodes with sm_80 and sm_70 GPUs), cuModuleLoadDataEx will receive an sm_80 cubin on an sm_70 device, the driver will return an error, and CUDA_CHECK converts that to std::exit. The shmem_cache in nvrtc_get_cub_block_shmem_size was correctly updated in this PR to include nvrtc_arch in the key — the kernel cubin cache needs the same treatment.
Use CUB Block* APIs for JIT-able reductions, scans, sorts, and argsorts, including reduced-rank block kernels and fusion-safe scalar handling for surrounding MatX operators.
Query CUB temporary storage by compiling an NVRTC probe to PTX and reading the global initializer, then feed that static shared-memory usage into launch capability negotiation.
Cap CUB JIT elements-per-thread to powers of two whose total item width is at most 16 bytes, and add coverage for sum/prod/min/max/cumsum/sort/argsort plus fused expressions with fftshift, fft, and linspace-generated inputs.