Skip to content

Add block-level CUB JIT support#1180

Open
cliffburdick wants to merge 24 commits into
mainfrom
cburdick/cub-block-jit
Open

Add block-level CUB JIT support#1180
cliffburdick wants to merge 24 commits into
mainfrom
cburdick/cub-block-jit

Conversation

@cliffburdick
Copy link
Copy Markdown
Collaborator

Use CUB Block* APIs for JIT-able reductions, scans, sorts, and argsorts, including reduced-rank block kernels and fusion-safe scalar handling for surrounding MatX operators.

Query CUB temporary storage by compiling an NVRTC probe to PTX and reading the global initializer, then feed that static shared-memory usage into launch capability negotiation.

Cap CUB JIT elements-per-thread to powers of two whose total item width is at most 16 bytes, and add coverage for sum/prod/min/max/cumsum/sort/argsort plus fused expressions with fftshift, fft, and linspace-generated inputs.

Use CUB Block* APIs for JIT-able reductions, scans, sorts, and argsorts, including reduced-rank block kernels and fusion-safe scalar handling for surrounding MatX operators.

Query CUB temporary storage by compiling an NVRTC probe to PTX and reading the global initializer, then feed that static shared-memory usage into launch capability negotiation.

Cap CUB JIT elements-per-thread to powers of two whose total item width is at most 16 bytes, and add coverage for sum/prod/min/max/cumsum/sort/argsort plus fused expressions with fftshift, fft, and linspace-generated inputs.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 12, 2026

Greptile Summary

This PR adds block-level CUB JIT support to MatX, enabling cub::BlockReduce, cub::BlockScan, and cub::BlockRadixSort to be used within JIT-compiled CUDA kernels for reductions, prefix scans, and sorts/argsorts respectively.

  • Introduces SUPPORTS_JIT, BLOCK_REDUCES_RANK, and STATIC_SHM_SIZE operator capabilities, a new CapabilityParams<EPT, JIT> dispatch mechanism, and CUB shared-memory probing via NVRTC PTX inspection to determine TempStorage size at JIT launch-param selection time.
  • Extends sum, min, max, prod, sort, argsort, and cumsum operators with JIT paths; adds ContainsBlockReduction/IsDirectBlockReduction type traits propagated through unary and binary expression templates so scalar promotion is applied correctly at every fusion point.
  • The disk cubin cache key does not include the GPU SM architecture (nvrtc_arch), which can cause a stale cubin compiled for one architecture to be loaded on a different GPU — the shmem_cache was correctly updated to include arch in this PR, but the corresponding fix was not applied to the main kernel cache key.

Confidence Score: 3/5

Safe to merge on single-GPU setups; multi-GPU deployments with warm disk cubin caches risk loading a wrong-architecture cubin and crashing.

The PR is large (23 files, ~4000 LOC) and addresses a complex JIT + CUB integration. Many previously flagged issues (double-free in PostRun, argsort complex guard, shmem_cache arch key, PTX regex, JIT stream, etc.) have been resolved. One new P1 remains: the disk cubin cache key omits the GPU SM architecture, meaning a cubin compiled for architecture X can be silently loaded on architecture Y in multi-GPU or cross-machine scenarios, causing cuModuleLoadDataEx to fail or produce wrong results.

include/matx/core/nvrtc_helper.h — disk cubin cache key construction around line 638 needs the nvrtc_arch string appended, mirroring the fix already applied to shmem_cache.

Important Files Changed

