From 82dcf48fd315d7b83f2cd2b4f1d1f1fda6af8ed2 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 14 Mar 2025 16:09:23 +0800 Subject: [PATCH] Fix bugs for intranode EP kernels --- csrc/kernels/intranode.cu | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 53911e3..4280166 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -487,6 +487,8 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int const auto thread_id = static_cast(threadIdx.x); const auto rank_id = thread_id / 32; const auto lane_id = thread_id % 32; + if (rank_id >= kNumRanks) + return; int token_start_idx, token_end_idx; get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx); @@ -714,10 +716,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights, for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) { // Read expected head int expected_head = -1; - if (recv_lane_id < kNumRanks) { + if (recv_lane_id < kNumRanks) expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); - warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; - } + auto start_time = clock64(); while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) { // Timeout check @@ -775,6 +776,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights, value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id); recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; } + + // Update head + if (recv_lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; } // Retired