mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use TMA instead of LD/ST for intra-node normal kernels (#191)
* Update CMake files * Use TMA instead of LD/ST for intranode dispatch * Use TMA instead of LD/ST for intranode combine * Adjust configs * Test default configs as well * More warps for combine * Add inter-thread fence * Enable more warps * Do not use TMA for senders * Update configs * Remove useless wait
This commit is contained in:
@@ -174,7 +174,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template <int kNumRanks, int kNumThreads>
|
||||
template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
@@ -183,7 +183,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
void **buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
EP_DEVICE_ASSERT(num_sms % 2 == 0);
|
||||
|
||||
@@ -232,19 +232,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
|
||||
|
||||
// TMA stuffs
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto half_hidden_int4 = hidden_int4 / 2;
|
||||
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
|
||||
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
|
||||
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + half_hidden_bytes);
|
||||
uint32_t tma_phase = 0;
|
||||
if (lane_id == 0) {
|
||||
mbarrier_init(tma_mbarrier, 1);
|
||||
fence_view_async_shared();
|
||||
fence_barrier_init();
|
||||
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
|
||||
|
||||
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
|
||||
// NOTES: this is for distinguishing zero tokens
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0) {
|
||||
if (lane_id == 0 and send_warp_id_in_rank == 0) {
|
||||
int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;
|
||||
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
|
||||
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
|
||||
@@ -262,7 +276,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
// NOTES: the head index received by different warps may not be the same
|
||||
auto start_time = clock64();
|
||||
while (send_lane_id == 0) {
|
||||
while (lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
|
||||
@@ -279,7 +293,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int chunk_token_idx = 0;
|
||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
|
||||
if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
|
||||
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
|
||||
|
||||
// Skip if not selected
|
||||
@@ -294,30 +308,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Copy data
|
||||
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x + token_idx * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
|
||||
__ldg, st_na_global);
|
||||
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global);
|
||||
|
||||
// Copy source index
|
||||
if (send_lane_id == 0)
|
||||
if (lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights` with transformed index
|
||||
if (send_lane_id < num_topk) {
|
||||
if (lane_id < num_topk) {
|
||||
// Top-k index
|
||||
int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id);
|
||||
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;
|
||||
|
||||
// Top-k weights
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);
|
||||
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = weight_value;
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = send_lane_id; i < num_scales; i += 32)
|
||||
for (int i = lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
}
|
||||
|
||||
@@ -328,7 +341,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
// Move tail index
|
||||
// NOTES: here all warps should share the same new tail
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
if (send_warp_id_in_rank == 0 and send_lane_id == 0)
|
||||
if (send_warp_id_in_rank == 0 and lane_id == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
@@ -336,7 +349,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
|
||||
const auto recv_thread_id = thread_id;
|
||||
const auto recv_lane_id = recv_thread_id % 32;
|
||||
const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;
|
||||
const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
@@ -348,9 +360,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
|
||||
// Receive channel offset
|
||||
int total_offset, num_tokens_to_recv;
|
||||
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
|
||||
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (recv_lane_id == 0) {
|
||||
while (lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
|
||||
while (lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (lane_id == 0) {
|
||||
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
|
||||
if (recv_warp_id_in_rank == 0)
|
||||
recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;
|
||||
@@ -393,8 +405,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
|
||||
ld_nc_global, st_na_global);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
|
||||
tma_store_wait();
|
||||
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
|
||||
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
|
||||
mbarrier_wait(tma_mbarrier, tma_phase);
|
||||
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Copy `src_idx`
|
||||
@@ -426,12 +445,16 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
cached_channel_head_idx += num_recv_tokens;
|
||||
total_offset += num_recv_tokens;
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
|
||||
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0)
|
||||
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
|
||||
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
|
||||
|
||||
// Exit
|
||||
num_tokens_to_recv -= num_recv_tokens;
|
||||
}
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,17 +464,22 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 512;
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 8192;
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<ranks, kNumThreads>, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
break
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
} break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
@@ -532,7 +560,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
#undef CACHED_NOTIFY_COMBINE
|
||||
}
|
||||
|
||||
template<typename dtype_t, int kNumRanks, int kNumThreads>
|
||||
template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
@@ -542,7 +570,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();
|
||||
const auto num_channels = num_sms / 2;
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
const int responsible_channel = sm_id / 2;
|
||||
@@ -553,16 +581,21 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
auto x_int4 = reinterpret_cast<const int4*>(x);
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
// TMA stuffs
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
// Several warps are responsible for a single rank
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks;
|
||||
constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks;
|
||||
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_rank_id = thread_id / num_threads_per_rank;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
|
||||
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||
@@ -595,7 +628,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
auto start_time = clock64();
|
||||
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
|
||||
while (send_lane_id == 0) {
|
||||
while (lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
|
||||
@@ -618,22 +651,22 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
// Copy data
|
||||
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
||||
UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
||||
|
||||
// Send source index
|
||||
if (send_lane_id == 0)
|
||||
if (lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
|
||||
|
||||
// Send `topk_weights`
|
||||
if (num_topk > 0 and send_lane_id < num_topk)
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
|
||||
if (num_topk > 0 and lane_id < num_topk)
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + lane_id);
|
||||
}
|
||||
token_idx += num_round_tokens;
|
||||
current_channel_tail_idx += num_round_tokens;
|
||||
|
||||
// Move tail index
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
|
||||
if (lane_id == 0 and send_warp_id_in_rank == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
@@ -641,7 +674,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
// One warp for moving the queue head, others for reduction
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
const auto recv_warp_id = thread_id / 32;
|
||||
const auto recv_lane_id = thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
|
||||
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
|
||||
|
||||
@@ -651,19 +683,19 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
__shared__ volatile bool warp_retired[num_recv_warps];
|
||||
if (thread_id < num_recv_warps)
|
||||
warp_retired[thread_id] = false;
|
||||
if (recv_lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
|
||||
if (lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][lane_id] = 0;
|
||||
if (thread_id < kNumRanks)
|
||||
channel_tail_idx[thread_id] = 0;
|
||||
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
||||
|
||||
if (thread_id < 32) {
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
|
||||
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
||||
|
||||
// Queue head updater
|
||||
int last_head = 0;
|
||||
while (recv_lane_id < kNumRanks) {
|
||||
while (lane_id < kNumRanks) {
|
||||
// Check retired
|
||||
bool retired = true;
|
||||
#pragma unroll
|
||||
@@ -673,13 +705,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
break;
|
||||
|
||||
// Update queue tail
|
||||
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
|
||||
channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
|
||||
|
||||
// Update minimum head
|
||||
int min_head = std::numeric_limits<int>::max();
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
|
||||
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
|
||||
min_head = min(min_head, warp_channel_head_idx[i][lane_id]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
|
||||
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
|
||||
}
|
||||
@@ -716,11 +748,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
|
||||
// Read expected head
|
||||
int expected_head = -1;
|
||||
if (recv_lane_id < kNumRanks)
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
|
||||
if (lane_id < kNumRanks)
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
|
||||
|
||||
auto start_time = clock64();
|
||||
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
|
||||
while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
|
||||
@@ -740,9 +772,16 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce data
|
||||
// Wait shared memory release
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
__syncwarp();
|
||||
|
||||
// Reduce data with pipeline
|
||||
constexpr int kNumStages = 8;
|
||||
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
|
||||
#pragma unroll
|
||||
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
|
||||
for (int i = lane_id; i < hidden_int4; i += 32) {
|
||||
// Read buffers
|
||||
int4 recv_value_int4[kNumRanks];
|
||||
#pragma unroll
|
||||
@@ -759,33 +798,55 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
values[k] += static_cast<float>(recv_value_dtypes[k]);
|
||||
}
|
||||
|
||||
// Cast back to `dtype_t` and write
|
||||
// Cast back to `dtype_t`
|
||||
int4 out_int4;
|
||||
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
||||
recv_int4[token_idx * hidden_int4 + i] = out_int4;
|
||||
|
||||
// Wait TMA arrival
|
||||
if (lane_id == 0)
|
||||
tma_store_wait<kNumStages - 1>();
|
||||
__syncwarp();
|
||||
|
||||
// Write into TMA buffer
|
||||
auto tma_stage_idx = (i / 32) % kNumStages;
|
||||
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
|
||||
|
||||
// Issue TMA
|
||||
tma_store_fence();
|
||||
__syncwarp();
|
||||
if (lane_id == 0) {
|
||||
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
|
||||
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
|
||||
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Reduce `topk_weights`
|
||||
if (recv_lane_id < num_topk) {
|
||||
if (lane_id < num_topk) {
|
||||
float value = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk_ranks; ++ i)
|
||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
|
||||
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
|
||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + lane_id);
|
||||
recv_topk_weights[token_idx * num_topk + lane_id] = value;
|
||||
}
|
||||
|
||||
// Update head
|
||||
if (recv_lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||
if (lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||
}
|
||||
|
||||
// Retired
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
if (lane_id == 0)
|
||||
warp_retired[recv_warp_id] = true;
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -799,15 +860,20 @@ void combine(cudaDataType_t type,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 4096;
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
|
||||
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
|
||||
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
||||
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); } \
|
||||
break
|
||||
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break
|
||||
|
||||
|
||||
Reference in New Issue
Block a user