#include "configs.cuh" #include "exception.cuh" #include "launch.cuh" namespace deep_ep { namespace layout { template __global__ void __launch_bounds__(kNumThreads, 1) get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int* num_tokens_per_expert, bool* is_token_in_rank, int num_tokens, int num_topk, int num_ranks, int num_experts) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); // Count expert statistics __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); if (expert_begin_idx < expert_end_idx) { // Per-thread count #pragma unroll for (int i = 0; i < kNumExpertsPerSM; ++ i) num_tokens_per_expert_per_thread[thread_id][i] = 0; #pragma unroll for (int i = thread_id; i < num_tokens; i += kNumThreads) { auto shifted_topk_idx = topk_idx + i * num_topk; #pragma unroll for (int j = 0, expert_idx; j < num_topk; ++ j) { expert_idx = static_cast(shifted_topk_idx[j]); if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; } } __syncthreads(); // Sum up EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); if (expert_begin_idx + thread_id < expert_end_idx) { int sum = 0; #pragma unroll for (int i = 0; i < kNumThreads; ++ i) sum += num_tokens_per_expert_per_thread[i][thread_id]; num_tokens_per_expert[expert_begin_idx + thread_id] = sum; } return; } if (num_tokens_per_rdma_rank != nullptr) EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); // Count rank statistics constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; if (rank_begin_idx < rank_end_idx) { const auto num_expert_per_rank = num_experts / num_ranks; auto expert_begin = rank_begin_idx * num_expert_per_rank; auto expert_end = rank_end_idx * num_expert_per_rank; // Per-thread count #pragma unroll for (int i = 0; i < kNumRanksPerSM; ++ i) num_tokens_per_rank_per_thread[thread_id][i] = 0; #pragma unroll for (int i = 0; i < kNumRDMARanksPerSM; ++ i) num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; #pragma unroll for (int i = thread_id; i < num_tokens; i += kNumThreads) { auto shifted_topk_idx = topk_idx + i * num_topk; int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; #pragma unroll for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { expert_idx = static_cast(shifted_topk_idx[j]); if (expert_begin <= expert_idx and expert_idx < expert_end) { // Count single rank rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; } } auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; #pragma unroll for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) { shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); } #pragma unroll for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); } __syncthreads(); // Sum up EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); if (rank_begin_idx + thread_id < rank_end_idx) { int sum = 0; #pragma unroll for (int i = 0; i < kNumThreads; ++ i) sum += num_tokens_per_rank_per_thread[i][thread_id]; num_tokens_per_rank[rank_begin_idx + thread_id] = sum; } if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { int sum = 0; #pragma unroll for (int i = 0; i < kNumThreads; ++ i) sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; } } } void get_dispatch_layout(const int64_t* topk_idx, int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, int* num_tokens_per_expert, bool* is_token_in_rank, int num_tokens, int num_topk, int num_ranks, int num_experts, cudaStream_t stream) { constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); LAUNCH_KERNEL(&cfg, (get_dispatch_layout), topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, num_tokens, num_topk, num_ranks, num_experts); } } // namespace layout } // namespace deep_ep