Fix bugs for intranode EP kernels

This commit is contained in:
Chenggang Zhao 2025-03-14 16:09:23 +08:00
parent 043fa5fa99
commit 82dcf48fd3

View File

@ -487,6 +487,8 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
const auto thread_id = static_cast<int>(threadIdx.x);
const auto rank_id = thread_id / 32;
const auto lane_id = thread_id % 32;
if (rank_id >= kNumRanks)
return;
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
@ -714,10 +716,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
// Read expected head
int expected_head = -1;
if (recv_lane_id < kNumRanks) {
if (recv_lane_id < kNumRanks)
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
}
auto start_time = clock64();
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
// Timeout check
@ -775,6 +776,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
}
// Update head
if (recv_lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
}
// Retired