From 9eb2f84b3eae6b1a9b9b2e884f848ae202176009 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Tue, 24 Jun 2025 09:10:23 +0800 Subject: [PATCH] Optimize intranode combine. (#247) * Increase the test round. * Add warp synchronization. * Shuffle the send warps. * Add time elapsed into bench result. --- csrc/kernels/intranode.cu | 6 +++--- tests/test_intranode.py | 8 ++++---- tests/utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 52ba9e3..0f3cb7e 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -618,8 +618,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights, const auto num_threads_per_rank = num_send_warps_per_rank * 32; const auto send_thread_id = thread_id; const auto send_warp_id = send_thread_id / 32; - const auto send_rank_id = thread_id / num_threads_per_rank; - const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank; + const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks; + const auto send_warp_id_in_rank = send_warp_id / kNumRanks; EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count"); // Calculate pointers by the specific layout @@ -777,7 +777,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id); auto start_time = clock64(); - while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) { + while (__any_sync(0xffffffff, channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 14c81cf..887a9ee 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -184,9 +184,9 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' - f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) print('', flush=True) # Gather the best config from rank 0 and the first test setting @@ -215,12 +215,12 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: t = bench(lambda: buffer.combine(**tune_args))[0] if local_rank == 0: print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' - f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True) if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) print('', flush=True) diff --git a/tests/utils.py b/tests/utils.py index da2e12c..2316a57 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -80,7 +80,7 @@ def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_gro return (scores * mask).view(num_tokens, num_experts) -def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): +def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')