mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support zero-copy for low-latency combine
This commit is contained in:
@@ -147,7 +147,8 @@ void combine(void* combined_x,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
|
||||
@@ -353,7 +353,7 @@ combine(void* combined_x,
|
||||
int num_combined_tokens, int hidden, int num_topk,
|
||||
int num_max_dispatch_tokens_per_rank,
|
||||
int num_experts, int rank, int num_ranks,
|
||||
int phases) {
|
||||
int phases, bool zero_copy) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@@ -420,7 +420,8 @@ combine(void* combined_x,
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||
}
|
||||
}
|
||||
@@ -500,7 +501,8 @@ void combine(void* combined_x,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
void* workspace, cudaStream_t stream,
|
||||
int phases, bool zero_copy) {
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
@@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases); } break
|
||||
phases, zero_copy); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
|
||||
|
||||
Reference in New Issue
Block a user