mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Optimize intranode combine. (#247)
* Increase the test round. * Add warp synchronization. * Shuffle the send warps. * Add time elapsed into bench result.
This commit is contained in:
@@ -618,8 +618,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_rank_id = thread_id / num_threads_per_rank;
|
||||
const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
|
||||
const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks;
|
||||
const auto send_warp_id_in_rank = send_warp_id / kNumRanks;
|
||||
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
@@ -777,7 +777,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
|
||||
|
||||
auto start_time = clock64();
|
||||
while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
|
||||
while (__any_sync(0xffffffff, channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
|
||||
|
||||
Reference in New Issue
Block a user