From bf4a4a21d282026b293ed61668aeb807540a3dba Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Wed, 18 Jun 2025 14:43:38 +0800 Subject: [PATCH 01/10] Set `device_id` to suppress pytorch warning. --- tests/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 1a9c176..7af4947 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,12 +14,17 @@ def init_dist(local_rank: int, num_local_ranks: int): node_rank = int(os.getenv('RANK', 0)) assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - dist.init_process_group( - backend='nccl', - init_method=f'tcp://{ip}:{port}', - world_size=num_nodes * num_local_ranks, - rank=node_rank * num_local_ranks + local_rank - ) + import inspect + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + params['device_id'] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank) From cd371d31fc7d62d6d28b5a803a31a3a1accc3d35 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Wed, 18 Jun 2025 14:52:04 +0800 Subject: [PATCH 02/10] Move import. --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 7af4947..57a38b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.distributed as dist from typing import Optional +import inspect def init_dist(local_rank: int, num_local_ranks: int): @@ -14,7 +15,6 @@ def init_dist(local_rank: int, num_local_ranks: int): node_rank = int(os.getenv('RANK', 0)) assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - import inspect sig = inspect.signature(dist.init_process_group) params = { 'backend': 'nccl', From b56f7c2c8c1fd032133d00d0df7cc4edfd367b79 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 18 Jun 2025 15:50:06 +0800 Subject: [PATCH 03/10] Adjust import order --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 57a38b2..51ee18e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,10 @@ +import inspect import os import sys import numpy as np import torch import torch.distributed as dist from typing import Optional -import inspect def init_dist(local_rank: int, num_local_ranks: int): From 9d4f7ef8eeedd9970c2bb8efe998e07e7bb9df5b Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 18 Jun 2025 16:04:42 +0800 Subject: [PATCH 04/10] Surpass type checks --- tests/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 51ee18e..da2e12c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ import inspect +import numpy as np import os import sys -import numpy as np import torch import torch.distributed as dist from typing import Optional @@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int): 'rank': node_rank * num_local_ranks + local_rank, } if 'device_id' in sig.parameters: - params['device_id'] = torch.device(f"cuda:{local_rank}") + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') From 7b0c25f864cc7a2d95d7e3c3c2b7c13b22116063 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 20 Jun 2025 16:37:28 +0800 Subject: [PATCH 05/10] Support more hidden size --- csrc/kernels/launch.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index 5b398bf..c9cca9b 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -84,6 +84,7 @@ cfg.dynamicSmemBytes = smem_size; #define SWITCH_HIDDEN(case_macro) \ switch (hidden) { \ + case 2048: case_macro(2048); \ case 2560: case_macro(2560); \ case 4096: case_macro(4096); \ case 5120: case_macro(5120); \ From c95997f8c4bf086763d71496741b2284e97da64a Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 23 Jun 2025 11:44:06 +0800 Subject: [PATCH 06/10] Update deep_ep.cpp (#242) --- csrc/deep_ep.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 9c90178..75906f6 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -35,7 +35,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); #ifdef DISABLE_NVSHMEM - EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disable during compilation"); + EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation"); #endif // Get device info @@ -151,7 +151,7 @@ pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; #else - EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); #endif } @@ -895,7 +895,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional Date: Mon, 23 Jun 2025 15:18:10 +0800 Subject: [PATCH 07/10] Update internode_ll.cu (#246) --- csrc/kernels/internode_ll.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index b162492..db15bf5 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -498,7 +498,7 @@ combine(void* combined_x, } cg::this_grid().sync(); - // Reduce tokens with FP8 cast + // Reduce tokens EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); if (thread_id < hidden_bf16_int4) { From 9eb2f84b3eae6b1a9b9b2e884f848ae202176009 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Tue, 24 Jun 2025 09:10:23 +0800 Subject: [PATCH 08/10] Optimize intranode combine. (#247) * Increase the test round. * Add warp synchronization. * Shuffle the send warps. * Add time elapsed into bench result. --- csrc/kernels/intranode.cu | 6 +++--- tests/test_intranode.py | 8 ++++---- tests/utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 52ba9e3..0f3cb7e 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -618,8 +618,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights, const auto num_threads_per_rank = num_send_warps_per_rank * 32; const auto send_thread_id = thread_id; const auto send_warp_id = send_thread_id / 32; - const auto send_rank_id = thread_id / num_threads_per_rank; - const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank; + const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks; + const auto send_warp_id_in_rank = send_warp_id / kNumRanks; EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count"); // Calculate pointers by the specific layout @@ -777,7 +777,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id); auto start_time = clock64(); - while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) { + while (__any_sync(0xffffffff, channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 14c81cf..887a9ee 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -184,9 +184,9 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' - f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) print('', flush=True) # Gather the best config from rank 0 and the first test setting @@ -215,12 +215,12 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: t = bench(lambda: buffer.combine(**tune_args))[0] if local_rank == 0: print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' - f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True) if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) print('', flush=True) diff --git a/tests/utils.py b/tests/utils.py index da2e12c..2316a57 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -80,7 +80,7 @@ def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_gro return (scores * mask).view(num_tokens, num_experts) -def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): +def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') From bc118b248a07316c6b1f61d8e5b2dde069dae7e6 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 24 Jun 2025 09:12:40 +0800 Subject: [PATCH 09/10] Add the transaction window data structure for RDMA senders (#245) * Add draft * Add fast-debugging flags * Fix several bugs * Add sender timeout checks * Fix stuck * Fix bugs * Fix bugs --- csrc/CMakeLists.txt | 1 + csrc/kernels/configs.cuh | 8 ++- csrc/kernels/internode.cu | 118 ++++++++++++++++++++++---------------- csrc/kernels/launch.cuh | 4 -- csrc/kernels/utils.cuh | 22 +++++++ tests/test_internode.py | 2 +- 6 files changed, 101 insertions(+), 54 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 005607a..3f51c27 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -6,6 +6,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") set(CUDA_SEPARABLE_COMPILATION ON) +list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") list(APPEND CUDA_NVCC_FLAGS "-O3") list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index 8893b79..a7b960b 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -7,9 +7,15 @@ #define NUM_BUFFER_ALIGNMENT_BYTES 128 #define FINISHED_SUM_TAG 1024 +#define NUM_WAIT_NANOSECONDS 500 + +#ifndef ENABLE_FAST_DEBUG #define NUM_CPU_TIMEOUT_SECS 100 #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s -#define NUM_WAIT_NANOSECONDS 500 +#else +#define NUM_CPU_TIMEOUT_SECS 10 +#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s +#endif #define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_RECV_PHASE 2 diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index a49c430..1e0a28e 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -365,7 +365,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms); const auto role_meta = [=]() -> std::pair { if (is_forwarder) { @@ -419,9 +419,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); // RDMA sender warp synchronization - __shared__ volatile int rdma_send_next_token_idx; - __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; - __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; + // NOTES: `rdma_send_channel_tail` means the latest released tail + // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status + __shared__ int rdma_send_channel_lock[kNumRDMARanks]; + __shared__ int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // Forward warp synchronization @@ -434,12 +436,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - // Clean shared memory - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); - (warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0; - (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; - (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0; - // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { @@ -468,24 +464,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Iterate over tokens and copy into buffer int64_t token_idx; - int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; + int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0; auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); - for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { + for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) { // Read RDMA rank existence uint64_t is_token_in_rank_uint64 = 0; - if (lane_id < kNumRDMARanks) - is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); - - // Acquire sequential lock - while (lane_id == 0 and rdma_send_next_token_idx != token_idx); + if (lane_id < kNumRDMARanks) { + is_token_in_rank_uint64 = __ldg(reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS)); + global_rdma_tail_idx += (is_token_in_rank_uint64 != 0); + } __syncwarp(); - // Acquire next tail - int rdma_tail_idx = -1; - if (is_token_in_rank_uint64 != 0) { - rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++; - while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) - cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); + // Skip the token which does not belong to this warp + if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id) + continue; + auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; + + // Wait the remote buffer to be released + auto start_time = clock64(); + while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); + + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx); + trap(); + } } __syncwarp(); @@ -493,14 +498,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id < kNumRDMARanks and not kCachedMode) send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); - last_rdma_tail_idx = rdma_tail_idx; - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; - // Broadcast tails SourceMeta src_meta; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; @@ -557,24 +554,46 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); } + __syncwarp(); + + // Release the transaction in the window + if (is_token_in_rank_uint64 != 0) { + // Acquire lock first + acquire_lock(rdma_send_channel_lock + lane_id); + + // Release the transaction slot + auto rdy_window = rdma_send_channel_window[lane_id]; + auto latest_tail = rdma_send_channel_tail[lane_id]; + auto offset = rdma_tail_idx - latest_tail; + + // The same effect with `EP_DEVICE_ASSERT(offset < 32);` + EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps"); + + // Erase bit and move the ones if possible + rdy_window ^= 1u << offset; + if (offset == 0) { + EP_DEVICE_ASSERT(rdy_window & 1); + auto num_empty_slots = __ffs(~rdy_window) - 1; + st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); + rdy_window >>= num_empty_slots; + } + rdma_send_channel_window[lane_id] = rdy_window; + + // Release lock + release_lock(rdma_send_channel_lock + lane_id); + } + __syncwarp(); } - - // Epilogue - // Acquire sequential lock - while (lane_id == 0 and rdma_send_next_token_idx != token_idx); - __syncwarp(); - - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); - __syncwarp(); - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; } else if (warp_role == WarpRole::kRDMASenderCoordinator) { // NOTES: in case of splitting, the issued put at the end of the buffer EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; + // Synchronize shared memory sync_rdma_sender_smem(); @@ -592,10 +611,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { - printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n", + printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send); trap(); } + + // TODO: try thread-level `put_nbi`? for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) { // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; @@ -603,9 +624,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (synced_num_tokens_to_send == 0) continue; - // Read progress - auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); + // Read the latest progress + // NOTES: `rdma_send_channel_tail` does not need to be protected by lock auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)), 0); + auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); auto num_tokens_processed = processed_tail - synced_last_issued_tail; if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) continue; @@ -625,9 +647,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Lighter fence for local RDMA rank memory_fence(); } + __syncwarp(); // Update tails - __syncwarp(); if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index c9cca9b..8b95eed 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -58,12 +58,9 @@ cfg.dynamicSmemBytes = smem_size; #define SWITCH_RDMA_RANKS(case_macro) \ switch (num_ranks / NUM_MAX_NVL_PEERS) { \ case 2: case_macro(2); \ - case 3: case_macro(3); \ case 4: case_macro(4); \ case 8: case_macro(8); \ case 16: case_macro(16); \ - case 18: case_macro(18); \ - case 20: case_macro(20); \ default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ } while (false) @@ -78,7 +75,6 @@ cfg.dynamicSmemBytes = smem_size; #define SWITCH_TYPES(case_macro) \ switch (type) { \ case CUDA_R_16BF: case_macro(nv_bfloat16); \ - case CUDA_R_32F: case_macro(float); \ default: EP_HOST_ASSERT(false && "Unsupported type"); \ } while (false) diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 796a6f9..ac97896 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -466,4 +466,26 @@ barrier_block(int** barrier_signal_ptrs, int rank) { __syncthreads(); } +__forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) { + int ret; + asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory"); + return ret; +} + +__forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) { + int ret; + asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory"); + return ret; +} + +__forceinline__ __device__ void acquire_lock(int* mutex) { + // To make later memory operations valid, we must use `acquire` for memory semantics + while (atomic_cas_cta_acquire(mutex, 0, 1) != 0); +} + +__forceinline__ __device__ void release_lock(int* mutex) { + // To make previous memory operations visible to other threads, we must use `release` for memory semantics + atomic_exch_cta_release(mutex, 0); +} + } // namespace deep_ep diff --git a/tests/test_internode.py b/tests/test_internode.py index 4aeca49..e84f4eb 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -220,7 +220,7 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in def test_loop(local_rank: int, num_local_ranks: int): num_nodes = int(os.getenv('WORLD_SIZE', 1)) rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - test_ll_compatibility = True + test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False) if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 From a15faa9ff07a2942c7f27bd3aa4615909a53cb9f Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 24 Jun 2025 09:21:35 +0800 Subject: [PATCH 10/10] Remove useless assertion --- csrc/kernels/internode.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 1e0a28e..da1d203 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -562,7 +562,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv acquire_lock(rdma_send_channel_lock + lane_id); // Release the transaction slot - auto rdy_window = rdma_send_channel_window[lane_id]; + auto window = rdma_send_channel_window[lane_id]; auto latest_tail = rdma_send_channel_tail[lane_id]; auto offset = rdma_tail_idx - latest_tail; @@ -570,14 +570,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps"); // Erase bit and move the ones if possible - rdy_window ^= 1u << offset; + window ^= 1u << offset; if (offset == 0) { - EP_DEVICE_ASSERT(rdy_window & 1); - auto num_empty_slots = __ffs(~rdy_window) - 1; + auto num_empty_slots = __ffs(~window) - 1; st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); - rdy_window >>= num_empty_slots; + window >>= num_empty_slots; } - rdma_send_channel_window[lane_id] = rdy_window; + rdma_send_channel_window[lane_id] = window; // Release lock release_lock(rdma_send_channel_lock + lane_id);