diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 5e73675f4f..4241ada3ba 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable(test_operator test_memset.cu test_splits_to_offsets.cu test_multi_cast_transpose.cu + test_multi_tensor_adam_mxfp8.cu test_multi_padding.cu test_multi_unpadding.cu test_causal_softmax.cu diff --git a/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu b/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu new file mode 100644 index 0000000000..d0eed33781 --- /dev/null +++ b/tests/cpp/operator/test_multi_tensor_adam_mxfp8.cu @@ -0,0 +1,266 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +uint8_t fp8_to_u8(fp8e4m3 v) { + uint8_t out = 0; + std::memcpy(&out, &v, sizeof(uint8_t)); + return out; +} + +uint8_t fp8_to_u8(fp8e5m2 v) { + uint8_t out = 0; + std::memcpy(&out, &v, sizeof(uint8_t)); + return out; +} + +void run_mxfp8_adam_test(DType fp8_dtype) { + const std::vector shape1{64, 128}; + const std::vector shape2{32, 64}; + const float lr = 1e-3f; + const float beta1 = 0.9f; + const float beta2 = 0.999f; + const float eps = 1e-8f; + const int step = 1; + const int mode = 1; + const int bias_correction = 1; + const float weight_decay = 0.0f; + + // Run with 25 tensors > 24[MXFP8_MAX_TENSORS] to check + // the chunking logic + const size_t tensor_count = 25; + std::vector> shapes; + shapes.reserve(tensor_count); + for (size_t i = 0; i < tensor_count; ++i) { + shapes.push_back((i % 2 == 0) ? shape1 : shape2); + } + + std::vector names; + names.reserve(tensor_count * 11); + std::vector g; + std::vector p; + std::vector m; + std::vector v; + std::vector p_ref_t; + std::vector m_ref_t; + std::vector v_ref_t; + std::vector q_ref; + std::vector dq; + std::vector dq_ref; + std::vector q; + g.reserve(tensor_count); + p.reserve(tensor_count); + m.reserve(tensor_count); + v.reserve(tensor_count); + p_ref_t.reserve(tensor_count); + m_ref_t.reserve(tensor_count); + v_ref_t.reserve(tensor_count); + q_ref.reserve(tensor_count); + dq.reserve(tensor_count); + dq_ref.reserve(tensor_count); + q.reserve(tensor_count); + + for (size_t i = 0; i < tensor_count; ++i) { + const std::vector &shape = shapes[i]; + names.push_back("g" + std::to_string(i)); + g.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("p" + std::to_string(i)); + p.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("m" + std::to_string(i)); + m.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("v" + std::to_string(i)); + v.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + + fillUniform(&g.back()); + fillUniform(&p.back()); + std::fill_n(m.back().rowwise_cpu_dptr(), product(m.back().rowwise_shape()), 0.0f); + std::fill_n(v.back().rowwise_cpu_dptr(), product(v.back().rowwise_shape()), 0.0f); + m.back().from_cpu(); + v.back().from_cpu(); + + names.push_back("p_ref_" + std::to_string(i)); + p_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("m_ref_" + std::to_string(i)); + m_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("v_ref_" + std::to_string(i)); + v_ref_t.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + const size_t n = shape[0] * shape[1]; + std::memcpy(p_ref_t.back().rowwise_cpu_dptr(), p.back().rowwise_cpu_dptr(), + n * sizeof(float)); + std::memcpy(m_ref_t.back().rowwise_cpu_dptr(), m.back().rowwise_cpu_dptr(), + n * sizeof(float)); + std::memcpy(v_ref_t.back().rowwise_cpu_dptr(), v.back().rowwise_cpu_dptr(), + n * sizeof(float)); + p_ref_t.back().from_cpu(); + m_ref_t.back().from_cpu(); + v_ref_t.back().from_cpu(); + + names.push_back("q_ref_" + std::to_string(i)); + q_ref.emplace_back(names.back().c_str(), shape, fp8_dtype, true, true, NVTE_MXFP8_1D_SCALING); + q_ref.back().set_with_gemm_swizzled_scales(false); + + names.push_back("dq" + std::to_string(i)); + dq.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + names.push_back("dq_ref_" + std::to_string(i)); + dq_ref.emplace_back(names.back().c_str(), shape, DType::kFloat32, true, false); + + names.push_back("q" + std::to_string(i)); + q.emplace_back(names.back().c_str(), shape, fp8_dtype, true, true, NVTE_MXFP8_1D_SCALING); + q.back().set_with_gemm_swizzled_scales(false); + } + + Tensor noop("noop", std::vector{1}, DType::kInt32, true, false); + int zero = 0; + std::memcpy(noop.rowwise_cpu_dptr(), &zero, sizeof(int)); + noop.from_cpu(); + + std::vector> lists(8); + std::vector extra_wrappers; + extra_wrappers.reserve(tensor_count * 4); + + auto add_tensor = [&](Tensor &g, Tensor &p, Tensor &m, Tensor &v, Tensor &q) { + lists[0].push_back(g.data()); + lists[1].push_back(p.data()); + lists[2].push_back(m.data()); + lists[3].push_back(v.data()); + + extra_wrappers.emplace_back(q.rowwise_dptr(), q.rowwise_shape(), fp8_dtype); + lists[4].push_back(extra_wrappers.back().data()); + extra_wrappers.emplace_back(q.columnwise_dptr(), q.columnwise_shape(), fp8_dtype); + lists[5].push_back(extra_wrappers.back().data()); + extra_wrappers.emplace_back(q.rowwise_scale_inv_dptr(), q.rowwise_scale_inv_shape(), + DType::kByte); + lists[6].push_back(extra_wrappers.back().data()); + extra_wrappers.emplace_back(q.columnwise_scale_inv_dptr(), q.columnwise_scale_inv_shape(), + DType::kByte); + lists[7].push_back(extra_wrappers.back().data()); + }; + + for (size_t i = 0; i < tensor_count; ++i) { + add_tensor(g[i], p[i], m[i], v[i], q[i]); + } + + std::vector list_ptrs; + list_ptrs.reserve(lists.size()); + for (auto &l : lists) { + list_ptrs.push_back(l.data()); + } + + nvte_multi_tensor_adam_mxfp8_cuda(65536, noop.data(), list_ptrs.data(), lists.size(), + lists[0].size(), static_cast(fp8_dtype), lr, beta1, + beta2, eps, step, mode, bias_correction, weight_decay, 0); + + std::vector> ref_lists(4); + for (size_t i = 0; i < tensor_count; ++i) { + ref_lists[0].push_back(g[i].data()); + ref_lists[1].push_back(p_ref_t[i].data()); + ref_lists[2].push_back(m_ref_t[i].data()); + ref_lists[3].push_back(v_ref_t[i].data()); + } + std::vector ref_list_ptrs; + ref_list_ptrs.reserve(ref_lists.size()); + for (auto &l : ref_lists) { + ref_list_ptrs.push_back(l.data()); + } + + nvte_multi_tensor_adam_cuda(65536, noop.data(), ref_list_ptrs.data(), ref_lists.size(), + ref_lists[0].size(), lr, beta1, beta2, eps, step, mode, + bias_correction, weight_decay, 0); + + for (size_t i = 0; i < tensor_count; ++i) { + nvte_quantize(p_ref_t[i].data(), q_ref[i].data(), 0); + nvte_dequantize(q[i].data(), dq[i].data(), 0); + nvte_dequantize(q_ref[i].data(), dq_ref[i].data(), 0); + } + + cudaDeviceSynchronize(); + + for (size_t i = 0; i < tensor_count; ++i) { + q[i].to_cpu(); + p[i].to_cpu(); + m[i].to_cpu(); + v[i].to_cpu(); + q_ref[i].to_cpu(); + dq[i].to_cpu(); + dq_ref[i].to_cpu(); + p_ref_t[i].to_cpu(); + m_ref_t[i].to_cpu(); + v_ref_t[i].to_cpu(); + } + + for (size_t i = 0; i < lists[0].size(); ++i) { + const Tensor &g_i = g[i]; + const Tensor &p_i = p[i]; + const Tensor &m_i = m[i]; + const Tensor &v_i = v[i]; + Tensor &q_i = q[i]; + const Tensor &p_ref_t_i = p_ref_t[i]; + const Tensor &m_ref_t_i = m_ref_t[i]; + const Tensor &v_ref_t_i = v_ref_t[i]; + Tensor &q_ref_i = q_ref[i]; + + compareResults("p", p_i, p_ref_t_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, 0); + compareResults("m", m_i, m_ref_t_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, 0); + compareResults("v", v_i, v_ref_t_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, 0); + + const Tensor &dq_i = dq[i]; + const Tensor &dq_ref_i = dq_ref[i]; + compareResults("dequantized", dq_i, dq_ref_i.rowwise_cpu_dptr(), true, 0.0, 0.0, true, + 0); + + const size_t rs = q_i.rowwise_scale_inv_shape().data[1]; + const size_t cs = q_i.columnwise_scale_inv_shape().data[1]; + const size_t rowwise_scale_size = q_i.rowwise_scale_inv_shape().data[0] * rs; + const size_t colwise_scale_size = q_i.columnwise_scale_inv_shape().data[0] * cs; + compareResults("rowwise_scale", q_i.rowwise_cpu_scale_inv_ptr(), + q_ref_i.rowwise_cpu_scale_inv_ptr(), rowwise_scale_size, 0.0f); + compareResults("colwise_scale", q_i.columnwise_cpu_scale_inv_ptr(), + q_ref_i.columnwise_cpu_scale_inv_ptr(), colwise_scale_size, 0.0f); + + uint8_t *row_data = nullptr; + uint8_t *col_data = nullptr; + uint8_t *row_data_ref = nullptr; + uint8_t *col_data_ref = nullptr; + if (fp8_dtype == DType::kFloat8E4M3) { + row_data = reinterpret_cast(q_i.rowwise_cpu_dptr()); + col_data = reinterpret_cast(q_i.columnwise_cpu_dptr()); + row_data_ref = reinterpret_cast(q_ref_i.rowwise_cpu_dptr()); + col_data_ref = reinterpret_cast(q_ref_i.columnwise_cpu_dptr()); + } else { + row_data = reinterpret_cast(q_i.rowwise_cpu_dptr()); + col_data = reinterpret_cast(q_i.columnwise_cpu_dptr()); + row_data_ref = reinterpret_cast(q_ref_i.rowwise_cpu_dptr()); + col_data_ref = reinterpret_cast(q_ref_i.columnwise_cpu_dptr()); + } + const size_t data_size = q_i.rowwise_shape().data[0] * q_i.rowwise_shape().data[1]; + compareResults("rowwise_data", row_data, row_data_ref, data_size, 0.0f); + compareResults("colwise_data", col_data, col_data_ref, data_size, 0.0f); + } +} + +} // namespace + +TEST(MultiTensorAdamMXFP8, E4M3) { run_mxfp8_adam_test(DType::kFloat8E4M3); } + +TEST(MultiTensorAdamMXFP8, E5M2) { run_mxfp8_adam_test(DType::kFloat8E5M2); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 927407f478..eab181fa82 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -200,6 +200,16 @@ class Tensor { return tensor_.get_columnwise_data().data_ptr; } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().data_ptr; + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + template T *rowwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py index c39957cf13..0b345814d7 100644 --- a/tests/pytorch/distributed/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -36,6 +36,12 @@ def get_recipe_from_string(recipe): SEQ_LEN = 32 BATCH_PER_RANK = 2 NUM_STEPS = 3 +LOCAL_RANK = None + + +def dist_print(msg): + if LOCAL_RANK == 0: + print(msg) def save_custom_attrs(module): @@ -151,6 +157,8 @@ def test_fused_adam_fp8_master_weights(recipe=None): - Training loop completes without error - DTensor wrapping and QuantizedTensor local tensors are preserved """ + global LOCAL_RANK + LOCAL_RANK = int(os.environ["LOCAL_RANK"]) world_size, _, device = _setup() model = _build_model(fp8_init=True, recipe=recipe) @@ -183,7 +191,7 @@ def test_fused_adam_fp8_master_weights(recipe=None): loss = F.mse_loss(output, target) loss.backward() optimizer.step() - + dist_print(f"Step {step} completed with loss {loss.item()}") # Verify optimizer states for param in model.parameters(): state = optimizer.state[param] @@ -677,6 +685,98 @@ def test_dcp_output_parity(recipe=None, async_save=False): dist.destroy_process_group() +def test_benchmark_optimizer_step(recipe=None): + """Benchmark per-iteration timings for FusedAdam + FSDP2 + MXFP8. + + Reports forward, backward, and optimizer.step() times separately + using CUDA events for accurate GPU measurement. + """ + global LOCAL_RANK + LOCAL_RANK = int(os.environ["LOCAL_RANK"]) + world_size, _, device = _setup() + + WARMUP_STEPS = 5 + BENCH_STEPS = 200 + + model = _build_model(fp8_init=True, recipe=recipe) + model = _shard_model(model, world_size) + + optimizer = te.optimizers.FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + + fwd_times = [] + bwd_times = [] + opt_times = [] + total_times = [] + + for step in range(WARMUP_STEPS + BENCH_STEPS): + evt_start = torch.cuda.Event(enable_timing=True) + evt_fwd = torch.cuda.Event(enable_timing=True) + evt_bwd = torch.cuda.Event(enable_timing=True) + evt_opt = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + evt_start.record() + + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + output = model(x) + loss = F.mse_loss(output, target) + evt_fwd.record() + + loss.backward() + evt_bwd.record() + + optimizer.step() + evt_opt.record() + + torch.cuda.synchronize() + + if step >= WARMUP_STEPS: + fwd_times.append(evt_start.elapsed_time(evt_fwd)) + bwd_times.append(evt_fwd.elapsed_time(evt_bwd)) + opt_times.append(evt_bwd.elapsed_time(evt_opt)) + total_times.append(evt_start.elapsed_time(evt_opt)) + + if LOCAL_RANK == 0: + import statistics + + def _stats(name, times): + avg = statistics.mean(times) + med = statistics.median(times) + mn = min(times) + mx = max(times) + std = statistics.stdev(times) if len(times) > 1 else 0.0 + print( + f" {name:12s}: avg={avg:8.3f}ms med={med:8.3f}ms " + f"min={mn:8.3f}ms max={mx:8.3f}ms std={std:7.3f}ms" + ) + + print(f"\n{'=' * 72}") + print(f"Benchmark: {BENCH_STEPS} iterations (after {WARMUP_STEPS} warmup)") + print( + f"Model: TransformerLayer(h={HIDDEN_SIZE}, ffn={FFN_HIDDEN_SIZE}, " + f"heads={NUM_ATTENTION_HEADS}) x {NUM_LAYERS}" + ) + print(f"Recipe: {type(recipe).__name__}") + print(f"World size: {world_size}") + print(f"{'=' * 72}") + _stats("Forward", fwd_times) + _stats("Backward", bwd_times) + _stats("Optim step", opt_times) + _stats("Total", total_times) + print(f"{'=' * 72}\n") + + dist.destroy_process_group() + + TESTS = { "fused_adam_fp8_master_weights": test_fused_adam_fp8_master_weights, "fused_adam_fp8_master_weights_no_meta": test_fused_adam_fp8_master_weights_no_meta, @@ -687,6 +787,7 @@ def test_dcp_output_parity(recipe=None, async_save=False): "dcp_output_parity": functools.partial(test_dcp_output_parity, async_save=False), "dcp_output_parity_async": functools.partial(test_dcp_output_parity, async_save=True), "safetensors_fp32_export": test_safetensors_fp32_export, + "benchmark_optimizer_step": test_benchmark_optimizer_step, } diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 60d7cd2023..a840ba7912 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -33,9 +33,9 @@ def dist_print(msg): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads") - parser.add_argument("--head-dim", type=int, default=64, help="Attention head size") + parser.add_argument("--head-dim", type=int, default=256, help="Attention head size") parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input") - parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input") + parser.add_argument("--seq-length", type=int, default=2048, help="Sequence length of input") parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.") parser.add_argument( "--fp8-init", diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 02e45d99cb..ed70c3cb9f 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -166,6 +166,12 @@ def test_fsdp2_fused_adam_bf16_store_param_remainders(fp_recipe): _run_fused_adam_test("fused_adam_bf16_store_param_remainders", fp_recipe) +@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") +def test_fsdp2_benchmark_optimizer_step(fp_recipe): + """Benchmark per-iteration timings for FusedAdam + FSDP2.""" + _run_fused_adam_test("benchmark_optimizer_step", fp_recipe) + + @pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") def test_fsdp2_dcp_output_parity(fp_recipe): """DCP save/load round-trip into a fresh model produces identical outputs.""" diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 09ab260f15..a50be47b95 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -149,6 +149,43 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float weight_decay, const NVTEDType fp8_dtype, cudaStream_t stream); +/*! \brief Compute and apply gradient update to parameters for Adam optimizer + * when model parameters are in MXFP8 precision. + * + * The update is applied to FP32 master parameters, then the master + * parameters are quantized to MXFP8 rowwise and columnwise data + * (both are always required). + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in,out] tensor_lists 2D array of input tensors with 8 lists in order: + * (0) gradients, (1) FP32 master params, (2) first moment, + * (3) second moment, (4) rowwise MXFP8 data, + * (5) columnwise MXFP8 data, (6) rowwise scale-inv, + * (7) columnwise scale-inv. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. Must be 8. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] fp8_dtype MXFP8 element type for quantization (E4M3/E5M2). + * \param[in] lr Learning rate. + * \param[in] beta1 Coefficient for first moment of gradient. + * \param[in] beta2 Coefficient for second moment of gradient. + * \param[in] epsilon Term added to the denominator for numerical stability. + * \param[in] step Iteration counter. + * \param[in] mode Whether to use AdamW (L2 penalty applied to params). + * \param[in] bias_correction Whether to apply correction factor for moment estimates. + * \param[in] weight_decay L2 penalty for weight decay. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_adam_mxfp8_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, const NVTEDType fp8_dtype, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay, + cudaStream_t stream); + /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support and LR scheduling. * diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 29a073be84..8fc90de949 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -9,6 +9,11 @@ #include #include +#include + +#include "../common.h" +#include "../util/math.h" +#include "../util/ptx.cuh" #include "../utils.cuh" #include "multi_tensor_apply.cuh" @@ -27,6 +32,7 @@ typedef enum { using MATH_T = float; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using e8m0_t = transformer_engine::e8m0_t; template struct is_fp8 : std::false_type {}; @@ -49,6 +55,31 @@ struct FP8Data { template <> struct FP8Data {}; +template +__device__ __forceinline__ void adam_update(T &r_g, T &r_p, T &r_m, T &r_v, const float beta1, + const float beta2, const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + if (mode == ADAM_MODE_0) { // L2 + r_g = r_g + (decay * r_p); + r_m = beta1 * r_m + (1 - beta1) * r_g; + r_v = beta2 * r_v + (1 - beta2) * r_g * r_g; + T next_m_unbiased = r_m / beta1_correction; + T next_v_unbiased = r_v / beta2_correction; + T denom = sqrtf(next_v_unbiased) + epsilon; + T update = next_m_unbiased / denom; + r_p = r_p - (lr * update); + } else { // weight decay + r_m = beta1 * r_m + (1 - beta1) * r_g; + r_v = beta2 * r_v + (1 - beta2) * r_g * r_g; + T next_m_unbiased = r_m / beta1_correction; + T next_v_unbiased = r_v / beta2_correction; + T denom = sqrtf(next_v_unbiased) + epsilon; + T update = (next_m_unbiased / denom) + (decay * r_p); + r_p = r_p - (lr * update); + } +} + template struct AdamFunctorMaster { static constexpr bool is_fp8_type = is_fp8::value; @@ -122,24 +153,8 @@ struct AdamFunctorMaster { } #pragma unroll for (int ii = 0; ii < ILP; ii++) { - if (mode == ADAM_MODE_0) { // L2 - r_g[ii] = r_g[ii] + (decay * r_p[ii]); - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = next_m_unbiased / denom; - r_p[ii] = r_p[ii] - (lr * update); - } else { // weight decay - r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); - r_p[ii] = r_p[ii] - (lr * update); - } + adam_update(r_g[ii], r_p[ii], r_m[ii], r_v[ii], beta1, beta2, beta1_correction, + beta2_correction, epsilon, lr, mode, decay); } #pragma unroll @@ -572,6 +587,192 @@ struct AdamCapturableMasterFunctor { } }; +template +__device__ __forceinline__ FP8_T cast_to_fp8(float x) { + return static_cast(x); +} + +__device__ __forceinline__ float fp8_max_norm_rcp(uint8_t fp8_dtype) { + if (fp8_dtype == static_cast(transformer_engine::DType::kFloat8E4M3)) { + return transformer_engine::Quantized_Limits::max_norm_rcp; + } + return transformer_engine::Quantized_Limits::max_norm_rcp; +} + +template +__global__ void adam_mxfp8_fused_kernel(int64_t chunk_size, volatile int *noop_gmem, + MXFP8TensorListMetadata tl, float beta1, float beta2, + float beta1_correction, float beta2_correction, + float epsilon, float lr, int mode, float weight_decay) { + if (noop_gmem != nullptr && *noop_gmem == 1) { + return; + } + + const int tensor_idx = tl.block_to_tensor[blockIdx.x]; + const int start_tile = tl.block_to_tile[blockIdx.x]; + const int64_t rows_val = tl.rows[tensor_idx]; + const int64_t cols_val = tl.cols[tensor_idx]; + if (rows_val == 0 || cols_val == 0) { + return; + } + + const int64_t tiles_per_row = (cols_val + MXFP8_TILE - 1) / MXFP8_TILE; + const int tiles_y = static_cast((rows_val + MXFP8_TILE - 1) / MXFP8_TILE); + const int total_tiles = tiles_y * static_cast(tiles_per_row); + const int tiles_per_block = max(1, static_cast(chunk_size / MXFP8_TILE_ELEMS)); + const int end_tile = min(start_tile + tiles_per_block, total_tiles); + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_idx]); + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_idx]); + MOMENT_T *m = reinterpret_cast(tl.addresses[2][tensor_idx]); + MOMENT_T *v = reinterpret_cast(tl.addresses[3][tensor_idx]); + + auto *rowwise_data = reinterpret_cast(tl.addresses[4][tensor_idx]); + auto *colwise_data = reinterpret_cast(tl.addresses[5][tensor_idx]); + auto *rowwise_scale_inv = reinterpret_cast(tl.addresses[6][tensor_idx]); + auto *colwise_scale_inv = reinterpret_cast(tl.addresses[7][tensor_idx]); + + constexpr int64_t kRowwiseScaleAlign = 4; + const int64_t row_stride = DIVUP_TO_MULTIPLE(tiles_per_row, kRowwiseScaleAlign); + constexpr int64_t kColwiseScaleAlign = 128; + const int64_t col_stride = DIVUP_TO_MULTIPLE(cols_val, kColwiseScaleAlign); + const uint8_t dtype = tl.fp8_dtype[tensor_idx]; + const auto adam_mode = static_cast(mode); + const float max_norm_rcp = fp8_max_norm_rcp(dtype); + + constexpr int NUM_WARPS = MXFP8_BLOCK_THREADS / THREADS_PER_WARP; + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int lane_id = threadIdx.x % THREADS_PER_WARP; + + __shared__ float col_partial[NUM_WARPS][MXFP8_TILE]; + + for (int tile_idx = start_tile; tile_idx < end_tile; ++tile_idx) { + const int64_t tile_row = tile_idx / tiles_per_row; + const int64_t tile_col = tile_idx % tiles_per_row; + const int64_t row_base = tile_row * MXFP8_TILE; + const int64_t col_base = tile_col * MXFP8_TILE; + + // ── Adam update: keep r_p in registers, write only m/v ─────────── + float r_p[ILP]; + float abs_p[ILP]; + index_t idx[ILP]; + bool valid[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + const int t = threadIdx.x + ii * blockDim.x; + const int local_r = t / MXFP8_TILE; + const int local_c = t % MXFP8_TILE; + const int64_t r = row_base + local_r; + const int64_t c = col_base + local_c; + valid[ii] = (t < MXFP8_TILE_ELEMS && r < rows_val && c < cols_val); + if (valid[ii]) { + idx[ii] = static_cast(r * cols_val + c); + float r_g = static_cast(g[idx[ii]]); + r_p[ii] = static_cast(p[idx[ii]]); + float r_m = static_cast(m[idx[ii]]); + float r_v = static_cast(v[idx[ii]]); + transformer_engine::multi_tensor_adam::adam_update(r_g, r_p[ii], r_m, r_v, beta1, beta2, + beta1_correction, beta2_correction, + epsilon, lr, adam_mode, weight_decay); + m[idx[ii]] = static_cast(r_m); + v[idx[ii]] = static_cast(r_v); + abs_p[ii] = fabsf(r_p[ii]); + } else { + r_p[ii] = 0.0f; + abs_p[ii] = 0.0f; + } + } + + // ── Row amax via warp shuffles (no sync needed) ────────────────── + float row_amax[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + row_amax[ii] = abs_p[ii]; + for (int mask = 16; mask >= 1; mask >>= 1) { + row_amax[ii] = fmaxf(row_amax[ii], __shfl_xor_sync(0xFFFFFFFF, row_amax[ii], mask)); + } + } + + // ── Col amax via cross-warp shared memory reduction ────────────── + float col_local = fmaxf(fmaxf(abs_p[0], abs_p[1]), fmaxf(abs_p[2], abs_p[3])); + col_partial[warp_id][lane_id] = col_local; + __syncthreads(); + + float col_amax; + if (warp_id == 0) { + col_amax = col_partial[0][lane_id]; +#pragma unroll + for (int w = 1; w < NUM_WARPS; w++) { + col_amax = fmaxf(col_amax, col_partial[w][lane_id]); + } + col_partial[0][lane_id] = col_amax; + } + __syncthreads(); + col_amax = col_partial[0][lane_id]; + + // ── Compute scales from registers (no global read) ─────────────── + ::transformer_engine::e8m0_t row_biased[ILP]; + float rsi[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + row_biased[ii] = transformer_engine::ptx::float_to_e8m0(row_amax[ii] * max_norm_rcp); + rsi[ii] = transformer_engine::ptx::exp2f_rcp(row_biased[ii]); + } + const ::transformer_engine::e8m0_t col_biased_e8m0 = + transformer_engine::ptx::float_to_e8m0(col_amax * max_norm_rcp); + const float csi = transformer_engine::ptx::exp2f_rcp(col_biased_e8m0); + + // ── Write scale_inv to global (for forward pass) ───────────────── + if (lane_id == 0) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + const int64_t row = row_base + warp_id + ii * NUM_WARPS; + if (row < rows_val) { + rowwise_scale_inv[static_cast(row * row_stride + tile_col)] = + reinterpret_cast(row_biased[ii]); + } + } + } + if (warp_id == 0) { + const int64_t col = col_base + lane_id; + if (col < cols_val) { + colwise_scale_inv[static_cast(tile_row * col_stride + col)] = + reinterpret_cast(col_biased_e8m0); + } + } + + // ── Write p + quantize to MXFP8 (r_p still in registers) ──────── +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (valid[ii]) { + p[idx[ii]] = static_cast(r_p[ii]); + if (dtype == static_cast(transformer_engine::DType::kFloat8E4M3)) { + reinterpret_cast(rowwise_data)[idx[ii]] = + cast_to_fp8(r_p[ii] * rsi[ii]); + reinterpret_cast(colwise_data)[idx[ii]] = cast_to_fp8(r_p[ii] * csi); + } else { + reinterpret_cast(rowwise_data)[idx[ii]] = + cast_to_fp8(r_p[ii] * rsi[ii]); + reinterpret_cast(colwise_data)[idx[ii]] = cast_to_fp8(r_p[ii] * csi); + } + } + } + } +} + +inline bool requires_64bit_indexing(const std::vector> &tensor_lists) { + const size_t num_tensor_lists = tensor_lists.size(); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 0; i < num_tensor_lists; ++i) { + for (size_t j = 0; j < num_tensors_per_list; ++j) { + if (tensor_lists[i][j]->numel() >= INT_MAX) { + return true; + } + } + } + return false; +} + void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, @@ -624,25 +825,13 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } - // Check if 64-bit indices are required - bool requires_64bit_indexing = false; - for (size_t i = 0; i < num_tensor_lists; i++) { - for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() >= INT_MAX) { - requires_64bit_indexing = true; - break; - } - } - if (requires_64bit_indexing) { - break; - } - } + const bool use_64bit_indexing = requires_64bit_indexing(tensor_lists); // Get moment dtype (m and v have the same dtype, already validated above) const auto moment_type_te = tensor_lists[2][0]->dtype(); // Launch kernel - if (requires_64bit_indexing) { + if (use_64bit_indexing) { if (num_tensor_lists == 4) { // g, p, m, v TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -766,28 +955,40 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } -void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float beta1, const float beta2, const float epsilon, - const int step, const int mode, const int bias_correction, - const float weight_decay, const DType fp8_dtype, - cudaStream_t stream) { - // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; +inline std::pair compute_bias_correction(int bias_correction, float beta1, + float beta2, int step) { + float bias_correction1 = 1.0f; + float bias_correction2 = 1.0f; if (bias_correction == 1) { bias_correction1 = 1 - std::pow(beta1, step); bias_correction2 = 1 - std::pow(beta2, step); } + return {bias_correction1, bias_correction2}; +} - // Check tensor list sizes - // 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv +inline void check_tensor_list_sizes(const std::vector> &tensor_lists, + size_t expected_lists) { const size_t num_tensor_lists = tensor_lists.size(); - NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists); + NVTE_CHECK(num_tensor_lists == expected_lists, "Expected ", expected_lists, + " tensor lists, but found ", num_tensor_lists); const size_t num_tensors_per_list = tensor_lists[0].size(); - for (size_t i = 1; i < num_tensor_lists; i++) { + for (size_t i = 1; i < num_tensor_lists; ++i) { NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); } +} + +void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, const DType fp8_dtype, + cudaStream_t stream) { + auto [bias_correction1, bias_correction2] = + compute_bias_correction(bias_correction, beta1, beta2, step); + check_tensor_list_sizes(tensor_lists, 8); + const size_t num_tensor_lists = tensor_lists.size(); + const size_t num_tensors_per_list = tensor_lists[0].size(); // Check tensor dtypes const auto g_in_type_te = tensor_lists[0][0]->dtype(); @@ -819,22 +1020,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ", but expected dtype=", to_string(DType::kFloat32)); } - // Check if 64-bit indices are required - bool requires_64bit_indexing = false; - for (size_t i = 0; i < num_tensor_lists; i++) { - for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() >= INT_MAX) { - requires_64bit_indexing = true; - break; - } - } - if (requires_64bit_indexing) { - break; - } - } + const bool use_64bit_indexing = requires_64bit_indexing(tensor_lists); // Launch kernel - if (requires_64bit_indexing) { + if (use_64bit_indexing) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -856,6 +1045,76 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); } +void multi_tensor_adam_mxfp8_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, const DType fp8_dtype, + cudaStream_t stream) { + auto [bias_correction1, bias_correction2] = + compute_bias_correction(bias_correction, beta1, beta2, step); + check_tensor_list_sizes(tensor_lists, 8); + const size_t num_tensor_lists = tensor_lists.size(); + const size_t num_tensors_per_list = tensor_lists[0].size(); + + NVTE_CHECK(fp8_dtype == DType::kFloat8E4M3 || fp8_dtype == DType::kFloat8E5M2, + "fp8_dtype must be E4M3 or E5M2 for MXFP8 fused Adam."); + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + const auto p_in_type_te = tensor_lists[1][0]->dtype(); + const auto moment_type_te = tensor_lists[2][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; ++j) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(p_in_type_te)); + { + const bool m_is_fp32 = tensor_lists[2][j]->dtype() == DType::kFloat32; + const bool m_is_bf16 = tensor_lists[2][j]->dtype() == DType::kBFloat16; + const bool v_is_fp32 = tensor_lists[3][j]->dtype() == DType::kFloat32; + const bool v_is_bf16 = tensor_lists[3][j]->dtype() == DType::kBFloat16; + NVTE_CHECK((m_is_fp32 && v_is_fp32) || (m_is_bf16 && v_is_bf16), + "First and second moment tensors must both be Float32 or both be BFloat16, but " + "tensor ", + j, " has first moment dtype=", to_string(tensor_lists[2][j]->dtype()), + " and second moment dtype=", to_string(tensor_lists[3][j]->dtype())); + } + } + + const bool use_64bit_indexing = requires_64bit_indexing(tensor_lists); + + if (use_64bit_indexing) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + p_in_type_te, p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + g_in_type_te, g_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply_mxfp8< + transformer_engine::multi_tensor_adam::adam_mxfp8_fused_kernel< + p_in_type, g_in_type, moment_type, int64_t>>( + chunk_size, noop_flag, tensor_lists, static_cast(fp8_dtype), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, mode, + weight_decay);))); + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + p_in_type_te, p_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + g_in_type_te, g_in_type, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16( + moment_type_te, moment_type, + multi_tensor_apply_mxfp8< + transformer_engine::multi_tensor_adam::adam_mxfp8_fused_kernel< + p_in_type, g_in_type, moment_type, int32_t>>( + chunk_size, noop_flag, tensor_lists, static_cast(fp8_dtype), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, mode, + weight_decay);))); + } +} + void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, Tensor lr, const float beta1, const float beta2, const float epsilon, @@ -1018,6 +1277,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); } +void nvte_multi_tensor_adam_mxfp8_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, const size_t num_tensor_lists, + const size_t num_tensors_per_list, const NVTEDType fp8_dtype, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_adam_mxfp8_cuda); + using namespace transformer_engine; + multi_tensor_adam::multi_tensor_adam_mxfp8_cuda( + chunk_size, *convertNVTETensorCheck(noop_flag), + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, + epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); +} + void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index 3062ead551..bea89a02ca 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -35,6 +35,22 @@ struct TensorListMetadata : public TensorListMetadataBase { void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]]; }; +constexpr int MXFP8_TILE = 32; +constexpr int MXFP8_TILE_ELEMS = MXFP8_TILE * MXFP8_TILE; +constexpr int MXFP8_BLOCK_THREADS = 256; +constexpr int MXFP8_MAX_TENSORS = 24; +constexpr int MXFP8_MAX_BLOCKS = 320; + +struct MXFP8TensorListMetadata { + void *addresses[8][MXFP8_MAX_TENSORS]; + int rows[MXFP8_MAX_TENSORS]; + int cols[MXFP8_MAX_TENSORS]; + uint8_t fp8_dtype[MXFP8_MAX_TENSORS]; + unsigned char block_to_tensor[MXFP8_MAX_BLOCKS]; + int block_to_tile[MXFP8_MAX_BLOCKS]; + int start_tensor_this_launch; +}; + template __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, U callable, ArgTypes... args) { @@ -113,3 +129,81 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, } } } + +template +void multi_tensor_apply_mxfp8(int64_t chunk_size, const transformer_engine::Tensor &noop_flag, + std::vector> tensor_lists, + uint8_t fp8_dtype, cudaStream_t stream, ArgTypes... args) { + constexpr size_t kNumTensorLists = 8; + constexpr int TileElems = TileRows * TileCols; + NVTE_CHECK(tensor_lists.size() == kNumTensorLists, + "Expected 8 tensor lists for MXFP8, but found ", tensor_lists.size()); + + const size_t num_tensors_per_list = tensor_lists[0].size(); + if (num_tensors_per_list == 0) { + return; + } + for (size_t i = 1; i < tensor_lists.size(); ++i) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + const int tiles_per_block = std::max(1, static_cast(chunk_size / TileElems)); + + MXFP8TensorListMetadata tl; + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + + for (size_t t = 0; t < num_tensors_per_list; ++t) { + const auto &rowwise_data = tensor_lists[4][t]; + + const int rows_val = static_cast(rowwise_data->data.shape[0]); + const int cols_val = static_cast(rowwise_data->data.shape[1]); + + tl.rows[loc_tensor_info] = rows_val; + tl.cols[loc_tensor_info] = cols_val; + tl.fp8_dtype[loc_tensor_info] = fp8_dtype; + + for (int d = 0; d < kNumTensorLists; ++d) { + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t]->data.dptr; + } + loc_tensor_info++; + + const int tiles_y = (rows_val + TileRows - 1) / TileRows; + const int tiles_x = (cols_val + TileCols - 1) / TileCols; + const int tiles_this_tensor = tiles_y * tiles_x; + const int blocks_this_tensor = (tiles_this_tensor + tiles_per_block - 1) / tiles_per_block; + + for (int block = 0; block < blocks_this_tensor; ++block) { + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_tile[loc_block_info] = block * tiles_per_block; + loc_block_info++; + + const bool blocks_full = (loc_block_info == MXFP8_MAX_BLOCKS); + const bool tensors_full = + (loc_tensor_info == MXFP8_MAX_TENSORS && block == blocks_this_tensor - 1); + const bool last_block = (t == num_tensors_per_list - 1 && block == blocks_this_tensor - 1); + if (blocks_full || tensors_full || last_block) { + Kernel<<>>( + chunk_size, reinterpret_cast(noop_flag.data.dptr), tl, args...); + NVTE_CHECK_CUDA(cudaGetLastError()); + loc_block_info = 0; + if (block == blocks_this_tensor - 1) { + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + tl.rows[0] = tl.rows[loc_tensor_info - 1]; + tl.cols[0] = tl.cols[loc_tensor_info - 1]; + tl.fp8_dtype[0] = tl.fp8_dtype[loc_tensor_info - 1]; + for (int d = 0; d < kNumTensorLists; ++d) { + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..5096cfb252 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -517,6 +517,12 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay, DType fp8_dtype); +void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype); + void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 145e1d4b40..6df9807f82 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -51,6 +51,25 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, at::cuda::getCurrentCUDAStream()); } +void multi_tensor_adam_mxfp8_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, const float lr, + const float beta1, const float beta2, const float epsilon, + const int step, const int mode, const int bias_correction, + const float weight_decay, DType fp8_dtype) { + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + + NVTE_CHECK(num_lists == 8, + "Expected 8 tensor lists (g, p_master, m, v, rowwise_data, colwise_data, " + "rowwise_scale_inv, colwise_scale_inv), but found ", + num_lists); + nvte_multi_tensor_adam_mxfp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, static_cast(fp8_dtype), lr, + beta1, beta2, epsilon, step, mode, bias_correction, + weight_decay, at::cuda::getCurrentCUDAStream()); +} + void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, const float beta1, const float beta2, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c590a3c9e2..bb2a8b6227 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -525,6 +525,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam_fp8", &transformer_engine::pytorch::multi_tensor_adam_fp8_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_mxfp8", &transformer_engine::pytorch::multi_tensor_adam_mxfp8_cuda, + "Compute and apply gradient update to parameters for Adam optimizer", + py::call_guard()); m.def("multi_tensor_adam_capturable", &transformer_engine::pytorch::multi_tensor_adam_capturable_cuda, "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index bcfd2bef19..1d64b18341 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -14,6 +14,7 @@ from torch.distributed._tensor import DTensor import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier @@ -189,6 +190,7 @@ def __init__( self.multi_tensor_adam = tex.multi_tensor_adam self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 + self.multi_tensor_adam_mxfp8 = tex.multi_tensor_adam_mxfp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -544,18 +546,27 @@ def step(self, closure=None, grad_scaler=None): # create lists for multi-tensor apply p_main_of_fp8_model = [] p_main_of_f16_model = [] + p_main_of_mxfp8_model = [] g_of_fp8_model = [] g_of_f16_model = [] g_of_f32_model = [] + g_of_mxfp8_model = [] m_of_fp8_model = [] m_of_f16_model = [] m_of_f32_model = [] + m_of_mxfp8_model = [] v_of_fp8_model = [] v_of_f16_model = [] v_of_f32_model = [] + v_of_mxfp8_model = [] p_fp8_model = [] p_f16_model = [] p_f32_model = [] + # mxfp8 meta + p_mxfp8_rowwise = [] + p_mxfp8_colwise = [] + p_mxfp8_rowwise_scale_inv = [] + p_mxfp8_colwise_scale_inv = [] # fp8 meta scales = [] amaxes = [] @@ -623,14 +634,38 @@ def step(self, closure=None, grad_scaler=None): g_of_fp8_model.append(p_grad.data) m_of_fp8_model.append(unscaled_state["exp_avg"]) v_of_fp8_model.append(unscaled_state["exp_avg_sq"]) + elif isinstance(p, MXFP8Tensor) or ( + isinstance(p, DTensor) and isinstance(p._local_tensor, MXFP8Tensor) + ): + p = p._local_tensor if isinstance(p, DTensor) else p + if p._rowwise_data is None or p._columnwise_data is None: + raise RuntimeError( + "MXFP8Tensor does not have one of rowwise/columnwise data." + ) + if self.capturable: + raise RuntimeError( + "FusedAdam does not support MXFP8 model weights with capturable=True." + ) + if not self.master_weights: + raise RuntimeError( + "FusedAdam without master_weights does not support " + "MXFP8 model weights. Use master_weights=True." + ) + p_main_of_mxfp8_model.append(unscaled_state["master_param"].data) + g_of_mxfp8_model.append(p_grad.data) + m_of_mxfp8_model.append(unscaled_state["exp_avg"]) + v_of_mxfp8_model.append(unscaled_state["exp_avg_sq"]) + p_mxfp8_rowwise.append(p._rowwise_data) + p_mxfp8_colwise.append(p._columnwise_data) + p_mxfp8_rowwise_scale_inv.append(p._rowwise_scale_inv) + p_mxfp8_colwise_scale_inv.append(p._columnwise_scale_inv) + out_dtype = p._fp8_dtype elif isinstance(p, QuantizedTensor) or ( isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) ): - # Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor, - # NVFP4Tensor). Operate on FP32 master weights, requantize back after - # Adam update. - # Note: a fused Adam+requantize kernel (like multi_tensor_adam_fp8 - # for Float8Tensor) would avoid the FP32 round-trip here. + # Note: Fused adam support for other quantized params (Float8BlockwiseQTensor, + # NVFP4Tensor) is missing currently.So, do "unfused adam" for now by + # operating on FP32 master weights, and requantize back after Adam update. if not self.master_weights: local_p = p._local_tensor if isinstance(p, DTensor) else p raise RuntimeError( @@ -797,6 +832,18 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N scale_invs, ] apply_multi_tensor_adam(self.multi_tensor_adam_fp8, tensor_lists, out_dtype) + if len(p_mxfp8_rowwise) > 0 and len(p_mxfp8_colwise) > 0: + tensor_lists = [ + g_of_mxfp8_model, + p_main_of_mxfp8_model, + m_of_mxfp8_model, + v_of_mxfp8_model, + p_mxfp8_rowwise, + p_mxfp8_colwise, + p_mxfp8_rowwise_scale_inv, + p_mxfp8_colwise_scale_inv, + ] + apply_multi_tensor_adam(self.multi_tensor_adam_mxfp8, tensor_lists, out_dtype) if len(p_f32_model) > 0: tensor_lists = [ g_of_f32_model,