mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-16 19:29:14 +00:00
Add automatic warp count control for low-latency kernels (#213)
* Add automatic warp count control for low-latency dispatch * Add automatic warp count control for low-latency combine * More assertions
This commit is contained in:
parent
4e923188f7
commit
1b92be8a71
@ -6,13 +6,13 @@
|
|||||||
namespace deep_ep {
|
namespace deep_ep {
|
||||||
|
|
||||||
template <typename dtype_t>
|
template <typename dtype_t>
|
||||||
dtype_t cell_div(dtype_t a, dtype_t b) {
|
dtype_t ceil_div(dtype_t a, dtype_t b) {
|
||||||
return (a + b - 1) / b;
|
return (a + b - 1) / b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dtype_t>
|
template <typename dtype_t>
|
||||||
dtype_t align(dtype_t a, dtype_t b) {
|
dtype_t align(dtype_t a, dtype_t b) {
|
||||||
return cell_div<dtype_t>(a, b) * b;
|
return ceil_div<dtype_t>(a, b) * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Config {
|
struct Config {
|
||||||
|
@ -41,6 +41,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
|||||||
// Get device info
|
// Get device info
|
||||||
cudaDeviceProp device_prop = {};
|
cudaDeviceProp device_prop = {};
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
|
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
|
||||||
|
num_device_sms = device_prop.multiProcessorCount;
|
||||||
|
|
||||||
if (num_nvl_bytes > 0) {
|
if (num_nvl_bytes > 0) {
|
||||||
// Local IPC: alloc local memory and set local IPC handles
|
// Local IPC: alloc local memory and set local IPC handles
|
||||||
@ -1142,7 +1143,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
|||||||
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||||
num_topk, num_experts, rank, num_ranks,
|
num_topk, num_experts, rank, num_ranks,
|
||||||
use_fp8, round_scale, use_ue8m0,
|
use_fp8, round_scale, use_ue8m0,
|
||||||
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
|
workspace, low_latency_usage_flag_mapped,
|
||||||
|
num_device_sms, launch_stream,
|
||||||
|
phases);
|
||||||
};
|
};
|
||||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||||
|
|
||||||
@ -1234,7 +1237,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
|||||||
next_clean_meta.first, next_clean_meta.second,
|
next_clean_meta.first, next_clean_meta.second,
|
||||||
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||||
num_topk, num_experts, rank, num_ranks,
|
num_topk, num_experts, rank, num_ranks,
|
||||||
workspace, low_latency_usage_flag_mapped, launch_stream,
|
workspace, low_latency_usage_flag_mapped,
|
||||||
|
num_device_sms, launch_stream,
|
||||||
phases, zero_copy);
|
phases, zero_copy);
|
||||||
};
|
};
|
||||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||||
|
@ -41,6 +41,7 @@ private:
|
|||||||
|
|
||||||
// Device info and communication
|
// Device info and communication
|
||||||
int device_id;
|
int device_id;
|
||||||
|
int num_device_sms;
|
||||||
int rank, rdma_rank, nvl_rank;
|
int rank, rdma_rank, nvl_rank;
|
||||||
int num_ranks, num_rdma_ranks, num_nvl_ranks;
|
int num_ranks, num_rdma_ranks, num_nvl_ranks;
|
||||||
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
|
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
|
||||||
|
@ -148,7 +148,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int* usage_flag,
|
||||||
cudaStream_t stream, int phases);
|
int num_device_sms, cudaStream_t stream,
|
||||||
|
int phases);
|
||||||
|
|
||||||
void combine(void* combined_x,
|
void combine(void* combined_x,
|
||||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||||
@ -158,7 +159,8 @@ void combine(void* combined_x,
|
|||||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int* usage_flag,
|
||||||
cudaStream_t stream, int phases, bool zero_copy);
|
int num_device_sms, cudaStream_t stream,
|
||||||
|
int phases, bool zero_copy);
|
||||||
|
|
||||||
} // namespace internode_ll
|
} // namespace internode_ll
|
||||||
|
|
||||||
|
@ -36,9 +36,8 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
|||||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool kUseFP8, bool kUseUE8M0,
|
template <bool kUseFP8, bool kUseUE8M0, int kHidden>
|
||||||
int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
__global__ __launch_bounds__(1024, 1) void
|
||||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
|
||||||
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
||||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||||
int* packed_recv_count,
|
int* packed_recv_count,
|
||||||
@ -49,16 +48,18 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
int* next_clean, int num_next_clean_int,
|
int* next_clean, int num_next_clean_int,
|
||||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
bool round_scale, int* usage_flag, int phases) {
|
bool round_scale, int* usage_flag,
|
||||||
|
int num_warp_groups, int num_warps_per_group,
|
||||||
|
int phases) {
|
||||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||||
const auto num_sms = static_cast<int>(gridDim.x);
|
const auto num_sms = static_cast<int>(gridDim.x);
|
||||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
const auto num_warps = num_warp_groups * num_warps_per_group;
|
||||||
const auto num_local_experts = num_experts / num_ranks;
|
const auto num_local_experts = num_experts / num_ranks;
|
||||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
const auto warp_group_id = warp_id / num_warps_per_group;
|
||||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
const auto sub_warp_id = warp_id % num_warps_per_group;
|
||||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
|
||||||
|
|
||||||
// May extract UE8M0 from the scales
|
// May extract UE8M0 from the scales
|
||||||
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
|
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
|
||||||
@ -78,13 +79,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
|
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
|
||||||
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
|
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
|
||||||
|
|
||||||
|
// Expert counts
|
||||||
|
constexpr int kNumMaxWarpGroups = 32;
|
||||||
|
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
|
||||||
|
|
||||||
// Sending phase
|
// Sending phase
|
||||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||||
goto LOW_LATENCY_DISPATCH_RECV;
|
goto LOW_LATENCY_DISPATCH_RECV;
|
||||||
|
|
||||||
// Expert counts
|
|
||||||
__shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups];
|
|
||||||
|
|
||||||
// There are 2 kinds of warps in this part:
|
// There are 2 kinds of warps in this part:
|
||||||
// 1. The first-kind warps for FP8 cast and sending top-k tokens
|
// 1. The first-kind warps for FP8 cast and sending top-k tokens
|
||||||
// 2. The last warp for reading `topk_idx` and count for per-expert information
|
// 2. The last warp for reading `topk_idx` and count for per-expert information
|
||||||
@ -96,8 +98,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
|
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
|
||||||
|
|
||||||
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
||||||
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
|
const auto x_int4 = static_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
|
||||||
const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
const auto rdma_x_src_idx = reinterpret_cast<int*>(static_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||||
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
|
||||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
|
||||||
|
|
||||||
@ -194,9 +196,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This SM should be responsible for some destination experts, read `topk_idx` for them
|
// This SM should be responsible for some destination experts, read `topk_idx` for them
|
||||||
int expert_count[kNumWarpGroups] = {0};
|
int expert_count[kNumMaxWarpGroups] = {0};
|
||||||
const auto expert_begin_idx = sm_id * kNumWarpGroups;
|
const auto expert_begin_idx = sm_id * num_warp_groups;
|
||||||
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
|
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
|
||||||
|
|
||||||
// Per lane count
|
// Per lane count
|
||||||
#pragma unroll 8
|
#pragma unroll 8
|
||||||
@ -222,7 +224,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
|
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
|
||||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||||
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
|
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
|
||||||
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
|
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
|
||||||
|
|
||||||
// Wait local sends issued and send expert counts
|
// Wait local sends issued and send expert counts
|
||||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||||
@ -257,23 +259,23 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
if (responsible_expert_idx < num_experts) {
|
if (responsible_expert_idx < num_experts) {
|
||||||
const auto src_rank = responsible_expert_idx / num_local_experts;
|
const auto src_rank = responsible_expert_idx / num_local_experts;
|
||||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||||
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
|
const auto rdma_recv_x_uint8 = static_cast<uint8_t*>(rdma_recv_x) +
|
||||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
const auto recv_x_int4 = static_cast<int4*>(packed_recv_x) +
|
||||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||||
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
|
const auto num_aligned_scales = align<int>(num_scales, sizeof(float) / sizeof(scale_t));
|
||||||
const auto recv_x_scales = reinterpret_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
|
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
|
||||||
|
|
||||||
// Shared between sub-warps in warp groups
|
// Shared between sub-warps in warp groups
|
||||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
|
||||||
|
|
||||||
// Wait tokens to arrive
|
// Wait tokens to arrive
|
||||||
// NOTES: using sub-warp 1 to overlap with sub-warp 0
|
// NOTES: using sub-warp 1 to overlap with sub-warp 0
|
||||||
int num_recv_tokens, recv_token_begin_idx;
|
int num_recv_tokens, recv_token_begin_idx;
|
||||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
|
||||||
if (sub_warp_id == 1 and lane_id == 0) {
|
if (sub_warp_id == 1 and lane_id == 0) {
|
||||||
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||||
num_recv_tokens = -num_recv_tokens - 1;
|
num_recv_tokens = -num_recv_tokens - 1;
|
||||||
@ -284,13 +286,13 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
if (cumulative_local_expert_recv_stats != nullptr)
|
if (cumulative_local_expert_recv_stats != nullptr)
|
||||||
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
|
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
|
||||||
}
|
}
|
||||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
|
||||||
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
||||||
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
|
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
|
||||||
|
|
||||||
// Copy tokens
|
// Copy tokens
|
||||||
EP_DEVICE_ASSERT(num_scales <= 64);
|
EP_DEVICE_ASSERT(num_scales <= 64);
|
||||||
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
|
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
|
||||||
// Copy source info
|
// Copy source info
|
||||||
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
||||||
if (lane_id == 0)
|
if (lane_id == 0)
|
||||||
@ -340,14 +342,16 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
bool use_fp8, bool round_scale, bool use_ue8m0,
|
bool use_fp8, bool round_scale, bool use_ue8m0,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int* usage_flag,
|
||||||
cudaStream_t stream, int phases) {
|
int num_device_sms, cudaStream_t stream,
|
||||||
|
int phases) {
|
||||||
constexpr int kNumMaxTopK = 9;
|
constexpr int kNumMaxTopK = 9;
|
||||||
constexpr int kNumWarpsPerGroup = 10;
|
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
|
||||||
constexpr int kNumWarpGroups = 3;
|
const int num_warps_per_group = 32 / num_warp_groups;
|
||||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
|
||||||
|
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
|
||||||
|
|
||||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
const auto num_warps = num_warp_groups * num_warps_per_group;
|
||||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
const auto num_sms = ceil_div(num_experts, num_warp_groups);
|
||||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||||
|
|
||||||
// Workspace checks
|
// Workspace checks
|
||||||
@ -360,11 +364,11 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
|
|||||||
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
|
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");
|
||||||
|
|
||||||
#define DISPATCH_LAUNCH_CASE(hidden) { \
|
#define DISPATCH_LAUNCH_CASE(hidden) { \
|
||||||
auto dispatch_func = dispatch<false, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
auto dispatch_func = dispatch<false, false, hidden>; \
|
||||||
if (use_fp8 and not use_ue8m0) \
|
if (use_fp8 and not use_ue8m0) \
|
||||||
dispatch_func = dispatch<true, false, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
dispatch_func = dispatch<true, false, hidden>; \
|
||||||
if (use_fp8 and use_ue8m0) \
|
if (use_fp8 and use_ue8m0) \
|
||||||
dispatch_func = dispatch<true, true, kNumWarpGroups, kNumWarpsPerGroup, hidden>; \
|
dispatch_func = dispatch<true, true, hidden>; \
|
||||||
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
LAUNCH_KERNEL(&cfg, dispatch_func, \
|
||||||
packed_recv_x, packed_recv_x_scales, \
|
packed_recv_x, packed_recv_x_scales, \
|
||||||
packed_recv_src_info, packed_recv_layout_range, \
|
packed_recv_src_info, packed_recv_layout_range, \
|
||||||
@ -376,15 +380,17 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
|
|||||||
next_clean, num_next_clean_int, \
|
next_clean, num_next_clean_int, \
|
||||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||||
num_topk, num_experts, rank, num_ranks, \
|
num_topk, num_experts, rank, num_ranks, \
|
||||||
round_scale, usage_flag, phases); } break
|
round_scale, usage_flag, \
|
||||||
|
num_warp_groups, num_warps_per_group, \
|
||||||
|
phases); } break
|
||||||
|
|
||||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||||
#undef DISPATCH_LAUNCH_CASE
|
#undef DISPATCH_LAUNCH_CASE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
template <int kHidden, int kNumMaxTopk>
|
||||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
__global__ __launch_bounds__(1024, 1) void
|
||||||
combine(void* combined_x,
|
combine(void* combined_x,
|
||||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||||
@ -394,16 +400,18 @@ combine(void* combined_x,
|
|||||||
int num_combined_tokens, int hidden, int num_topk,
|
int num_combined_tokens, int hidden, int num_topk,
|
||||||
int num_max_dispatch_tokens_per_rank,
|
int num_max_dispatch_tokens_per_rank,
|
||||||
int num_experts, int rank, int num_ranks,
|
int num_experts, int rank, int num_ranks,
|
||||||
int* usage_flag, int phases, bool zero_copy) {
|
int* usage_flag,
|
||||||
|
int num_warp_groups, int num_warps_per_group,
|
||||||
|
int phases, bool zero_copy) {
|
||||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||||
const auto num_sms = static_cast<int>(gridDim.x);
|
const auto num_sms = static_cast<int>(gridDim.x);
|
||||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||||
const auto num_threads = static_cast<int>(blockDim.x);
|
const auto num_threads = static_cast<int>(blockDim.x);
|
||||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||||
const auto num_local_experts = num_experts / num_ranks;
|
const auto num_local_experts = num_experts / num_ranks;
|
||||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
const auto warp_group_id = warp_id / num_warps_per_group;
|
||||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
const auto sub_warp_id = warp_id % num_warps_per_group;
|
||||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
|
||||||
|
|
||||||
// Data type staffs
|
// Data type staffs
|
||||||
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
|
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
|
||||||
@ -435,10 +443,10 @@ combine(void* combined_x,
|
|||||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||||
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
|
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
|
||||||
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
|
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
|
||||||
const auto local_x = reinterpret_cast<const int4*>(x) +
|
const auto local_x = static_cast<const int4*>(x) +
|
||||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
|
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
|
||||||
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||||
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
|
const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
|
||||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
|
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
|
||||||
|
|
||||||
// Unpack layout
|
// Unpack layout
|
||||||
@ -446,7 +454,7 @@ combine(void* combined_x,
|
|||||||
unpack2(layout, num_tokens_to_send, offset);
|
unpack2(layout, num_tokens_to_send, offset);
|
||||||
|
|
||||||
// Issue IBGDA send
|
// Issue IBGDA send
|
||||||
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
|
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
|
||||||
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
|
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
|
||||||
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
|
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
|
||||||
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
|
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
|
||||||
@ -467,9 +475,9 @@ combine(void* combined_x,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put finishing flag
|
// Put the finishing flag
|
||||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);
|
||||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
|
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
|
||||||
if (sub_warp_id == 1 and lane_id == 0) {
|
if (sub_warp_id == 1 and lane_id == 0) {
|
||||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||||
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
|
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
|
||||||
@ -491,7 +499,7 @@ combine(void* combined_x,
|
|||||||
|
|
||||||
// Wait all ranks to arrive and notify usages
|
// Wait all ranks to arrive and notify usages
|
||||||
if (responsible_expert_idx < num_experts) {
|
if (responsible_expert_idx < num_experts) {
|
||||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
EP_DEVICE_ASSERT(num_warps_per_group > 1);
|
||||||
if (sub_warp_id == 0 and lane_id == 0) {
|
if (sub_warp_id == 0 and lane_id == 0) {
|
||||||
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||||
} else if (sm_id == 0 and sub_warp_id == 1 and lane_id == 0) {
|
} else if (sm_id == 0 and sub_warp_id == 1 and lane_id == 0) {
|
||||||
@ -518,7 +526,7 @@ combine(void* combined_x,
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
|
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
|
||||||
// Read from sources
|
// Read from sources
|
||||||
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
|
auto rdma_buffer_type = reinterpret_cast<const int*>(static_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
|
||||||
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
|
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
|
||||||
|
|
||||||
// Reduce
|
// Reduce
|
||||||
@ -535,7 +543,7 @@ combine(void* combined_x,
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||||
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
|
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
|
||||||
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
|
(static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -548,21 +556,23 @@ void combine(void* combined_x,
|
|||||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
void* workspace, int* usage_flag,
|
void* workspace, int* usage_flag,
|
||||||
cudaStream_t stream, int phases, bool zero_copy) {
|
int num_device_sms, cudaStream_t stream,
|
||||||
constexpr int kNumWarpsPerGroup = 10;
|
int phases, bool zero_copy) {
|
||||||
constexpr int kNumWarpGroups = 3;
|
|
||||||
constexpr int kNumMaxTopk = 9;
|
constexpr int kNumMaxTopk = 9;
|
||||||
|
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
|
||||||
|
const int num_warps_per_group = 32 / num_warp_groups;
|
||||||
|
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
|
||||||
|
|
||||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
const auto num_warps = num_warp_groups * num_warps_per_group;
|
||||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
const auto num_sms = ceil_div(num_experts, num_warp_groups);
|
||||||
|
|
||||||
// Check workspace
|
// Check workspace
|
||||||
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
|
auto atomic_clean_flag = static_cast<int*>(workspace);
|
||||||
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
|
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
|
||||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
||||||
|
|
||||||
#define COMBINE_LAUNCH_CASE(hidden) { \
|
#define COMBINE_LAUNCH_CASE(hidden) { \
|
||||||
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
|
auto combine_func = combine<hidden, kNumMaxTopk>; \
|
||||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||||
combined_x, \
|
combined_x, \
|
||||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||||
@ -573,6 +583,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
|
|||||||
num_max_dispatch_tokens_per_rank, \
|
num_max_dispatch_tokens_per_rank, \
|
||||||
num_experts, rank, num_ranks, \
|
num_experts, rank, num_ranks, \
|
||||||
usage_flag, \
|
usage_flag, \
|
||||||
|
num_warp_groups, num_warps_per_group, \
|
||||||
phases, zero_copy); } break
|
phases, zero_copy); } break
|
||||||
|
|
||||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||||
|
@ -333,18 +333,18 @@ __device__ __forceinline__ void tma_store_wait() {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename dtype_t>
|
template <typename dtype_t>
|
||||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
|
||||||
return (a + b - 1) / b;
|
return (a + b - 1) / b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dtype_t>
|
template <typename dtype_t>
|
||||||
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
|
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
|
||||||
return cell_div<dtype_t>(a, b) * b;
|
return ceil_div<dtype_t>(a, b) * b;
|
||||||
}
|
}
|
||||||
|
|
||||||
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
|
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
|
||||||
int& token_start_idx, int& token_end_idx) {
|
int& token_start_idx, int& token_end_idx) {
|
||||||
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
|
int num_tokens_per_sm = ceil_div(num_tokens, num_sms);
|
||||||
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
|
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
|
||||||
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
|
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user