From ed7487c15e219161f86656ed6604aa9877bf72bc Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 10 Mar 2025 17:24:41 +0800 Subject: [PATCH] Support BF16 for low-latency kernels --- README.md | 2 +- csrc/config.hpp | 2 +- csrc/deep_ep.cpp | 22 ++++--- csrc/deep_ep.hpp | 4 +- csrc/kernels/api.cuh | 2 +- csrc/kernels/internode_ll.cu | 112 +++++++++++++++++++---------------- deep_ep/buffer.py | 16 +++-- tests/test_low_latency.py | 89 +++++++++++++++------------- 8 files changed, 138 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 6ecb7db..dd20a79 100644 --- a/README.md +++ b/README.md @@ -282,7 +282,7 @@ For two micro-batch overlapping, you can refer to the following figure. With our - [x] AR support - [ ] Refactor low-latency mode AR code - [ ] A100 support (intranode only) -- [ ] Support BF16 for the low-latency dispatch kernel +- [x] Support BF16 for the low-latency dispatch kernel - [ ] Support NVLink protocol for intranode low-latency kernels - [ ] SM-free normal kernels diff --git a/csrc/config.hpp b/csrc/config.hpp index 11a09c7..37be06b 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -128,7 +128,7 @@ struct LowLatencyLayout { // Message sizes EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); - size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4); + size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16); // Send buffer diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 94ae57b..d43f3e5 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1011,10 +1011,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int at::cuda::getCurrentCUDAStream()); } -std::tuple, std::optional>> +std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, - bool async, bool return_recv_hook) { + bool use_fp8, bool async, bool return_recv_hook) { EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1045,20 +1045,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn)); + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); // Allocate column-majored scales - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - auto packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); - packed_recv_x_scales = torch::transpose(packed_recv_x_scales, 1, 2); + auto packed_recv_x_scales = std::optional(); + float* packed_recv_x_scales_ptr = nullptr; + if (use_fp8) { + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { - internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr(), + internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, @@ -1066,7 +1072,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i x.data_ptr(), topk_idx.data_ptr(), next_clean_meta.first, next_clean_meta.second, num_tokens, hidden, num_max_dispatch_tokens_per_rank, - num_topk, num_experts, rank, num_ranks, + num_topk, num_experts, rank, num_ranks, use_fp8, workspace, launch_stream, phases); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 33f0c81..3fff3ae 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -134,10 +134,10 @@ public: void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, std::optional>> + std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, - bool async, bool return_recv_hook); + bool use_fp8, bool async, bool return_recv_hook); std::tuple, std::optional>> low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 7b2159e..74c962b 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -137,7 +137,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, + int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases); void combine(void* combined_x, diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index f60e933..426c7bc 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +template __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -62,11 +62,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, constexpr int kNumPerChannels = 128; constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f; const int num_scales = kHidden / kNumPerChannels; - const size_t hidden_int4 = kHidden / sizeof(int4); + const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); // Message package: hidden data, FP8 scales, index at source // NOTES: currently we have 3 reserved int fields for future use - const size_t num_bytes_per_msg = kHidden + num_scales * sizeof(float) + sizeof(int4); + using vec_t = typename std::conditional::type; + const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); @@ -89,9 +91,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { const auto x_int4 = reinterpret_cast(x) + token_idx * hidden_bf16_int4; - const auto rdma_x_int2 = reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); - const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_int2) + kHidden); - const auto rdma_x_src_idx = reinterpret_cast(rdma_x_scales + num_scales); + const auto rdma_x_src_idx = reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); // Overlap top-k index read and source token index write auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; @@ -100,32 +102,39 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // FP8 cast #pragma unroll for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { - // Read and calculate local amax + // Read auto int4_value = __ldg(x_int4 + i); - auto bf16_values = reinterpret_cast(&int4_value); - float fp32_values[kNumElemsPerRead]; - float amax = kFP8Margin, scale, scale_inv; - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; ++ j) { - fp32_values[j] = static_cast(bf16_values[j]); - amax = fmaxf(amax, fabsf(fp32_values[j])); - } - // Reduce amax and scale - EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); - amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv; - if (lane_id == 0 or lane_id == 16) - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + if (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++ j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } - // Cast into send buffer - int2 int2_value; - auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; j += 2) { - float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; - fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv; + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); } - rdma_x_int2[i] = int2_value; } asm volatile("bar.sync 1, %0;" :: "r"(num_threads)); @@ -135,7 +144,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto dst_rank = dst_expert_idx / num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; - const auto src_ptr = reinterpret_cast(rdma_x_int2); + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + @@ -273,26 +282,28 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { - // Copy data - // NOTES: only 2 load iterations for 7K hidden with 7 unrolls - const auto src = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); - const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global); - - // Copy scales - const auto src_scales = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden); - const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); - const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; - auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; - auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0; - lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f; - (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; - // Copy source info - const auto src_src_idx = reinterpret_cast(src_scales + num_scales); + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); __syncwarp(); + + // Copy data + // NOTES: only 2 load iterations for 7K hidden with 7 unrolls + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + // Copy scales + if (kUseFP8) { + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); + const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; + auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; + auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0; + lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f; + (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; + } } } } @@ -304,7 +315,7 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, + int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 9; constexpr int kNumWarpsPerGroup = 10; @@ -314,15 +325,16 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_sms = cell_div(num_experts, kNumWarpGroups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); - EP_HOST_ASSERT(cell_div(static_cast(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2); // Workspace checks auto atomic_counter_per_expert = reinterpret_cast(workspace); auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); -#define DISPATCH_LAUNCH_CASE(hidden) \ -LAUNCH_KERNEL(&cfg, dispatch