Filename Overview
include/matx/core/nvrtc_helper.h Adds CUB shared-memory probe via NVRTC (PTX regex, shmem cache keyed by arch+type), generate_capability_params_string now passes JIT=true, kernel cache now device-keyed in-memory but disk cubin cache key still missing GPU architecture — cross-architecture cache reuse would crash.
include/matx/transforms/cub_device.h Introduces BlockReduce, BlockSort, BlockArgsort, BlockScan device-side implementations, helper index mapping utilities, and host-side GetCubBlockShmRequired. BlockScan correctly uses InclusiveSum method-template overload for CCCL CUB.
include/matx/core/get_grid_dims.h Adds get_grid_dims_block_reduce with INT_MAX guard for blocks.x; rank-preserving operators correctly use get_grid_dims_block so batch-dim sizing is correct.
include/matx/executors/jit_cuda.h JIT launch params cache now keyed by device ID+operator type; ExecWithRank correctly dispatches block_reduces_rank vs block-preserving grid dims.
include/matx/executors/cuda_executor_common.h find_best_launch_params extended to query STATIC_SHM_SIZE and combine with DYN_SHM_SIZE for occupancy check against correct per-block and per-SM limits.
include/matx/operators/sum.h CUB JIT block-reduce path added with !is_complex_v guard; PostRun now resets ptr=nullptr and prerun_done_=false preventing double-free; BLOCK_REDUCES_RANK returns true for correct grid dispatch.
include/matx/operators/sort.h CUB JIT BlockRadixSort path using CubJitMaxPowerOfTwoCollectiveEPT; BLOCK_REDUCES_RANK not set so rank-preserving grid is used correctly.
include/matx/operators/cumsum.h CUB JIT BlockScan path added with !is_complex_v guard and power-of-two critical-dim requirement.
include/matx/operators/argsort.h SUPPORTS_JIT complex guard now correctly checks !is_complex_v<key_type> instead of the output index type.
include/matx/operators/unary_operators.h Adds ContainsBlockReduction/IsDirectBlockReduction helpers; I1Cap now correctly uses CapType for block-reduction inputs and ScalarCap for elementwise inputs; JIT-generated code mirrors host logic.
include/matx/operators/binary_operators.h Adds block-reduction awareness: I1Cap/I2Cap dispatch using ContainsBlockReduction and IsDirectBlockReduction; JIT-generated code now includes IsDirectBlockReduction, fixing divergence from host code.
include/matx/executors/jit_kernel.h matxOpT1KernelBlock now branches on KernelContainsBlockReduction for block-collective indexing; separate *KernelBlockReduce variants added for rank-2/3/4.
include/matx/core/capabilities.h Adds STATIC_SHM_SIZE capability (SUM_QUERY) and BLOCK_REDUCES_RANK (OR_QUERY, default=false); combine_capabilities correctly handles SUM_QUERY using sum_identity.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Exec as JIT CUDA Executor
    participant Cache as JIT Launch Cache
    participant NVRTC as nvrtc_helper
    participant ShmCache as shmem_cache
    participant Disk as Disk Cubin Cache
    participant GPU as GPU (cuLaunchKernel)

    User->>Exec: "ExecWithRank<RANK>(op, stream)"
    Exec->>Cache: GetJITLaunchParamsCacheKey(device_id + ":" + kernel_op_type)
    alt Cache HIT
        Cache-->>Exec: blocks, threads, shm_size, global_kernel
    else Cache MISS
        Exec->>NVRTC: find_best_launch_params(op, ept, block_size)
        NVRTC->>ShmCache: nvrtc_get_cub_block_shmem_size(arch, algo, type, ept, bs)
        alt shmem_cache HIT
            ShmCache-->>NVRTC: cached shm_size
        else shmem_cache MISS
            NVRTC->>NVRTC: make_cub_shmem_probe_source() to PTX compile
            NVRTC->>NVRTC: regex extract temp_storage_size from PTX
            NVRTC-->>ShmCache: store keyed by arch+algo+type+ept+bs
        end
        NVRTC->>Exec: best_ept, shm_size, block_size
        Exec->>NVRTC: nvrtc_compile_and_run(kernel_name, kernel_op_type)
        NVRTC->>Disk: lookup cubin by device_N_+kernel_name+kernel_op_type
        alt Disk HIT
            Disk-->>NVRTC: cubin bytes
        else Disk MISS
            NVRTC->>NVRTC: NVRTC compile to NVJITLINK to cubin
            NVRTC-->>Disk: store cubin
        end
        NVRTC->>NVRTC: cuModuleLoadDataEx to cuModuleGetFunction
        NVRTC-->>Cache: store launch params
    end
    alt "block_reduces_rank == true"
        Exec->>Exec: get_grid_dims_block_reduce (batch dims only)
    else
        Exec->>Exec: get_grid_dims_block (rank-preserving)
    end
    Exec->>GPU: cuLaunchKernel(func, blocks, threads, shm, stream, args)
    GPU->>GPU: matxOpT1KernelBlock / matxOpT2KernelBlockReduce
    note over GPU: BlockReduce/BlockScan/BlockSort via CUB
    GPU-->>User: results written to output tensor
