Support zero-copy for low-latency combine

This commit is contained in:
Chenggang Zhao
2025-03-18 15:41:50 +08:00
parent 82dcf48fd3
commit dcaf73e5ff
7 changed files with 80 additions and 28 deletions

View File

@@ -102,6 +102,9 @@ struct LowLatencyBuffer {
void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr;
void* combine_rdma_send_buffer_data_start = nullptr;
int num_bytes_per_combine_msg = 0;
std::pair<int*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
@@ -163,7 +166,9 @@ struct LowLatencyLayout {
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * i + sizeof(int4)),
static_cast<int>(num_bytes_per_combine_msg)
};
}
}

View File

@@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
@@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream, phases);
workspace, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
return {combined_x, event, recv_hook};
}
torch::Tensor
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
}
} // namespace deep_ep
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine);
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
}

View File

@@ -143,7 +143,11 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
};
} // namespace deep_ep

View File

@@ -147,7 +147,8 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
void* workspace, cudaStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll

View File

@@ -353,7 +353,7 @@ combine(void* combined_x,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int phases) {
int phases, bool zero_copy) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
@@ -420,7 +420,8 @@ combine(void* combined_x,
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
}
@@ -500,7 +501,8 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) {
void* workspace, cudaStream_t stream,
int phases, bool zero_copy) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9;
@@ -524,7 +526,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases); } break
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);