mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
@@ -68,6 +68,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
const void* x, const float* topk_weights,
|
||||
const void* bias_0, const void* bias_1,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
@@ -121,6 +122,7 @@ void combine(cudaDataType_t type,
|
||||
void* combined_x, float* combined_topk_weights,
|
||||
const bool* is_combined_token_in_rank,
|
||||
const void* x, const float* topk_weights,
|
||||
const void* bias_0, const void* bias_1,
|
||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||
|
||||
@@ -1139,10 +1139,11 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
is_cached_dispatch, cpu_rdma_team);
|
||||
}
|
||||
|
||||
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
|
||||
template <int kNumRanks, bool kMaybeWithBias, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
|
||||
__device__ int combine_token(bool is_token_in_rank, int head_idx,
|
||||
int lane_id, int hidden_int4, int num_topk,
|
||||
int4* combined_row, float* combined_topk_weights,
|
||||
const int4* bias_0_int4, const int4* bias_1_int4,
|
||||
int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
|
||||
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
||||
|
||||
@@ -1160,15 +1161,33 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
|
||||
// Reduce data
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < hidden_int4; i += 32) {
|
||||
// Read bias
|
||||
// TODO: make it as a finer-grained template
|
||||
int4 bias_0_value_int4, bias_1_value_int4;
|
||||
if (kMaybeWithBias) {
|
||||
bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0);
|
||||
bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// Read buffers
|
||||
// TODO: maybe too many registers here
|
||||
int4 recv_value_int4[kMaxNumRanks];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);
|
||||
|
||||
// Clean
|
||||
// Reduce bias
|
||||
float values[kDtypePerInt4] = {0};
|
||||
if (kMaybeWithBias) {
|
||||
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
|
||||
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
|
||||
}
|
||||
|
||||
// Reduce all-to-all results
|
||||
float values[kDtypePerInt4] = {0};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j) {
|
||||
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
||||
@@ -1210,6 +1229,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32,
|
||||
combine(int4* combined_x, float* combined_topk_weights,
|
||||
const bool* is_combined_token_in_rank,
|
||||
const int4* x, const float* topk_weights,
|
||||
const int4* bias_0, const int4* bias_1,
|
||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||
const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||
@@ -1470,12 +1490,12 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||
auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
|
||||
auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
|
||||
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
||||
combine_token<NUM_MAX_NVL_PEERS, false, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
||||
expected_head, lane_id,
|
||||
hidden_int4, num_topk,
|
||||
static_cast<int4*>(shifted),
|
||||
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
||||
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
||||
nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
||||
|
||||
// Update head
|
||||
if (lane_id < NUM_MAX_NVL_PEERS)
|
||||
@@ -1549,11 +1569,13 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
// Combine current token
|
||||
auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
|
||||
auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
|
||||
combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
|
||||
combine_token<kNumRDMARanks, true, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
|
||||
expected_head, lane_id,
|
||||
hidden_int4, num_topk,
|
||||
combined_x + token_idx * hidden_int4,
|
||||
combined_topk_weights + token_idx * num_topk,
|
||||
bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4,
|
||||
bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4,
|
||||
num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
|
||||
}
|
||||
|
||||
@@ -1614,6 +1636,7 @@ void combine(cudaDataType_t type,
|
||||
void* combined_x, float* combined_topk_weights,
|
||||
const bool* is_combined_token_in_rank,
|
||||
const void* x, const float* topk_weights,
|
||||
const void* bias_0, const void* bias_1,
|
||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||
@@ -1628,6 +1651,7 @@ void combine(cudaDataType_t type,
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
reinterpret_cast<int4*>(combined_x), combined_topk_weights, is_combined_token_in_rank, \
|
||||
reinterpret_cast<const int4*>(x), topk_weights, \
|
||||
reinterpret_cast<const int4*>(bias_0), reinterpret_cast<const int4*>(bias_1), \
|
||||
combined_rdma_head, combined_nvl_head, \
|
||||
reinterpret_cast<const SourceMeta*>(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \
|
||||
num_tokens, num_combined_tokens, hidden, num_topk, \
|
||||
|
||||
@@ -587,6 +587,7 @@ template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWa
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
const dtype_t* bias_0, const dtype_t* bias_1,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank,
|
||||
@@ -602,6 +603,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
||||
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
|
||||
auto x_int4 = reinterpret_cast<const int4*>(x);
|
||||
auto bias_0_int4 = reinterpret_cast<const int4*>(bias_0);
|
||||
auto bias_1_int4 = reinterpret_cast<const int4*>(bias_1);
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
// TMA stuffs
|
||||
@@ -809,14 +812,26 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < hidden_int4; i += 32) {
|
||||
// Read bias
|
||||
// TODO: make it as a template
|
||||
int4 bias_0_value_int4 = bias_0_int4 != nullptr ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);
|
||||
int4 bias_1_value_int4 = bias_1_int4 != nullptr ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i) : make_int4(0, 0, 0, 0);
|
||||
|
||||
// Read buffers
|
||||
int4 recv_value_int4[kNumRanks];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
|
||||
|
||||
// Reduce bias
|
||||
float values[kDtypePerInt4];
|
||||
auto bias_0_values = reinterpret_cast<const dtype_t*>(&bias_0_value_int4);
|
||||
auto bias_1_values = reinterpret_cast<const dtype_t*>(&bias_1_value_int4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
values[j] = static_cast<float>(bias_0_values[j]) + static_cast<float>(bias_1_values[j]);
|
||||
|
||||
// Reduce all-to-all results
|
||||
float values[kDtypePerInt4] = {0};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j) {
|
||||
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
||||
@@ -887,6 +902,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
const void* x, const float* topk_weights,
|
||||
const void* bias_0, const void* bias_1,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
@@ -904,6 +920,7 @@ void combine(cudaDataType_t type,
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
reinterpret_cast<const dtype*>(bias_0), reinterpret_cast<const dtype*>(bias_1), \
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
||||
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
||||
buffer_ptrs, rank, \
|
||||
|
||||
Reference in New Issue
Block a user