mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Improve EP2/4 performance
This commit is contained in:
parent
55cdd9a64f
commit
1553fc42bf
@ -174,8 +174,8 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void __launch_bounds__(kNumRanks * 32, 1)
|
||||
template <int kNumRanks, int kNumThreads>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
@ -187,11 +187,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
EP_DEVICE_ASSERT(num_sms % 2 == 0);
|
||||
|
||||
// Each warp is responsible for a single rank
|
||||
// Several warps are response for a single rank
|
||||
const auto num_threads_per_rank = kNumThreads / kNumRanks;
|
||||
const auto num_channels = num_sms / 2;
|
||||
const auto responsible_rank = (static_cast<int>(thread_id)) / 32;
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving
|
||||
const auto responsible_rank = (static_cast<int>(thread_id)) / num_threads_per_rank;
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
const auto responsible_channel = sm_id / 2;
|
||||
|
||||
int num_experts_per_rank = num_experts / kNumRanks;
|
||||
@ -234,19 +234,20 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
constexpr int num_send_warps = kNumRanks;
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
EP_DEVICE_ASSERT(num_send_warps == kNumRanks and send_warp_id == responsible_rank);
|
||||
EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
|
||||
|
||||
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
|
||||
// NOTES: this is for distinguishing zero tokens
|
||||
if (send_lane_id == 0) {
|
||||
int value = responsible_channel > 0 ? channel_prefix_matrix[send_warp_id * num_channels + responsible_channel - 1] : 0;
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0) {
|
||||
int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;
|
||||
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
|
||||
value = channel_prefix_matrix[send_warp_id * num_channels + responsible_channel];
|
||||
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
|
||||
st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
|
||||
}
|
||||
__syncwarp();
|
||||
@ -257,8 +258,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Iterate over all tokens and send by chunks
|
||||
int cached_channel_tail_idx = 0;
|
||||
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
|
||||
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) {
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
// NOTES: the head index received by different warps may not be the same
|
||||
auto start_time = clock64();
|
||||
while (send_lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
@ -276,67 +278,73 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
int chunk_token_idx = 0;
|
||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||
if (send_lane_id == 0)
|
||||
send_head[token_idx * kNumRanks + send_warp_id] = is_token_in_rank[token_idx * kNumRanks + send_warp_id] ? cached_channel_tail_idx : -1;
|
||||
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
|
||||
if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
|
||||
|
||||
// Skip if not selected
|
||||
if (not is_token_in_rank[token_idx * kNumRanks + send_warp_id]) {
|
||||
if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) {
|
||||
token_idx ++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get an empty slot
|
||||
int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens;
|
||||
if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) {
|
||||
// Copy data
|
||||
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x + token_idx * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
|
||||
__ldg, st_na_global);
|
||||
|
||||
// Copy data
|
||||
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x + token_idx * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
|
||||
__ldg, st_na_global);
|
||||
// Copy source index
|
||||
if (send_lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
|
||||
|
||||
// Copy source index
|
||||
if (send_lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
|
||||
// Copy `topk_idx` and `topk_weights` with transformed index
|
||||
if (send_lane_id < num_topk) {
|
||||
// Top-k index
|
||||
int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
|
||||
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
|
||||
|
||||
// Copy `topk_idx` and `topk_weights` with transformed index
|
||||
if (send_lane_id < num_topk) {
|
||||
// Top-k index
|
||||
int recv_expert_begin = send_warp_id * num_experts_per_rank, recv_expert_end = (send_warp_id + 1) * num_experts_per_rank;
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
|
||||
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
|
||||
// Top-k weights
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
|
||||
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
|
||||
}
|
||||
|
||||
// Top-k weights
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
|
||||
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = send_lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = send_lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
|
||||
// Move token index
|
||||
chunk_token_idx ++, token_idx ++;
|
||||
}
|
||||
|
||||
// Move tail index
|
||||
__syncwarp();
|
||||
if (send_lane_id == 0)
|
||||
// NOTES: here all warps should share the same new tail
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
if (send_warp_id_in_rank == 0 and send_lane_id == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
// Workers for receiving and copying into buffer
|
||||
constexpr int num_recv_warps = kNumRanks;
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
|
||||
const auto recv_thread_id = thread_id;
|
||||
const auto recv_warp_id = recv_thread_id / 32;
|
||||
const auto recv_lane_id = recv_thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and recv_warp_id == responsible_rank);
|
||||
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps == kNumRanks);
|
||||
const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;
|
||||
const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
|
||||
|
||||
// Calculate offset first
|
||||
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
int rank_offset = recv_warp_id > 0 ? rank_prefix_matrix[(recv_warp_id - 1) * kNumRanks + rank] : 0;
|
||||
int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;
|
||||
|
||||
// Receive channel offset
|
||||
int total_offset, num_tokens_to_recv;
|
||||
@ -344,23 +352,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (recv_lane_id == 0) {
|
||||
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
|
||||
recv_channel_offset[recv_warp_id * num_channels + responsible_channel] = total_offset;
|
||||
if (recv_warp_id_in_rank == 0)
|
||||
recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;
|
||||
num_tokens_to_recv -= total_offset;
|
||||
}
|
||||
total_offset = __shfl_sync(0xffffffff, total_offset, 0);
|
||||
total_offset += rank_offset;
|
||||
num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0);
|
||||
|
||||
// Shared tail indices for different warps
|
||||
__shared__ volatile int shared_channel_tail_idx[kNumRanks];
|
||||
|
||||
auto start_time = clock64();
|
||||
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
||||
while (num_tokens_to_recv > 0) {
|
||||
// Check channel status by lane 0
|
||||
while (recv_lane_id == 0) {
|
||||
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
|
||||
while (recv_thread_id_in_rank == 0) {
|
||||
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
|
||||
|
||||
// Ready to copy
|
||||
if (cached_channel_head_idx != cached_channel_tail_idx)
|
||||
if (cached_channel_head_idx != cached_channel_tail_idx) {
|
||||
shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx;
|
||||
break;
|
||||
}
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
@ -369,12 +383,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
}
|
||||
}
|
||||
|
||||
// Sync queue tail
|
||||
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
|
||||
// Synchronize queue tail
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank];
|
||||
|
||||
// Copy data
|
||||
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
|
||||
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx) {
|
||||
for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) {
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
@ -384,12 +399,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Copy `src_idx`
|
||||
#pragma unroll 4
|
||||
for (int chunk_idx = cached_channel_head_idx + recv_lane_id; chunk_idx < cached_channel_tail_idx; chunk_idx += 32)
|
||||
for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank; chunk_idx < cached_channel_tail_idx; chunk_idx += 32 * num_recv_warps_per_rank)
|
||||
recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights`
|
||||
#pragma unroll 4
|
||||
for (int idx = recv_lane_id; idx < num_recv_tokens * num_topk; idx += 32) {
|
||||
for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk; idx += 32 * num_recv_warps_per_rank) {
|
||||
int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto recv_idx = static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
|
||||
@ -400,7 +415,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll 4
|
||||
for (int i = recv_lane_id; i < num_recv_tokens * num_scales; i += 32) {
|
||||
for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales; i += 32 * num_recv_warps_per_rank) {
|
||||
int chunk_idx = i / num_scales, scales_idx = i % num_scales;
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales + scales_idx] =
|
||||
@ -410,8 +425,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Move queue
|
||||
cached_channel_head_idx += num_recv_tokens;
|
||||
total_offset += num_recv_tokens;
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0)
|
||||
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
|
||||
|
||||
// Exit
|
||||
@ -426,8 +441,10 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 512;
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<ranks>, \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<ranks, kNumThreads>, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
@ -438,7 +455,7 @@ break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_ranks * 32, stream);
|
||||
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
|
||||
SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
|
||||
#undef DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
@ -160,12 +160,11 @@ class Buffer:
|
||||
Returns:
|
||||
config: the recommended config.
|
||||
"""
|
||||
# Intranode
|
||||
if num_ranks <= 8:
|
||||
return Config(Buffer.num_sms, 6, 256, 6, 128)
|
||||
|
||||
# Internode
|
||||
config_map = {
|
||||
2: Config(Buffer.num_sms, 16, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 16, 256, 6, 128),
|
||||
8: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
16: Config(Buffer.num_sms, 16, 288, 20, 128),
|
||||
24: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
32: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
@ -188,12 +187,11 @@ class Buffer:
|
||||
Returns:
|
||||
config: the recommended config.
|
||||
"""
|
||||
# Intranode
|
||||
if num_ranks <= 8:
|
||||
return Config(Buffer.num_sms, 6, 256, 6, 128)
|
||||
|
||||
# Internode
|
||||
config_map = {
|
||||
2: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
4: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
8: Config(Buffer.num_sms, 6, 256, 6, 128),
|
||||
16: Config(Buffer.num_sms, 2, 288, 28, 128),
|
||||
24: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
32: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
|
||||
@ -13,7 +13,6 @@ import test_low_latency
|
||||
|
||||
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
|
||||
# Settings
|
||||
# TODO: fix EP2/4/8 performance
|
||||
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
|
||||
assert num_experts % num_ranks == 0
|
||||
if local_rank == 0:
|
||||
@ -182,7 +181,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
for nvl_chunk_size in range(1, 7, 1):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user