mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-03 03:32:27 +00:00
1724 lines
96 KiB
Plaintext
1724 lines
96 KiB
Plaintext
#include "configs.cuh"
|
|
#include "buffer.cuh"
|
|
#include "exception.cuh"
|
|
#include "launch.cuh"
|
|
#include "utils.cuh"
|
|
|
|
namespace deep_ep {
|
|
|
|
namespace internode {
|
|
|
|
extern nvshmem_team_t cpu_rdma_team;
|
|
|
|
template<int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
|
|
__global__ void __launch_bounds__(kNumThreads, 1)
|
|
get_dispatch_layout(const int64_t* topk_idx,
|
|
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
|
int* num_tokens_per_expert, bool* is_token_in_rank,
|
|
int num_tokens, int num_topk, int num_ranks, int num_experts) {
|
|
auto sm_id = static_cast<int>(blockIdx.x);
|
|
auto thread_id = static_cast<int>(threadIdx.x);
|
|
|
|
// Count expert statistics
|
|
__shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
|
|
int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
|
|
if (expert_begin_idx < expert_end_idx) {
|
|
// Per-thread count
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumExpertsPerSM; ++ i)
|
|
num_tokens_per_expert_per_thread[thread_id][i] = 0;
|
|
#pragma unroll
|
|
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
|
|
auto shifted_topk_idx = topk_idx + i * num_topk;
|
|
#pragma unroll
|
|
for (int j = 0, expert_idx; j < num_topk; ++ j) {
|
|
expert_idx = static_cast<int>(shifted_topk_idx[j]);
|
|
if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
|
|
++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
|
|
}
|
|
}
|
|
__syncthreads();
|
|
|
|
// Sum up
|
|
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
|
|
if (expert_begin_idx + thread_id < expert_end_idx) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumThreads; ++ i)
|
|
sum += num_tokens_per_expert_per_thread[i][thread_id];
|
|
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (num_tokens_per_rdma_rank != nullptr)
|
|
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
|
|
|
|
// Count rank statistics
|
|
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
|
|
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
|
|
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
|
|
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
|
|
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
|
|
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
|
|
if (rank_begin_idx < rank_end_idx) {
|
|
const auto num_expert_per_rank = num_experts / num_ranks;
|
|
auto expert_begin = rank_begin_idx * num_expert_per_rank;
|
|
auto expert_end = rank_end_idx * num_expert_per_rank;
|
|
|
|
// Per-thread count
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRanksPerSM; ++ i)
|
|
num_tokens_per_rank_per_thread[thread_id][i] = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
|
|
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
|
|
#pragma unroll
|
|
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
|
|
auto shifted_topk_idx = topk_idx + i * num_topk;
|
|
int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
|
|
#pragma unroll
|
|
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
|
|
expert_idx = static_cast<int>(shifted_topk_idx[j]);
|
|
if (expert_begin <= expert_idx and expert_idx < expert_end) {
|
|
// Count single rank
|
|
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
|
|
is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
|
|
}
|
|
}
|
|
|
|
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
|
|
#pragma unroll
|
|
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
|
|
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
|
|
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
|
|
num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
|
|
}
|
|
__syncthreads();
|
|
|
|
// Sum up
|
|
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
|
|
if (rank_begin_idx + thread_id < rank_end_idx) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumThreads; ++ i)
|
|
sum += num_tokens_per_rank_per_thread[i][thread_id];
|
|
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
|
|
}
|
|
|
|
if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumThreads; ++ i)
|
|
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
|
|
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
void get_dispatch_layout(const int64_t* topk_idx,
|
|
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
|
int* num_tokens_per_expert, bool* is_token_in_rank,
|
|
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
|
cudaStream_t stream) {
|
|
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
|
|
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
|
|
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
|
|
|
|
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
|
|
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
|
|
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
|
|
num_tokens, num_topk, num_ranks, num_experts);
|
|
}
|
|
|
|
struct SourceMeta {
|
|
int src_rdma_rank, is_token_in_nvl_rank_bits;
|
|
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");
|
|
|
|
__forceinline__ SourceMeta() = default;
|
|
|
|
// TODO: faster encoding
|
|
__device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) {
|
|
src_rdma_rank = rdma_rank;
|
|
is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
|
|
#pragma unroll
|
|
for (int i = 1; i < NUM_MAX_NVL_PEERS; ++ i)
|
|
is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
|
|
}
|
|
|
|
__device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
|
|
return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
|
|
}
|
|
};
|
|
|
|
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
|
|
|
|
int get_source_meta_bytes() {
|
|
return sizeof(SourceMeta);
|
|
}
|
|
|
|
__host__ __device__ __forceinline__
|
|
int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
|
|
return static_cast<int>(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
|
|
}
|
|
|
|
__host__ __device__ __forceinline__
|
|
std::pair<int, int> get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
|
|
// Return `int32_t` offset and count to clean
|
|
return {
|
|
(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
|
|
(NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms
|
|
};
|
|
}
|
|
|
|
__host__ __device__ __forceinline__
|
|
std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_sms) {
|
|
// Return `int32_t` offset and to clean
|
|
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
|
|
return {
|
|
(num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_sms) / sizeof(int),
|
|
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
|
|
};
|
|
}
|
|
|
|
template <bool kLowLatencyMode>
|
|
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) {
|
|
return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
|
|
}
|
|
|
|
template <bool kLowLatencyMode>
|
|
__forceinline__ __device__ void nvshmem_barrier_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {
|
|
kLowLatencyMode ? void(nvshmem_barrier(rdma_team)) : nvshmem_barrier_all();
|
|
}
|
|
|
|
template <bool kLowLatencyMode, int kNumRDMARanks>
|
|
__global__ void
|
|
notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
|
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
|
|
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
|
const bool* is_token_in_rank, int num_tokens, int num_channels, int expert_alignment,
|
|
const int rdma_clean_offset, const int rdma_num_int_clean,
|
|
const int nvl_clean_offset, const int nvl_num_int_clean,
|
|
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
|
|
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
|
void* rdma_buffer_ptr,
|
|
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank,
|
|
const nvshmem_team_t rdma_team) {
|
|
auto sm_id = static_cast<int>(blockIdx.x);
|
|
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
|
|
auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
|
|
|
|
auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
|
auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS;
|
|
|
|
if (sm_id == 0) {
|
|
// Communication with others
|
|
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
|
|
EP_DEVICE_ASSERT(num_warps > 1);
|
|
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
|
|
if (thread_id == 32)
|
|
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
|
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
|
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
|
__syncthreads();
|
|
|
|
// Send numbers of tokens per rank/expert to RDMA ranks
|
|
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
|
auto rdma_recv_num_tokens_mixed = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
|
|
|
|
// Clean up for later data dispatch
|
|
EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int));
|
|
#pragma unroll
|
|
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
|
|
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
|
|
|
|
// Copy to send buffer
|
|
#pragma unroll
|
|
for (int i = thread_id; i < num_ranks; i += num_threads)
|
|
rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i];
|
|
#pragma unroll
|
|
for (int i = thread_id; i < num_experts; i += num_threads)
|
|
rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i];
|
|
if (thread_id < kNumRDMARanks)
|
|
rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id];
|
|
__syncthreads();
|
|
|
|
// Issue send
|
|
// TODO: more light fence or barrier or signaling
|
|
// TODO: overlap EP barrier and NVL cleaning
|
|
if (thread_id < kNumRDMARanks) {
|
|
nvshmem_int_put_nbi(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(thread_id),
|
|
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
|
|
}
|
|
__syncthreads();
|
|
if (thread_id == 0)
|
|
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
|
__syncthreads();
|
|
|
|
// NVL buffers
|
|
auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
|
|
auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
|
|
auto nvl_reduced_num_tokens_per_expert = Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
|
|
auto nvl_send_num_tokens_per_rank = AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
|
|
auto nvl_send_num_tokens_per_expert = AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
|
|
auto nvl_recv_num_tokens_per_rank = AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
|
|
auto nvl_recv_num_tokens_per_expert = AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
|
|
|
|
// Clean up for later data dispatch
|
|
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
|
|
EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes +
|
|
nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int));
|
|
#pragma unroll
|
|
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
|
|
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
|
|
|
|
// Reduce number of tokens per expert into the NVL send buffer
|
|
// TODO: may use NVSHMEM reduction
|
|
EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
|
|
if (thread_id < num_rdma_experts) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRDMARanks; ++ i)
|
|
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
|
|
nvl_reduced_num_tokens_per_expert[thread_id] = sum;
|
|
}
|
|
__syncthreads();
|
|
|
|
// Reduce RDMA received tokens
|
|
if (thread_id == 0) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRDMARanks; ++ i) {
|
|
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
|
|
recv_rdma_rank_prefix_sum[i] = sum;
|
|
}
|
|
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1);
|
|
*moe_recv_rdma_counter_mapped = sum;
|
|
}
|
|
|
|
// Send numbers of tokens per rank/expert to NVL ranks
|
|
EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
|
|
if (thread_id < NUM_MAX_NVL_PEERS) {
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRDMARanks; ++ i)
|
|
nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
|
|
#pragma unroll
|
|
for (int i = 0; i < num_nvl_experts; ++ i)
|
|
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
|
|
}
|
|
memory_fence();
|
|
__syncthreads();
|
|
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
|
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
|
__syncthreads();
|
|
|
|
// Reduce number of tokens per rank/expert
|
|
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
|
|
if (thread_id == 0) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < num_ranks; ++ i) {
|
|
int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
|
|
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
|
|
recv_gbl_rank_prefix_sum[i] = sum;
|
|
}
|
|
while (ld_volatile_global(moe_recv_counter_mapped) != -1);
|
|
*moe_recv_counter_mapped = sum;
|
|
}
|
|
if (thread_id < num_nvl_experts) {
|
|
int sum = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
|
|
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
|
|
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
|
|
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1);
|
|
moe_recv_expert_counter_mapped[thread_id] = sum;
|
|
}
|
|
|
|
// Finally barrier
|
|
__syncthreads();
|
|
if (thread_id == 32)
|
|
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
|
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
|
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
|
} else {
|
|
// Calculate meta data
|
|
int dst_rdma_rank = sm_id - 1;
|
|
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
|
|
int token_start_idx, token_end_idx;
|
|
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
|
|
|
// Iterate over tokens
|
|
int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
|
|
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) {
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
|
|
auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
|
|
auto is_token_in_rank_values = reinterpret_cast<const bool*>(&is_token_in_rank_uint64);
|
|
#pragma unroll
|
|
for (int j = 0; j < NUM_MAX_NVL_PEERS; ++ j)
|
|
per_nvl_rank_count[j] += is_token_in_rank_values[j];
|
|
total_count += (is_token_in_rank_uint64 != 0);
|
|
}
|
|
|
|
// Warp reduce
|
|
total_count = warp_reduce_sum(total_count);
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
|
|
per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);
|
|
|
|
// Write into channel matrix
|
|
if (lane_id == 0) {
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
|
|
gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i];
|
|
rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
|
|
}
|
|
}
|
|
|
|
// Calculate prefix sum
|
|
__syncthreads();
|
|
if (thread_id == 0) {
|
|
auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
|
|
#pragma unroll
|
|
for (int i = 1; i < num_channels; ++ i)
|
|
prefix_row[i] += prefix_row[i - 1];
|
|
}
|
|
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers");
|
|
if (thread_id < NUM_MAX_NVL_PEERS) {
|
|
auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
|
|
#pragma unroll
|
|
for (int i = 1; i < num_channels; ++ i)
|
|
prefix_row[i] += prefix_row[i - 1];
|
|
}
|
|
}
|
|
}
|
|
|
|
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
|
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
|
|
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
|
const bool* is_token_in_rank, int num_tokens, int num_channels,
|
|
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
|
|
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
|
|
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
|
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
|
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
|
int** task_fifo_ptrs, int head, int rank,
|
|
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
|
bool low_latency_mode) {
|
|
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
|
auto notify_dispatch_func = low_latency_mode ? \
|
|
notify_dispatch<true, num_rdma_ranks> : notify_dispatch<false, num_rdma_ranks>; \
|
|
LAUNCH_KERNEL(&cfg, notify_dispatch_func, \
|
|
num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \
|
|
num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \
|
|
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
|
|
is_token_in_rank, num_tokens, num_channels, expert_alignment, \
|
|
rdma_clean_meta.first, rdma_clean_meta.second, \
|
|
nvl_clean_meta.first, nvl_clean_meta.second, \
|
|
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
|
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
|
rdma_buffer_ptr, \
|
|
buffer_ptrs, task_fifo_ptrs, head, rank, \
|
|
cpu_rdma_team); } break
|
|
|
|
constexpr int kNumThreads = std::max(256, 32 * num_channels);
|
|
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
|
|
|
|
// Get clean meta
|
|
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
|
|
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
|
|
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
|
|
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
|
|
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
|
|
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
|
|
|
|
// Launch kernel
|
|
SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
|
|
SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
|
|
#undef NOTIFY_DISPATCH_LAUNCH_CASE
|
|
}
|
|
|
|
// At most 8 RDMA ranks to be sent
|
|
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
|
|
return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
|
|
}
|
|
|
|
template <bool kLowLatencyMode, int kNumRDMARanks, bool kCachedMode,
|
|
int kNumDispatchRDMASenderWarps, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
|
|
__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1)
|
|
dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta,
|
|
const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
|
int* send_rdma_head, int* send_nvl_head,
|
|
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
|
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
|
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
|
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
|
const bool* is_token_in_rank,
|
|
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
|
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
|
int rank, int num_ranks) {
|
|
enum class WarpRole {
|
|
kRDMASender,
|
|
kRDMASenderCoordinator,
|
|
kRDMAAndNVLForwarder,
|
|
kForwarderCoordinator,
|
|
kNVLReceivers
|
|
};
|
|
|
|
const auto sm_id = static_cast<int>(blockIdx.x);
|
|
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
|
|
const auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
|
|
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
|
|
const bool is_forwarder = sm_id % 2 == 0;
|
|
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
|
|
|
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
|
if (is_forwarder) {
|
|
if (warp_id < NUM_MAX_NVL_PEERS) {
|
|
return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
|
|
} else {
|
|
return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
|
|
}
|
|
} else if (warp_id < kNumDispatchRDMASenderWarps) {
|
|
return {WarpRole::kRDMASender, -1};
|
|
} else if (warp_id == kNumDispatchRDMASenderWarps) {
|
|
return {WarpRole::kRDMASenderCoordinator, -1};
|
|
} else {
|
|
return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};
|
|
}
|
|
}();
|
|
auto warp_role = role_meta.first;
|
|
auto target_rank = role_meta.second; // Not applicable for RDMA senders
|
|
EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS);
|
|
|
|
// Data checks
|
|
EP_DEVICE_ASSERT(num_topk <= 32);
|
|
|
|
// RDMA symmetric layout
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
|
|
auto hidden_bytes = hidden_int4 * sizeof(int4);
|
|
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
|
|
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
|
|
auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
|
|
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
|
|
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
|
|
|
|
// NVL buffer layouts
|
|
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
|
|
void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
|
|
int rs_wr_rank = 0, ws_rr_rank = 0;
|
|
if (warp_role == WarpRole::kRDMAAndNVLForwarder)
|
|
rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
|
|
if (warp_role == WarpRole::kNVLReceivers)
|
|
rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
|
|
|
|
// Allocate buffers
|
|
auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
|
|
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
|
|
// RDMA sender warp synchronization
|
|
__shared__ volatile int rdma_send_next_token_idx;
|
|
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
|
|
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
|
|
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
|
|
|
|
// Forward warp synchronization
|
|
__shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
|
|
__shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
|
|
auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); };
|
|
|
|
if (warp_role == WarpRole::kRDMASender) {
|
|
// Get tasks
|
|
int token_start_idx, token_end_idx;
|
|
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
|
|
|
// Clean shared memory
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
|
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
|
|
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
|
|
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
|
|
|
|
// Send number of tokens in this channel by `-value - 1`
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
|
|
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
|
|
if (lane_id < NUM_MAX_NVL_PEERS) {
|
|
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
|
|
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
|
|
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
|
|
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
|
|
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
|
|
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
|
|
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
|
|
}
|
|
nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
|
}
|
|
nvshmem_fence();
|
|
sync_rdma_sender_smem();
|
|
|
|
// Iterate over tokens and copy into buffer
|
|
int64_t token_idx;
|
|
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
|
|
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
|
|
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
|
|
// Read RDMA rank existence
|
|
uint64_t is_token_in_rank_uint64 = 0;
|
|
if (lane_id < kNumRDMARanks)
|
|
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
|
|
|
|
// Acquire sequential lock
|
|
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
|
|
__syncwarp();
|
|
|
|
// Acquire next tail
|
|
int rdma_tail_idx = -1;
|
|
if (is_token_in_rank_uint64 != 0) {
|
|
rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++;
|
|
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
|
|
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
|
}
|
|
__syncwarp();
|
|
|
|
// Store RDMA head for combine
|
|
if (lane_id < kNumRDMARanks and not kCachedMode)
|
|
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
|
|
|
|
// Update last token tail
|
|
if (last_rdma_tail_idx >= 0)
|
|
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
|
last_rdma_tail_idx = rdma_tail_idx;
|
|
|
|
// Release sequential lock
|
|
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
|
|
|
// Broadcast tails
|
|
SourceMeta src_meta;
|
|
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
|
|
void* dst_send_buffers[kNumTopkRDMARanks];
|
|
#pragma unroll
|
|
for (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) {
|
|
slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;
|
|
topk_ranks[num_topk_ranks] = i;
|
|
auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
|
|
auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);
|
|
if (lane_id == num_topk_ranks)
|
|
src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
|
|
dst_send_buffers[num_topk_ranks ++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
|
|
}
|
|
EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);
|
|
|
|
// Copy `x` into symmetric send buffer
|
|
auto st_broadcast = [=](const int key, const int4& value) {
|
|
#pragma unroll
|
|
for (int j = 0; j < num_topk_ranks; ++ j)
|
|
st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
|
|
};
|
|
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
|
|
#pragma unroll
|
|
for (int i = 0; i < num_topk_ranks; ++ i)
|
|
dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
|
|
|
|
// Copy source metadata into symmetric send buffer
|
|
if (lane_id < num_topk_ranks)
|
|
st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
|
|
#pragma unroll
|
|
for (int i = 0; i < num_topk_ranks; ++ i)
|
|
dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
|
|
|
|
// Copy `x_scales` into symmetric send buffer
|
|
#pragma unroll
|
|
for (int i = lane_id; i < num_scales; i += 32) {
|
|
auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
|
|
#pragma unroll
|
|
for (int j = 0; j < num_topk_ranks; ++ j)
|
|
st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
|
|
}
|
|
#pragma unroll
|
|
for (int i = 0; i < num_topk_ranks; ++ i)
|
|
dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
|
|
|
|
// Copy `topk_idx` and `topk_weights` into symmetric send buffer
|
|
#pragma unroll
|
|
for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) {
|
|
auto rank_idx = i / num_topk, copy_idx = i % num_topk;
|
|
auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
|
|
auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
|
|
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
|
|
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
|
|
}
|
|
}
|
|
|
|
// Epilogue
|
|
// Acquire sequential lock
|
|
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
|
|
__syncwarp();
|
|
|
|
// Update last token tail
|
|
if (last_rdma_tail_idx >= 0)
|
|
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
|
|
|
// Release sequential lock
|
|
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
|
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
|
|
// NOTES: in case of splitting the issued put at the end of the buffer
|
|
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
|
|
|
|
// Synchronize shared memory
|
|
sync_rdma_sender_smem();
|
|
|
|
// Get number of tokens to send for each RDMA rank
|
|
int num_tokens_to_send = 0;
|
|
if (lane_id < kNumRDMARanks) {
|
|
num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
|
|
if (channel_id > 0)
|
|
num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
|
|
}
|
|
|
|
// Iterate all RDMA ranks
|
|
int last_issued_tail = 0;
|
|
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
|
|
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
|
|
int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;
|
|
synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank);
|
|
if (synced_num_tokens_to_send == 0)
|
|
continue;
|
|
|
|
// Read progress
|
|
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
|
|
auto processed_tail = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
|
|
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
|
|
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
|
|
continue;
|
|
|
|
// Issue RDMA send
|
|
auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
|
|
EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send);
|
|
if (dst_rdma_rank != rdma_rank) {
|
|
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
|
|
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
|
|
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
|
rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
|
num_bytes_per_rdma_token * num_tokens_to_issue,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
|
nvshmem_fence();
|
|
} else {
|
|
// Lighter fence for local RDMA rank
|
|
memory_fence();
|
|
}
|
|
|
|
// Update tails
|
|
__syncwarp();
|
|
if (lane_id == dst_rdma_rank) {
|
|
last_issued_tail += num_tokens_to_issue;
|
|
num_tokens_to_send -= num_tokens_to_issue;
|
|
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
|
}
|
|
}
|
|
}
|
|
} else if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
|
|
// RDMA consumers and NVL producers
|
|
const auto dst_nvl_rank = target_rank;
|
|
const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;
|
|
const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);
|
|
const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks);
|
|
|
|
// Wait counters to arrive
|
|
int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
|
|
EP_DEVICE_ASSERT(kNumRDMARanks <= 32);
|
|
auto start_time = clock64();
|
|
if (lane_id < kNumRDMARanks) {
|
|
while (true) {
|
|
auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);
|
|
auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank);
|
|
auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);
|
|
auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);
|
|
if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) {
|
|
// Notify NVL ranks
|
|
int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
|
|
EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum);
|
|
st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
|
|
st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);
|
|
|
|
// Save RDMA channel received token count
|
|
src_rdma_channel_prefix = -meta_2 - 1;
|
|
auto src_rdma_channel_prefix_1 = -meta_3 - 1;
|
|
num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix;
|
|
if (not kCachedMode)
|
|
recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;
|
|
src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1];
|
|
EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
|
|
break;
|
|
}
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
|
printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
|
|
channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
|
|
trap();
|
|
}
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
// Shift cached head
|
|
send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;
|
|
|
|
// Wait shared memory to be cleaned
|
|
sync_forwarder_smem();
|
|
|
|
// Forward tokens from RDMA buffer
|
|
// NOTES: always start from the local rank
|
|
int src_rdma_rank = sm_id % kNumRDMARanks;
|
|
int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
|
|
int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
|
|
while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) {
|
|
// Check destination queue emptiness, or wait a buffer to be released
|
|
start_time = clock64();
|
|
while (lane_id == 0) {
|
|
int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
|
|
if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
|
|
break;
|
|
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
|
printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
|
|
trap();
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
// Find next source RDMA rank (round-robin)
|
|
start_time = clock64();
|
|
while (true) {
|
|
src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
|
|
if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
|
|
if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail)
|
|
cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
|
|
if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank))
|
|
break;
|
|
}
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
|
|
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
|
|
trap();
|
|
}
|
|
}
|
|
auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank);
|
|
auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank);
|
|
|
|
// Iterate over every token from the RDMA buffer
|
|
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
|
|
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
|
|
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
|
|
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
|
|
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
|
|
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
|
|
if (lane_id == src_rdma_rank) {
|
|
auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
|
|
rdma_nvl_token_idx += is_in_dst_nvl_rank;
|
|
if (not kCachedMode)
|
|
send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
|
|
}
|
|
if (not is_in_dst_nvl_rank)
|
|
continue;
|
|
|
|
// Get an empty slot
|
|
int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;
|
|
|
|
// Copy data
|
|
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
|
|
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
|
|
reinterpret_cast<int4*>(shifted),
|
|
ld_nc_global, st_na_global);
|
|
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
|
|
|
|
// Copy source meta
|
|
if (lane_id == 0)
|
|
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
|
|
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
|
|
|
|
// Copy `x_scales`
|
|
UNROLLED_WARP_COPY(1, lane_id, num_scales,
|
|
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
|
|
reinterpret_cast<float*>(shifted),
|
|
ld_nc_global, st_na_global);
|
|
shifted = reinterpret_cast<float*>(shifted) + num_scales;
|
|
|
|
// Copy `topk_idx` and `topk_weights`
|
|
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
|
|
if (lane_id < num_topk) {
|
|
// Read
|
|
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
|
|
shifted = reinterpret_cast<int*>(shifted) + num_topk;
|
|
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
|
|
|
|
// Transform and write
|
|
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
|
|
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
|
|
weight_value = idx_value >= 0 ? weight_value : 0.0f;
|
|
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
|
|
}
|
|
|
|
// In case of insufficient NVL buffers, early stopping
|
|
if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens)
|
|
src_rdma_tail = i + 1;
|
|
}
|
|
|
|
// Sync head index
|
|
if (lane_id == src_rdma_rank)
|
|
forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
|
|
|
|
// Move tail index
|
|
__syncwarp();
|
|
if (lane_id == 0)
|
|
st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
|
|
}
|
|
|
|
// Retired
|
|
__syncwarp();
|
|
if (lane_id == 0)
|
|
forward_channel_retired[dst_nvl_rank] = true;
|
|
} else if (warp_role == WarpRole::kForwarderCoordinator) {
|
|
// Extra warps for forwarder coordinator should exit directly
|
|
if (target_rank > 0)
|
|
return;
|
|
|
|
// Forward warp coordinator
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
|
|
|
|
// Clean shared memory
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers");
|
|
#pragma unroll
|
|
for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32)
|
|
forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
|
|
if (lane_id < NUM_MAX_NVL_PEERS)
|
|
forward_channel_retired[lane_id] = false;
|
|
sync_forwarder_smem();
|
|
|
|
int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
|
|
while (true) {
|
|
// Find minimum head
|
|
int min_head = std::numeric_limits<int>::max();
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i])
|
|
min_head = min(min_head, forward_channel_head[i][target_rdma]);
|
|
if (__all_sync(0xffffffff, min_head == std::numeric_limits<int>::max()))
|
|
break;
|
|
|
|
// Update remote head
|
|
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
|
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
|
last_head = min_head;
|
|
}
|
|
|
|
// Nanosleep and let other warps work
|
|
__nanosleep(NUM_WAIT_NANOSECONDS);
|
|
}
|
|
} else {
|
|
// NVL consumers
|
|
// Retrieve rank offset from barrier results (each lane's register stores an RDMA rank)
|
|
int src_nvl_rank = target_rank, total_offset = 0;
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
|
|
if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
|
|
total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];
|
|
|
|
// Receive channel offsets
|
|
int start_offset = 0, end_offset = 0, num_tokens_to_recv;
|
|
auto start_time = clock64();
|
|
while (lane_id < kNumRDMARanks) {
|
|
start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
|
|
end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
|
|
if (start_offset < 0 and end_offset < 0) {
|
|
start_offset = -start_offset - 1, end_offset = -end_offset - 1;
|
|
total_offset += start_offset;
|
|
break;
|
|
}
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
|
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
|
|
trap();
|
|
}
|
|
}
|
|
num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);
|
|
|
|
// Save for combine usage
|
|
if (lane_id < kNumRDMARanks and not kCachedMode)
|
|
recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
|
|
__syncwarp();
|
|
|
|
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
|
while (num_tokens_to_recv > 0) {
|
|
// Check channel status by lane 0
|
|
start_time = clock64();
|
|
while (lane_id == 0) {
|
|
// Ready to copy
|
|
if (cached_channel_head_idx != cached_channel_tail_idx)
|
|
break;
|
|
cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
|
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
|
|
trap();
|
|
}
|
|
}
|
|
|
|
// Sync queue tail
|
|
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
|
|
|
|
// Copy data
|
|
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
|
|
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) {
|
|
int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens;
|
|
auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
|
|
int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank);
|
|
(lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;
|
|
|
|
// Copy data
|
|
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
|
|
recv_x + recv_token_idx * hidden_int4,
|
|
nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
|
|
ld_nc_global, st_na_global);
|
|
|
|
// Copy source meta
|
|
if (lane_id == 0 and not kCachedMode)
|
|
st_na_global(recv_src_meta + recv_token_idx, meta);
|
|
|
|
// Copy scales
|
|
UNROLLED_WARP_COPY(1, lane_id, num_scales,
|
|
recv_x_scales + recv_token_idx * num_scales,
|
|
nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
|
|
ld_nc_global, st_na_global);
|
|
|
|
// Copy `topk_idx` and `topk_weights`
|
|
if (lane_id < num_topk) {
|
|
auto recv_idx = recv_token_idx * num_topk + lane_id;
|
|
auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
|
|
st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
|
|
st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
|
|
}
|
|
}
|
|
|
|
// Move queue
|
|
__syncwarp();
|
|
if (lane_id == 0)
|
|
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
|
|
}
|
|
}
|
|
}
|
|
|
|
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
|
|
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
|
int* send_rdma_head, int* send_nvl_head,
|
|
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
|
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
|
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
|
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
|
const bool* is_token_in_rank,
|
|
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
|
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
|
int rank, int num_ranks, bool is_cached_dispatch,
|
|
cudaStream_t stream, int num_channels, bool low_latency_mode) {
|
|
constexpr int kNumDispatchRDMASenderWarps = 7;
|
|
|
|
#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
|
auto dispatch_func = low_latency_mode ? \
|
|
(is_cached_dispatch ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>) : \
|
|
(is_cached_dispatch ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps> : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>); \
|
|
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
|
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast<SourceMeta*>(recv_src_meta), \
|
|
reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
|
send_rdma_head, send_nvl_head, \
|
|
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \
|
|
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
|
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
|
num_tokens, hidden_int4, num_scales, num_topk, num_experts, \
|
|
is_token_in_rank, \
|
|
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
|
|
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
|
|
rank, num_ranks); } break
|
|
|
|
EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
|
|
EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
|
|
|
|
SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream);
|
|
SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
|
|
#undef DISPATCH_LAUNCH_CASE
|
|
}
|
|
|
|
template <bool kLowLatencyMode>
|
|
__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean,
|
|
const int nvl_clean_offset, const int nvl_num_int_clean,
|
|
int* combined_rdma_head, int num_combined_tokens, int num_channels,
|
|
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
|
void* rdma_buffer_ptr,
|
|
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
|
bool is_cached_dispatch, const nvshmem_team_t rdma_team) {
|
|
auto sm_id = static_cast<int>(blockIdx.x);
|
|
auto thread_id = static_cast<int>(threadIdx.x);
|
|
auto num_threads = static_cast<int>(blockDim.x);
|
|
auto num_warps = num_threads / 32;
|
|
auto warp_id = thread_id / 32;
|
|
auto lane_id = get_lane_id();
|
|
|
|
auto nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
|
auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
|
|
|
|
// Using two SMs, which clean the RDMA/NVL buffer respectively
|
|
if (sm_id == 0) {
|
|
// Barrier for RDMA
|
|
if (thread_id == 0)
|
|
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
|
__syncthreads();
|
|
|
|
// Clean
|
|
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
|
#pragma unroll
|
|
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
|
|
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
|
|
nvshmem_fence();
|
|
__syncthreads();
|
|
|
|
// Barrier again
|
|
if (thread_id == 0)
|
|
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
|
} else if (sm_id == 1) {
|
|
// Barrier for NVL
|
|
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
|
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
|
__syncthreads();
|
|
|
|
// Clean
|
|
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
|
|
#pragma unroll
|
|
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
|
|
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
|
|
memory_fence();
|
|
__syncthreads();
|
|
|
|
// Barrier again
|
|
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
|
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
|
} else if (sm_id == 2) {
|
|
if (is_cached_dispatch)
|
|
return;
|
|
|
|
EP_DEVICE_ASSERT(num_warps >= num_channels);
|
|
EP_DEVICE_ASSERT(num_rdma_ranks <= 32);
|
|
|
|
// Iterate in reverse order
|
|
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
|
|
int token_start_idx, token_end_idx;
|
|
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx);
|
|
|
|
// NOTES: `1 << 25` is a heuristic large number
|
|
int last_head = 1 << 25;
|
|
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) {
|
|
auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
|
|
if (current_head < 0) {
|
|
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
|
|
} else {
|
|
last_head = current_head;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
if (is_cached_dispatch)
|
|
return;
|
|
|
|
EP_DEVICE_ASSERT(num_warps >= num_channels);
|
|
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers");
|
|
|
|
if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
|
|
for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) {
|
|
// Iterate in reverse order
|
|
int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
|
|
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
|
|
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
|
|
token_start_idx += shift, token_end_idx += shift;
|
|
|
|
// NOTES: `1 << 25` is a heuristic large number
|
|
int last_head = 1 << 25;
|
|
#pragma unroll
|
|
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) {
|
|
auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
|
|
if (current_head < 0) {
|
|
combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
|
|
} else {
|
|
last_head = current_head;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
|
|
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
|
|
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
|
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
|
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
|
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
|
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
|
bool is_cached_dispatch, bool low_latency_mode) {
|
|
const int num_threads = std::max(128, 32 * num_channels);
|
|
const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
|
|
|
|
// Get clean meta
|
|
auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels);
|
|
auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
|
|
EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
|
|
EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
|
|
EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
|
|
EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
|
|
EP_HOST_ASSERT(num_channels * 2 > 3);
|
|
|
|
// Launch kernel
|
|
auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
|
|
SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
|
|
LAUNCH_KERNEL(&cfg, cached_notify_func,
|
|
rdma_clean_meta.first, rdma_clean_meta.second,
|
|
nvl_clean_meta.first, nvl_clean_meta.second,
|
|
combined_rdma_head, num_combined_tokens, num_channels,
|
|
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
|
|
rdma_buffer_ptr,
|
|
buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks,
|
|
is_cached_dispatch, cpu_rdma_team);
|
|
}
|
|
|
|
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
|
|
__device__ int combine_token(bool is_token_in_rank, int head_idx,
|
|
int lane_id, int hidden_int4, int num_topk,
|
|
int4* combined_row, float* combined_topk_weights,
|
|
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
|
|
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
|
|
|
// Broadcast current heads
|
|
// Lane `i` holds the head of rank `i` and `is_token_in_rank`
|
|
EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks");
|
|
int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRanks; ++ i) if (__shfl_sync(0xffffffff, is_token_in_rank, i)) {
|
|
slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens;
|
|
topk_ranks[num_topk_ranks ++] = i;
|
|
}
|
|
EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);
|
|
|
|
// Reduce data
|
|
#pragma unroll
|
|
for (int i = lane_id; i < hidden_int4; i += 32) {
|
|
// Read buffers
|
|
// TODO: maybe too many registers here
|
|
int4 recv_value_int4[kMaxNumRanks];
|
|
#pragma unroll
|
|
for (int j = 0; j < num_topk_ranks; ++ j)
|
|
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
|
|
|
|
// Reduce all-to-all results
|
|
float values[kDtypePerInt4] = {0};
|
|
#pragma unroll
|
|
for (int j = 0; j < num_topk_ranks; ++ j) {
|
|
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
|
#pragma unroll
|
|
for (int k = 0; k < kDtypePerInt4; ++ k)
|
|
values[k] += static_cast<float>(recv_value_dtypes[k]);
|
|
}
|
|
|
|
// Cast back to `dtype_t` and write
|
|
int4 out_int4;
|
|
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
|
|
#pragma unroll
|
|
for (int j = 0; j < kDtypePerInt4; ++ j)
|
|
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
|
st_na_global(combined_row + i, out_int4);
|
|
}
|
|
|
|
// Reduce `topk_weights`
|
|
if (lane_id < num_topk) {
|
|
float value = 0;
|
|
#pragma unroll
|
|
for (int i = 0; i < num_topk_ranks; ++ i)
|
|
value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
|
|
st_na_global(combined_topk_weights + lane_id, value);
|
|
}
|
|
|
|
// Return the minimum top-k rank
|
|
return topk_ranks[0];
|
|
}
|
|
|
|
template<bool kLowLatencyMode,
|
|
int kNumRDMARanks, typename dtype_t,
|
|
int kNumCombineForwarderWarps,
|
|
int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks),
|
|
int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
|
|
int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder,
|
|
int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS>
|
|
__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1)
|
|
combine(int4* combined_x, float* combined_topk_weights,
|
|
const bool* is_combined_token_in_rank,
|
|
const int4* x, const float* topk_weights,
|
|
const int* combined_rdma_head, const int* combined_nvl_head,
|
|
const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
|
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
|
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
|
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
|
int rank, int num_ranks) {
|
|
enum class WarpRole {
|
|
kNVLSender,
|
|
kNVLAndRDMAForwarder,
|
|
kRDMAReceiver,
|
|
kCoordinator
|
|
};
|
|
|
|
const auto sm_id = static_cast<int>(blockIdx.x);
|
|
const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
|
|
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
|
|
const auto num_channels = static_cast<int>(gridDim.x) / 2, channel_id = sm_id / 2;
|
|
const bool is_rdma_receiver_sm = sm_id % 2 == 1;
|
|
|
|
EP_DEVICE_ASSERT(num_topk <= 32);
|
|
EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0);
|
|
const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));
|
|
|
|
// NOTES: we decouple a channel into 2 SMs
|
|
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
|
auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
|
auto warp_id = thread_id / 32;
|
|
if (not is_rdma_receiver_sm) {
|
|
if (warp_id < NUM_MAX_NVL_PEERS) {
|
|
auto shuffled_warp_id = warp_id;
|
|
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;
|
|
return {WarpRole::kNVLSender, shuffled_warp_id};
|
|
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
|
|
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS;
|
|
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
|
|
return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id};
|
|
} else {
|
|
return {WarpRole::kCoordinator, 0};
|
|
}
|
|
} else {
|
|
if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
|
|
return {WarpRole::kRDMAReceiver, warp_id};
|
|
} else {
|
|
return {WarpRole::kCoordinator, 0};
|
|
}
|
|
}
|
|
}();
|
|
auto warp_role = role_meta.first;
|
|
auto warp_id = role_meta.second;
|
|
|
|
EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1);
|
|
auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;
|
|
|
|
if (warp_role == WarpRole::kNVLSender) {
|
|
// NVL producers
|
|
const auto dst_nvl_rank = warp_id;
|
|
|
|
// NVL layouts
|
|
// NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
|
|
auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
|
|
auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
|
|
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
|
|
auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
|
|
auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
|
|
auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
|
|
|
|
// Get tasks for each RDMA lane
|
|
int token_start_idx = 0, token_end_idx = 0;
|
|
if (lane_id < kNumRDMARanks) {
|
|
int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
|
|
token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
|
|
token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
|
|
}
|
|
__syncwarp();
|
|
|
|
// NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
|
|
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
|
|
|
|
// Iterate over all tokens and send by chunks
|
|
while (true) {
|
|
// Exit if possible
|
|
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
|
|
break;
|
|
|
|
// Decide next RDMA buffer to send
|
|
bool is_lane_ready = false;
|
|
auto start_time = clock64();
|
|
while (true) {
|
|
int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
|
|
is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;
|
|
if (__any_sync(0xffffffff, is_lane_ready))
|
|
break;
|
|
|
|
// Retry
|
|
if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
|
|
cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
|
|
printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx,
|
|
token_start_idx, token_end_idx);
|
|
trap();
|
|
}
|
|
}
|
|
|
|
// Sync token start index and count
|
|
for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) {
|
|
if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
|
|
continue;
|
|
|
|
// Sync token start index
|
|
auto token_idx = static_cast<int64_t>(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx));
|
|
int num_tokens_in_chunk = __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
|
|
|
|
// Send by chunk
|
|
for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++ chunk_idx, ++ token_idx) {
|
|
// Get an empty slot
|
|
int dst_slot_idx = 0;
|
|
if (lane_id == current_rdma_idx) {
|
|
dst_slot_idx = (cached_channel_tail_idx ++) % num_max_nvl_chunked_recv_tokens_per_rdma;
|
|
dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
|
|
}
|
|
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx);
|
|
|
|
// Copy data
|
|
auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
|
|
auto shifted_x = x + token_idx * hidden_int4;
|
|
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
|
|
|
// Copy source meta
|
|
if (lane_id == 0)
|
|
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
|
|
|
|
// Copy `topk_weights`
|
|
if (lane_id < num_topk)
|
|
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
|
|
}
|
|
lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
|
|
}
|
|
|
|
// Move queue tail
|
|
__syncwarp();
|
|
if (lane_id < kNumRDMARanks and is_lane_ready)
|
|
st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
|
|
}
|
|
} else {
|
|
// Combiners and coordinators
|
|
// RDMA symmetric layout
|
|
auto hidden_bytes = hidden_int4 * sizeof(int4);
|
|
auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
|
|
auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
|
|
auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
|
|
auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
|
|
|
|
// NVL layouts
|
|
void* local_nvl_buffer = buffer_ptrs[nvl_rank];
|
|
void* nvl_buffers[NUM_MAX_NVL_PEERS];
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
|
|
nvl_buffers[i] = buffer_ptrs[i];
|
|
auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
|
|
auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
|
|
auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
|
|
auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
|
|
auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
|
|
|
|
// Combiner warp synchronization
|
|
__shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
|
|
__shared__ volatile bool forwarder_retired[kNumForwarders];
|
|
__shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
|
|
__shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
|
|
auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); };
|
|
auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); };
|
|
|
|
if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
|
|
// Receive from NVL ranks and forward to RDMA ranks
|
|
// NOTES: this part is using "large warps" for each RDMA ranks
|
|
const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder;
|
|
const auto sub_warp_id = warp_id % kNumWarpsPerForwarder;
|
|
auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
|
|
auto sync_large_warp = [=]() {
|
|
if (kNumWarpsPerForwarder == 1) {
|
|
__syncwarp();
|
|
} else {
|
|
asm volatile("bar.sync %0, %1;" :: "r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32));
|
|
}
|
|
};
|
|
EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough");
|
|
|
|
// Advance to the corresponding NVL buffer
|
|
nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
|
|
nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
|
|
nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
|
|
nvl_channel_head.advance(dst_rdma_rank);
|
|
nvl_channel_tail.advance(dst_rdma_rank);
|
|
|
|
// Clean shared memory and sync
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers");
|
|
lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0;
|
|
lane_id == 0 ? (forwarder_retired[warp_id] = false) : false;
|
|
sync_forwarder_smem();
|
|
|
|
// Get count and cached head
|
|
int cached_nvl_channel_tail_idx = 0;
|
|
int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
|
|
int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
|
|
num_tokens_to_combine -= num_tokens_prefix;
|
|
num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
|
|
combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;
|
|
|
|
// Iterate over all tokens and combine by chunks
|
|
for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
|
|
// Check destination queue emptiness, or wait a buffer to be released
|
|
auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
|
|
auto num_chunked_tokens = token_end_idx - token_start_idx;
|
|
auto start_time = clock64();
|
|
while (sub_warp_id == 0 and lane_id == 0) {
|
|
// Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
|
|
// Here, `token_start_idx` is the actual tail
|
|
int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));
|
|
if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
|
|
break;
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
|
printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
|
|
trap();
|
|
}
|
|
}
|
|
sync_large_warp();
|
|
|
|
// Combine and write to the RDMA buffer
|
|
for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
|
|
// Read expected head
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
|
|
int expected_head = -1;
|
|
if (lane_id < NUM_MAX_NVL_PEERS)
|
|
expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
|
|
|
|
// Wait lanes to be ready
|
|
start_time = clock64();
|
|
while (cached_nvl_channel_tail_idx <= expected_head) {
|
|
cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
|
|
printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
|
|
trap();
|
|
}
|
|
}
|
|
|
|
// Combine current token
|
|
auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
|
|
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
|
|
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
|
|
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
|
|
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
|
expected_head, lane_id,
|
|
hidden_int4, num_topk,
|
|
reinterpret_cast<int4*>(shifted),
|
|
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
|
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
|
|
|
// Update head
|
|
if (lane_id < NUM_MAX_NVL_PEERS)
|
|
expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1);
|
|
}
|
|
sync_large_warp();
|
|
|
|
// Issue RDMA send
|
|
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
|
|
if (dst_rdma_rank != rdma_rank) {
|
|
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
|
|
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
|
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
|
num_chunked_tokens * num_bytes_per_rdma_token,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
|
nvshmem_fence();
|
|
} else {
|
|
memory_fence();
|
|
}
|
|
|
|
// Write new RDMA tail
|
|
__syncwarp();
|
|
if (lane_id == 0)
|
|
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
|
}
|
|
}
|
|
|
|
// Retired
|
|
__syncwarp();
|
|
if (lane_id == 0)
|
|
forwarder_retired[warp_id] = true;
|
|
} else if (warp_role == WarpRole::kRDMAReceiver) {
|
|
// Receive from RDMA ranks and write to the output tensor
|
|
// Clean shared memory and sync
|
|
EP_DEVICE_ASSERT(kNumRDMARanks <= 32);
|
|
lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0;
|
|
lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0;
|
|
sync_rdma_receiver_smem();
|
|
|
|
// The same tokens as the dispatch process
|
|
int token_start_idx, token_end_idx;
|
|
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
|
|
|
// Iterate over all tokens and combine
|
|
int cached_channel_tail_idx = 0;
|
|
for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
|
|
// Read expected head
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers");
|
|
int expected_head = -1;
|
|
if (lane_id < kNumRDMARanks) {
|
|
expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
|
|
(expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head);
|
|
}
|
|
|
|
// Wait lanes to be ready
|
|
auto start_time = clock64();
|
|
while (cached_channel_tail_idx <= expected_head) {
|
|
cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
|
|
|
|
// Timeout check
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
|
printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
|
|
channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
|
|
trap();
|
|
}
|
|
}
|
|
__syncwarp();
|
|
|
|
// Combine current token
|
|
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
|
|
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
|
|
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
|
|
expected_head, lane_id,
|
|
hidden_int4, num_topk,
|
|
combined_x + token_idx * hidden_int4,
|
|
combined_topk_weights + token_idx * num_topk,
|
|
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
|
|
}
|
|
|
|
// Retired
|
|
__syncwarp();
|
|
if (lane_id == 0)
|
|
rdma_receiver_retired[warp_id] = true;
|
|
} else {
|
|
// Coordinator
|
|
// Sync shared memory status
|
|
is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem();
|
|
const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;
|
|
|
|
int last_rdma_head = 0;
|
|
int last_nvl_head[kNumRDMARanks] = {0};
|
|
int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0;
|
|
int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
|
|
EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps");
|
|
while (true) {
|
|
// Retired
|
|
if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
|
|
break;
|
|
if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
|
|
break;
|
|
|
|
// Find minimum head for RDMA ranks
|
|
if (is_rdma_receiver_sm) {
|
|
int min_head = std::numeric_limits<int>::max();
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
|
|
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
|
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
|
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
|
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
|
last_rdma_head = min_head;
|
|
}
|
|
} else {
|
|
// Find minimum head for NVL ranks
|
|
#pragma unroll
|
|
for (int i = 0; i < kNumRDMARanks; ++ i) {
|
|
int min_head = std::numeric_limits<int>::max();
|
|
#pragma unroll
|
|
for (int j = 0; j < num_warps_per_rdma_rank; ++ j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j])
|
|
min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);
|
|
if (min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS)
|
|
st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
|
|
}
|
|
}
|
|
|
|
// Nanosleep and let other warps work
|
|
__nanosleep(NUM_WAIT_NANOSECONDS);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void combine(cudaDataType_t type,
|
|
void* combined_x, float* combined_topk_weights,
|
|
const bool* is_combined_token_in_rank,
|
|
const void* x, const float* topk_weights,
|
|
const int* combined_rdma_head, const int* combined_nvl_head,
|
|
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
|
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
|
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
|
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
|
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) {
|
|
constexpr int kNumCombineForwarderWarps = 16;
|
|
|
|
#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \
|
|
auto combine_func = low_latency_mode ? \
|
|
combine<true, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps> : combine<false, num_rdma_ranks, nv_bfloat16, kNumCombineForwarderWarps>; \
|
|
LAUNCH_KERNEL(&cfg, combine_func, \
|
|
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
|
|
reinterpret_cast<const int4*>(x), topk_weights, \
|
|
combined_rdma_head, combined_nvl_head, \
|
|
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
|
|
num_tokens, num_combined_tokens, hidden, num_topk, \
|
|
rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \
|
|
buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \
|
|
rank, num_ranks); } break
|
|
|
|
int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
|
|
auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
|
|
int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder;
|
|
EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
|
|
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
|
|
EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
|
|
EP_HOST_ASSERT(type == CUDA_R_16BF);
|
|
|
|
SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream);
|
|
SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
|
|
#undef COMBINE_LAUNCH_CASE
|
|
}
|
|
|
|
} // namespace internode
|
|
|
|
} // namespace deep_ep
|