mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Remove all raw tensors for better P2P overlapping
This commit is contained in:
parent
f60306409a
commit
6cc3497df8
@ -93,7 +93,6 @@ struct LowLatencyBuffer {
|
||||
void* dispatch_rdma_send_buffer = nullptr;
|
||||
void* dispatch_rdma_recv_data_buffer = nullptr;
|
||||
int* dispatch_rdma_recv_count_buffer = nullptr;
|
||||
int* dispatch_rdma_atomic_token_counter = nullptr;
|
||||
|
||||
void* combine_rdma_send_buffer = nullptr;
|
||||
void* combine_rdma_recv_data_buffer = nullptr;
|
||||
@ -145,10 +144,8 @@ struct LowLatencyLayout {
|
||||
|
||||
// Symmetric signaling buffers
|
||||
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
|
||||
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
|
||||
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
|
||||
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes,
|
||||
combine_recv_flag_buffer_bytes);
|
||||
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
|
||||
total_bytes += signaling_buffer_bytes * 2;
|
||||
|
||||
// Assign pointers
|
||||
@ -160,7 +157,6 @@ struct LowLatencyLayout {
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
|
||||
|
||||
@ -1048,8 +1048,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
|
||||
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
|
||||
auto packed_recv_count = torch::from_blob(buffer.dispatch_rdma_atomic_token_counter,
|
||||
{num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
// Allocate column-majored scales
|
||||
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
|
||||
@ -1061,6 +1060,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
auto launcher = [=](int phases) {
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
|
||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||
packed_recv_count.data_ptr<int>(),
|
||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||
buffer.dispatch_rdma_send_buffer,
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
|
||||
@ -132,6 +132,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
|
||||
@ -40,9 +40,10 @@ template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert,
|
||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
@ -215,6 +216,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Clean workspace for next use
|
||||
atomic_counter_per_expert[responsible_expert_idx] = 0;
|
||||
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
|
||||
|
||||
// Clean `packed_recv_count`
|
||||
if (dst_rank == 0)
|
||||
packed_recv_count[dst_expert_local_idx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@ -223,6 +228,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
|
||||
if (phases & LOW_LATENCY_SEND_PHASE)
|
||||
cg::this_grid().sync();
|
||||
|
||||
// Receiving and packing
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
@ -252,7 +261,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||
}
|
||||
num_recv_tokens = -num_recv_tokens - 1;
|
||||
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
|
||||
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
|
||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
||||
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
||||
@ -290,6 +299,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
int* packed_recv_count,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
@ -311,17 +321,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
|
||||
// NOTES: this part will be cleaned in `combine`
|
||||
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
packed_recv_count, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases); break
|
||||
|
||||
Loading…
Reference in New Issue
Block a user