Loading

Reviews (46): Last reviewed commit: "Address Greptile CUB JIT review issues" | Re-trigger Greptile

Comment thread include/matx/core/capabilities.h
Comment thread include/matx/transforms/cub_device.h
Use explicit SUM_QUERY identities for shared-memory aggregation, make non-JIT CUB shared-memory queries fail loudly, and make the NVRTC temp-storage probe robust to signed or unsigned PTX declarations.

Also align generated binary JIT capability routing with direct block-reduction detection so fused CUB reductions keep block-level thread participation while non-reduction operands use scalar indexing.
Evaluate direct block-reduction RHS operators with the active CapType so CUB sees the negotiated EPT and block size. Keep ScalarCap only for the output write, where the block aggregate is stored by a single thread.
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@coveralls
Copy link
Copy Markdown

coveralls commented May 13, 2026

Coverage Status

Coverage is 94.08%cburdick/cub-block-jit into main. No base build found for main.

Pass the executor stream through nvrtc_compile_and_run and into both cuLaunchKernel paths so CUDAJITExecutor(stream) preserves caller stream ordering. Add a nonblocking-stream regression test that would fail if the JIT kernel launched on stream 0.
Comment thread include/matx/core/get_grid_dims.h Outdated
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

Check the block-reduction batch grid count before narrowing it into dim3::x, and add grid-dim coverage for the normal and oversized block-reduction cases.
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

Normalize CUB shared-memory probe cache keys for algorithms whose temp storage does not depend on EPT, clarify the dynamic shared-memory limit name/logging, and make JIT launch-parameter caches inline so they are shared across translation units.
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

Comment thread include/matx/operators/sum.h
Comment thread include/matx/executors/cuda_executor_common.h Outdated
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

Comment thread include/matx/operators/sum.h
Comment thread include/matx/operators/argsort.h
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

1 similar comment
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

Comment thread include/matx/transforms/cub_device.h Outdated
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

Comment thread include/matx/operators/cumsum.h Outdated
Comment thread include/matx/executors/jit_cuda.h
Comment thread include/matx/transforms/cub_device.h
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

1 similar comment
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

Comment thread include/matx/operators/cumsum.h Outdated
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

Comment thread include/matx/core/get_grid_dims.h
Replace local power-of-two checks in CUB JIT sizing with libcudacxx math helpers. Keep existing invalid-size guards while using cuda::is_power_of_two and cuda::prev_power_of_two for EPT and block-thread selection.
Move shared CUB JIT EPT and block-thread sizing policy into cub_device.h. Reduce sum/min/max/prod/sort/argsort/cumsum operators to calls into the common helpers while preserving the RTC-only CUB probe path.
Comment thread include/matx/core/nvrtc_helper.h
Comment thread include/matx/transforms/cub_device.h
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

Comment thread include/matx/operators/unary_operators.h
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

Comment thread include/matx/executors/jit_kernel.h Outdated
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

@cliffburdick
Copy link
Copy Markdown
Collaborator Author

@greptile review

std::string kernel_name = get_kernel_name_for_rank<RANK>(stride, global_kernel, pass_through_threads);
std::string cache_key = kernel_name + "_" + kernel_op_type;
std::string kernel_name = get_kernel_name_for_rank<RANK>(stride, global_kernel, pass_through_threads, block_reduces_rank);
std::string cache_key = device_cache_prefix + kernel_name + "_" + kernel_op_type;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Disk cubin cache key does not include GPU architecture

cache_key is built as "device_N_" + kernel_name + "_" + kernel_op_type, which encodes only the CUDA device index. The cubin stored on disk is compiled for a specific SM architecture (via get_preprocessor_options()). If the same build-directory cache is accessed from a machine with a different GPU at device index 0 (e.g. an NFS-mounted build tree shared across cluster nodes with sm_80 and sm_70 GPUs), cuModuleLoadDataEx will receive an sm_80 cubin on an sm_70 device, the driver will return an error, and CUDA_CHECK converts that to std::exit. The shmem_cache in nvrtc_get_cub_block_shmem_size was correctly updated in this PR to include nvrtc_arch in the key — the kernel cubin cache needs the same treatment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants