Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
107a865
Enabled persistency with WorkID Query feature
Oleg-Goncharov Mar 4, 2026
caf664f
Added a struct with tunable parameters
Oleg-Goncharov Mar 4, 2026
68dbc62
Added persistency with static scheduling
Oleg-Goncharov Mar 4, 2026
051d925
Fixed test cases
Oleg-Goncharov Mar 4, 2026
2f9a299
Ready for benchmarking
Oleg-Goncharov Mar 4, 2026
c040d59
Fixed out-of-boundary error
Oleg-Goncharov Mar 4, 2026
30c28fb
Tuned kernel parameters
Oleg-Goncharov Mar 4, 2026
977168e
Refactoring
Oleg-Goncharov Mar 4, 2026
885fcb9
Refactoring 2
Oleg-Goncharov Mar 4, 2026
d787847
Refactoring 3
Oleg-Goncharov Mar 4, 2026
79c1ac2
Removed the dynamic (WorkID Query) persistency
Oleg-Goncharov Mar 5, 2026
12b8712
Ready for PR
Oleg-Goncharov Mar 5, 2026
2812d55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
f24afb2
Fixes per the review
Oleg-Goncharov Mar 6, 2026
aa484a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
f066851
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 13, 2026
74722a5
Ready for benchmark
Oleg-Goncharov Mar 13, 2026
5c570cd
Ready for benchmark - Regular kernel
Oleg-Goncharov Mar 13, 2026
c5b1f7d
Added the source code to the profiler
Oleg-Goncharov Mar 13, 2026
3edcb5d
Added constructors to Job and Block descriptors
Oleg-Goncharov Mar 13, 2026
6e00237
Removed the prefetch overlapping between jobs
Oleg-Goncharov Mar 13, 2026
274f91e
Cache tensor ID
Oleg-Goncharov Mar 13, 2026
38b7e4e
ShapeRepresentation is not a template parameter
Oleg-Goncharov Mar 13, 2026
4405255
Removed redundant fence_proxy
Oleg-Goncharov Mar 13, 2026
8cad6e6
Refactoring
Oleg-Goncharov Mar 16, 2026
c6622d4
Used mixed precision FMA
Oleg-Goncharov Mar 17, 2026
e6a737c
Added Quantize parameters
Oleg-Goncharov Mar 17, 2026
7be1136
Added the fast math branch
Oleg-Goncharov Mar 17, 2026
4c2bed5
Added the fast math to cpp test suite
Oleg-Goncharov Mar 17, 2026
e296b0b
Align tests
Oleg-Goncharov Mar 17, 2026
e63eee9
Use STS instead of generic ST
Oleg-Goncharov Mar 17, 2026
6874206
Add zero-tensor cases
Oleg-Goncharov Mar 17, 2026
a02c71c
Used LDS instead of generic LD in colwise path
Oleg-Goncharov Mar 17, 2026
4c992b0
Used LDS instead of generic LD in rowwise
Oleg-Goncharov Mar 17, 2026
8ceeed0
Ready for merge
Oleg-Goncharov Mar 17, 2026
2c20675
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 17, 2026
ef973d7
Merge branch 'moe_mxfp8_benchmark' into pr_persistent_grouped_mxfp8_k…
Oleg-Goncharov Mar 17, 2026
f119d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2026
6874935
Uncommented test cases
Oleg-Goncharov Mar 18, 2026
f3e07e5
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 18, 2026
f985c01
Added FP16 Fast math path to rowwise processing
Oleg-Goncharov Mar 18, 2026
5068556
Refactoring
Oleg-Goncharov Mar 18, 2026
6c945d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
a38888c
Fixed lint
Oleg-Goncharov Mar 18, 2026
20e354a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
3d2d1ba
Fix
Oleg-Goncharov Mar 18, 2026
ac75ea2
Fixes
Oleg-Goncharov Mar 18, 2026
3fc8a3e
Fix
Oleg-Goncharov Mar 18, 2026
62dfbd4
Fixed test suite
Oleg-Goncharov Mar 18, 2026
1b6938a
Fixed test suite
Oleg-Goncharov Mar 18, 2026
c319671
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 18, 2026
add9e9c
Fixes per the review
Oleg-Goncharov Mar 18, 2026
86abab8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
4e28663
Modifications per the review
Oleg-Goncharov Mar 19, 2026
a6a9bb6
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 19, 2026
b6b8697
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
2ae38cb
Assert the buffer size
Oleg-Goncharov Mar 19, 2026
d87c5e1
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 19, 2026
9fb2aa8
Added fast math RCP for bf16
Oleg-Goncharov Mar 23, 2026
d231218
Fast math for BF16 is now default
Oleg-Goncharov Mar 24, 2026
0db3063
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 24, 2026
d9faa2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
5e79ae7
Fixed compilation error when compiling on previous archs
Oleg-Goncharov Mar 24, 2026
9838da8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
5885049
Boundary condition fix
Oleg-Goncharov Mar 24, 2026
4fd4162
Fixed compilation error
Oleg-Goncharov Mar 24, 2026
c422ee1
Refactoring. Moved helpers to core-common
Oleg-Goncharov Mar 24, 2026
d264190
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
1ebc952
Refactoring
Oleg-Goncharov Mar 24, 2026
861e226
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
1ed8b58
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 25, 2026
8214619
Refactoring per the review
Oleg-Goncharov Mar 25, 2026
1564d2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
5aedbcc
Addressed the PR review comments
Oleg-Goncharov Mar 30, 2026
bf9ec5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2026
692ca8c
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 30, 2026
646147b
Fixed the compilation error when PTX was compiled for CUDA 13.0
Oleg-Goncharov Mar 31, 2026
737cc7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
8b09ef4
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 31, 2026
0c8a26f
Fixed pytorch extensions
Oleg-Goncharov Apr 1, 2026
ff029bf
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/test_cast_mxfp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ std::vector<std::vector<size_t>> matrix_sizes = {
{1024},
{8, 32, 1024},
{16, 8, 4, 512},
{8192, 7168},
};

std::vector<std::pair<size_t, size_t>> block_sizes = {
Expand Down
83 changes: 49 additions & 34 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ void performTest(const ProcessingMethod processing_method,

NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());

std::vector<size_t> dbias_logical_shape_vec= {num_tensors, cols};
std::vector<size_t> dbias_logical_shape_vec = {num_tensors, cols};
NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(),
dbias_logical_shape_vec.size());

Expand Down Expand Up @@ -499,11 +499,13 @@ void performTest(const ProcessingMethod processing_method,
scales_stride_colwise);
}

QuantizationConfigWrapper quant_config;

// GPU
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_group_quantize(in_group_tensor, out_group_tensor, 0);
nvte_group_quantize(in_group_tensor, out_group_tensor, quant_config, 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
Expand Down Expand Up @@ -554,6 +556,11 @@ void performTest(const ProcessingMethod processing_method,
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;

// Compare only allocated contiguous output range.
// In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back().
const size_t compare_rows = 1;
const size_t compare_cols = elts_num;

if (rowwise) {
cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost);
Expand All @@ -566,7 +573,8 @@ void performTest(const ProcessingMethod processing_method,
const size_t mismatches_elts = 32 * mismatches_scales;

compare_scaled_elts<OutputType>("rowwise_output", out_data_rowwise_ref.data(),
out_data_rowwise_h.data(), rows, cols, true, mismatches_elts);
out_data_rowwise_h.data(), compare_rows, compare_cols,
true, mismatches_elts);
}

if (colwise) {
Expand All @@ -581,7 +589,8 @@ void performTest(const ProcessingMethod processing_method,
const size_t mismatches_elts = 32 * mismatches_scales;

compare_scaled_elts<OutputType>("colwise_output", out_data_colwise_ref.data(),
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
out_data_colwise_h.data(), compare_rows, compare_cols,
false, mismatches_elts);
}

if (compute_dbias) {
Expand Down Expand Up @@ -652,9 +661,13 @@ std::vector<std::vector<size_t>> input_config = {
{VARYING_FIRST_DIM, 4, 1024,144, 128,384,0,512},
{VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
// Empty tensor in the middle of the group must not terminate the persistent work loop.
{VARYING_FIRST_DIM, 4, 512,160, 128,0,0,256},
{VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128},
};

} // namespace
Expand Down Expand Up @@ -808,6 +821,37 @@ std::string to_string(const ActivationKind activation) {
}
}

std::string MakeGroupedFusedCastMXFP8TestName(
const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));

switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}

const std::vector<size_t> input = std::get<3>(info.param);

switch (static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
}

name += "_N_" + std::to_string(input[1]);

name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]);

name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));

return name;
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GroupedFusedCastMXFP8TestSuite,
Expand All @@ -818,33 +862,4 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(input_config),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)),
[](const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));

switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}

const std::vector<size_t> input = std::get<3>(info.param);

switch(static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
};

name += "_N_" + std::to_string(input[1]);

name += "_SHAPE_" +
std::to_string(input[2]) +
"X" + std::to_string(input[3]);

name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));
return name;
});
MakeGroupedFusedCastMXFP8TestName);
2 changes: 1 addition & 1 deletion tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ constexpr size_t scale_tensor_alignment_Y_colwise = 4;
constexpr size_t scale_tensor_alignment_X_colwise = 128;

inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
return ((N + M) - 1) / M;
}

inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) {
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
}

void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
using namespace transformer_engine;

constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}

void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
Expand Down
Loading
Loading