From 42494864ba64b6520337a6d084dc209ac6be4905 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 7 Apr 2025 09:55:39 +0800 Subject: [PATCH] Remove useless control metadata for low-latency combine --- csrc/config.hpp | 6 +++--- csrc/kernels/internode_ll.cu | 9 ++++----- tests/test_low_latency.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/csrc/config.hpp b/csrc/config.hpp index 1c7a681..ec74564 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -122,7 +122,6 @@ struct LowLatencyLayout { LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { const int num_scales = hidden / 128; - const int num_local_experts = num_experts / num_ranks; // Dispatch and combine layout: // - 2 symmetric odd/even send buffer @@ -130,9 +129,10 @@ struct LowLatencyLayout { // - 2 symmetric odd/even signaling buffers // Message sizes + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); - size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16); + size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); // Send buffer size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; @@ -167,7 +167,7 @@ struct LowLatencyLayout { advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), - advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)), + advance(rdma_buffer, send_buffer_bytes * i), num_bytes_per_combine_msg }; } diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 03be9ba..c33e062 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -369,8 +369,7 @@ combine(void* combined_x, const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; // Message package - // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot) - constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16); + constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); // Sending phase @@ -409,12 +408,12 @@ combine(void* combined_x, for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) { const auto x_int4 = local_x + token_idx * hidden_bf16_int4; const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); - const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row + 4); + const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); // Copy directly to local rank, or copy to buffer and issue RDMA auto src_idx = __ldg(local_src_info + token_idx); const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); - const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; if (dst_rank == rank) { const auto dst_int4_ptr = reinterpret_cast(dst_ptr); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); @@ -473,7 +472,7 @@ combine(void* combined_x, for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) { // Read from sources auto rdma_buffer_type = reinterpret_cast(reinterpret_cast(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); - auto rdma_buffer_row = reinterpret_cast(rdma_buffer_type + 4); + auto rdma_buffer_row = reinterpret_cast(rdma_buffer_type); // Reduce auto x_vec = ld_nc_global(reinterpret_cast(rdma_buffer_row) + thread_id); diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 6cf852d..ed7b32e 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -84,7 +84,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, if do_check: diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) assert torch.isnan(combined_x).sum().item() == 0 - assert diff < 1e-5, f'Error: diff={diff}' + assert diff < 1e-5, f'Error: {diff=}, {zero_copy=}' hash_value ^= hash_tensor(combined_x) def create_test_cast_with_outliers(num_outliers):