mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-03 11:41:13 +00:00
Fix bugs for intranode EP kernels
This commit is contained in:
parent
043fa5fa99
commit
82dcf48fd3
@ -487,6 +487,8 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
|||||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||||
const auto rank_id = thread_id / 32;
|
const auto rank_id = thread_id / 32;
|
||||||
const auto lane_id = thread_id % 32;
|
const auto lane_id = thread_id % 32;
|
||||||
|
if (rank_id >= kNumRanks)
|
||||||
|
return;
|
||||||
|
|
||||||
int token_start_idx, token_end_idx;
|
int token_start_idx, token_end_idx;
|
||||||
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
||||||
@ -714,10 +716,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
|
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
|
||||||
// Read expected head
|
// Read expected head
|
||||||
int expected_head = -1;
|
int expected_head = -1;
|
||||||
if (recv_lane_id < kNumRanks) {
|
if (recv_lane_id < kNumRanks)
|
||||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
|
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
|
||||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
|
||||||
}
|
|
||||||
auto start_time = clock64();
|
auto start_time = clock64();
|
||||||
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
|
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
|
||||||
// Timeout check
|
// Timeout check
|
||||||
@ -775,6 +776,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
|||||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
|
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
|
||||||
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
|
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update head
|
||||||
|
if (recv_lane_id < kNumRanks)
|
||||||
|
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retired
|
// Retired
|
||||||
|
Loading…
Reference in New Issue
Block a user