Remove all raw tensors for better P2P overlapping

This commit is contained in:
Chenggang Zhao 2025-03-03 14:25:22 +08:00
parent f60306409a
commit 6cc3497df8
4 changed files with 18 additions and 14 deletions

View File

@ -93,7 +93,6 @@ struct LowLatencyBuffer {
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
int* dispatch_rdma_atomic_token_counter = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
@ -145,10 +144,8 @@ struct LowLatencyLayout {
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes,
combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
// Assign pointers
@ -160,7 +157,6 @@ struct LowLatencyLayout {
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)

View File

@ -1048,8 +1048,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(torch::kFloat8_e4m3fn));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::from_blob(buffer.dispatch_rdma_atomic_token_counter,
{num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
@ -1061,6 +1060,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales.data_ptr<float>(),
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),

View File

@ -132,6 +132,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,

View File

@ -40,9 +40,10 @@ template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
@ -215,6 +216,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// Clean `packed_recv_count`
if (dst_rank == 0)
packed_recv_count[dst_expert_local_idx] = 0;
}
__syncwarp();
@ -223,6 +228,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if (phases & LOW_LATENCY_SEND_PHASE)
cg::this_grid().sync();
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
@ -252,7 +261,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
}
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
@ -290,6 +299,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
@ -311,17 +321,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
// NOTES: this part will be cleaned in `combine`
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
#define DISPATCH_LAUNCH_CASE(hidden) \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); break