Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
28e5f53
Refactor to group_sizes per tensor
jberchtold-nvidia Mar 9, 2026
4a57485
Support first_dims and last_dims instead of a single group_sizes per
jberchtold-nvidia Mar 10, 2026
345d940
Refactor GMM FFIs to store static attrs as structs
jberchtold-nvidia Mar 10, 2026
ed9c8e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
ed0deaf
Cleanup C++ v2 FFI
jberchtold-nvidia Mar 10, 2026
88bb7da
Fix int64 workspace usage
jberchtold-nvidia Mar 10, 2026
60312c8
Address greptile comments
jberchtold-nvidia Mar 10, 2026
025f598
Refactor wgrad-specific checks to be generic for GMM in gemm.py
jberchtold-nvidia Mar 10, 2026
089e530
Refactor XLA FFI struct setup
jberchtold-nvidia Mar 10, 2026
8ad2294
Fix edge case in TE v1 GMM
jberchtold-nvidia Mar 11, 2026
bac092d
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 11, 2026
4ff5d1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
0cb7289
Fix issues on Hopper
jberchtold-nvidia Mar 11, 2026
37d300a
Merge remote-trackint commit --amend -sg branch 'github-upstream/main…
jberchtold-nvidia Mar 11, 2026
cc236ad
Refactor
jberchtold-nvidia Mar 12, 2026
2b84dfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
2902eb2
Merge remote-tracking branch 'github-upstream/main' into jberchtold/g…
jberchtold-nvidia Mar 17, 2026
bee7f3b
Address comments
jberchtold-nvidia Mar 23, 2026
d9b9c44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
ef0d498
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia Mar 23, 2026
9438478
Lint
jberchtold-nvidia Mar 23, 2026
09dfd9c
Fixes for Hopper
jberchtold-nvidia Mar 24, 2026
e25538e
Address review comments
jberchtold-nvidia Mar 24, 2026
78674e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
d5229e2
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia Mar 24, 2026
a7b11a3
Grouped quantization test fixes
jberchtold-nvidia Mar 25, 2026
cc0b33c
Merge branch 'main' into jberchtold/gmm-refactor
jberchtold-nvidia Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
GroupedNoScaleTensor,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
Expand Down Expand Up @@ -150,8 +151,13 @@ def assert_dequantized_grouped_scaled_tensor(
a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
if isinstance(a, GroupedScaledTensor1x):
assert a.group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
group_sizes = (
a.first_dims
if a.first_dims is not None
else jnp.ones(a.original_shape[0], dtype=jnp.int32)
)
assert group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
dq_a = a.dequantize()
for dq_a_i, b_i in zip(dq_a, b):
if len(dq_a_i) == 0:
Expand Down Expand Up @@ -1787,13 +1793,18 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

# jitting grouped_gemm
lhs_tensor = GroupedNoScaleTensor(
data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape
)
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs,
rhs,
group_sizes,
contracting_dims,
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
use_async_d2h_group_sizes=True,
)

Expand Down Expand Up @@ -1825,8 +1836,17 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

lhs_tensor = GroupedNoScaleTensor(
data=lhs, amax=None, first_dims=group_sizes, last_dims=None, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, amax=None, first_dims=None, last_dims=None, original_shape=rhs.shape
)
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)

allclose_dtype = jnp.float8_e4m3fn
Expand Down
Loading
Loading