Optimize intranode combine. (#247)

* Increase the test round.

* Add warp synchronization.

* Shuffle the send warps.

* Add time elapsed into bench result.
This commit is contained in:
Shangyan Zhou
2025-06-24 09:10:23 +08:00
committed by GitHub
parent fbcf430006
commit 9eb2f84b3e
3 changed files with 8 additions and 8 deletions

View File

@@ -618,8 +618,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
const auto send_thread_id = thread_id;
const auto send_warp_id = send_thread_id / 32;
const auto send_rank_id = thread_id / num_threads_per_rank;
const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks;
const auto send_warp_id_in_rank = send_warp_id / kNumRanks;
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
// Calculate pointers by the specific layout
@@ -777,7 +777,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
auto start_time = clock64();
while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
while (__any_sync(0xffffffff, channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) {
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);