Support CUDA graph for intranode normal kernels (#203)

This commit is contained in:
Chenggang Zhao
2025-06-11 11:08:54 +08:00
committed by GitHub
parent 8da2d7b38d
commit a8299ca7c2
7 changed files with 86 additions and 38 deletions

View File

@@ -45,7 +45,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);

View File

@@ -25,7 +25,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
per_rank_buffer = static_cast<int*>(buffer_ptrs[thread_id]);
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
}
@@ -48,7 +48,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto local_per_rank_buffer = static_cast<int*>(buffer_ptrs[rank]);
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 1; i < kNumRanks; ++ i)
@@ -141,7 +141,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
ptr[i] = rank_prefix_matrix[i];
@@ -173,7 +173,7 @@ __global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
@@ -196,7 +196,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Calculate pointers by the specific layout
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
int target_rank = is_sender ? rank : responsible_rank;
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
@@ -286,7 +286,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the following data
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
@@ -349,7 +349,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
// Calculate offset first
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0;
// Receive channel offset
@@ -372,7 +372,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto start_time = clock64();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) {
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same
// NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are the same
while (recv_thread_id_in_rank == 0) {
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
@@ -450,12 +450,25 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
if (lane_id == 0)
tma_store_wait();
}
// Clean unused `recv_topk_idx` as -1
if (num_worst_tokens > 0) {
auto rank_prefix_matrix = static_cast<int*>(buffer_ptrs[rank]);
const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];
const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads;
const auto clean_end = num_worst_tokens * num_topk;
const auto clean_stride = num_sms * kNumThreads;
#pragma unroll
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
recv_topk_idx[i] = -1;
}
}
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
@@ -470,7 +483,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
@@ -493,7 +506,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
auto ptr = static_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
@@ -590,7 +603,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
// Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[send_rank_id]));
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
@@ -682,7 +695,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
if (thread_id < 32) {
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int* channel_head_idx_ptr = static_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater
@@ -720,13 +733,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto channel_rank_offset = responsible_channel * kNumRanks + i;
auto num_channels_total = num_channels * kNumRanks;
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
auto ptr = reinterpret_cast<void*>(static_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
ptr = reinterpret_cast<void*>(static_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);