diff --git a/csrc/config.hpp b/csrc/config.hpp index b6ffd60..8a1a8ba 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -6,13 +6,13 @@ namespace deep_ep { template -dtype_t cell_div(dtype_t a, dtype_t b) { +dtype_t ceil_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; } template dtype_t align(dtype_t a, dtype_t b) { - return cell_div(a, b) * b; + return ceil_div(a, b) * b; } struct Config { diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 041db5b..981753b 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -41,6 +41,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); + num_device_sms = device_prop.multiProcessorCount; if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles @@ -1142,7 +1143,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, round_scale, use_ue8m0, - workspace, low_latency_usage_flag_mapped, launch_stream, phases); + workspace, low_latency_usage_flag_mapped, + num_device_sms, launch_stream, + phases); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1234,7 +1237,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id next_clean_meta.first, next_clean_meta.second, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, - workspace, low_latency_usage_flag_mapped, launch_stream, + workspace, low_latency_usage_flag_mapped, + num_device_sms, launch_stream, phases, zero_copy); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 7b8ab75..40e0b33 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -41,6 +41,7 @@ private: // Device info and communication int device_id; + int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 570fa17..f033c65 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -148,7 +148,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, int* usage_flag, - cudaStream_t stream, int phases); + int num_device_sms, cudaStream_t stream, + int phases); void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, @@ -158,7 +159,8 @@ void combine(void* combined_x, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, void* workspace, int* usage_flag, - cudaStream_t stream, int phases, bool zero_copy); + int num_device_sms, cudaStream_t stream, + int phases, bool zero_copy); } // namespace internode_ll diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 73e6058..9b9678b 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -36,9 +36,8 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template -__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void +template +__global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, @@ -49,16 +48,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* next_clean, int num_next_clean_int, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, - bool round_scale, int* usage_flag, int phases) { + bool round_scale, int* usage_flag, + int num_warp_groups, int num_warps_per_group, + int phases) { const auto sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / 32, lane_id = get_lane_id(); const auto num_sms = static_cast(gridDim.x); - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_local_experts = num_experts / num_ranks; - const auto warp_group_id = warp_id / kNumWarpsPerGroup; - const auto sub_warp_id = warp_id % kNumWarpsPerGroup; - const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; + const auto warp_group_id = warp_id / num_warps_per_group; + const auto sub_warp_id = warp_id % num_warps_per_group; + const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; // May extract UE8M0 from the scales using scale_t = std::conditional_t; @@ -78,13 +79,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + // Expert counts + constexpr int kNumMaxWarpGroups = 32; + __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; - // Expert counts - __shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups]; - // There are 2 kinds of warps in this part: // 1. The first-kind warps for FP8 cast and sending top-k tokens // 2. The last warp for reading `topk_idx` and count for per-expert information @@ -96,8 +98,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { - const auto x_int4 = reinterpret_cast(x) + token_idx * hidden_bf16_int4; - const auto rdma_x_src_idx = reinterpret_cast(reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; + const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); @@ -194,9 +196,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } // This SM should be responsible for some destination experts, read `topk_idx` for them - int expert_count[kNumWarpGroups] = {0}; - const auto expert_begin_idx = sm_id * kNumWarpGroups; - const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); + int expert_count[kNumMaxWarpGroups] = {0}; + const auto expert_begin_idx = sm_id * num_warp_groups; + const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); // Per lane count #pragma unroll 8 @@ -222,7 +224,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { const auto dst_rank = responsible_expert_idx / num_local_experts; const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; - const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; + const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); @@ -257,23 +259,23 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, if (responsible_expert_idx < num_experts) { const auto src_rank = responsible_expert_idx / num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts; - const auto rdma_recv_x_uint8 = reinterpret_cast(rdma_recv_x) + + const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; - const auto recv_x_int4 = reinterpret_cast(packed_recv_x) + + const auto recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto num_aligned_scales = align(num_scales, sizeof(float) / sizeof(scale_t)); - const auto recv_x_scales = reinterpret_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; // Shared between sub-warps in warp groups - __shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups]; + __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; // Wait tokens to arrive // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens, recv_token_begin_idx; - EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); if (sub_warp_id == 1 and lane_id == 0) { while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); num_recv_tokens = -num_recv_tokens - 1; @@ -284,13 +286,13 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, if (cumulative_local_expert_recv_stats != nullptr) atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); } - asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); + asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); - for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { + for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); if (lane_id == 0) @@ -340,14 +342,16 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, void* workspace, int* usage_flag, - cudaStream_t stream, int phases) { + int num_device_sms, cudaStream_t stream, + int phases) { constexpr int kNumMaxTopK = 9; - constexpr int kNumWarpsPerGroup = 10; - constexpr int kNumWarpGroups = 3; - EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); + const int num_warp_groups = ceil_div(num_experts, num_device_sms); + const int num_warps_per_group = 32 / num_warp_groups; + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); + EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - const auto num_sms = cell_div(num_experts, kNumWarpGroups); + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_sms = ceil_div(num_experts, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); // Workspace checks @@ -360,11 +364,11 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = dispatch