mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support zero-copy for low-latency combine
This commit is contained in:
@@ -1100,7 +1100,8 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
|
||||
bool zero_copy, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@@ -1159,7 +1160,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
next_clean_meta.first, next_clean_meta.second,
|
||||
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
|
||||
num_topk, num_experts, rank, num_ranks,
|
||||
workspace, launch_stream, phases);
|
||||
workspace, launch_stream,
|
||||
phases, zero_copy);
|
||||
};
|
||||
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
|
||||
|
||||
@@ -1182,6 +1184,20 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
return {combined_x, event, recv_hook};
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
auto dtype = torch::kBFloat16;
|
||||
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
|
||||
|
||||
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
|
||||
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
|
||||
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
|
||||
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
@@ -1218,5 +1234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
.def("internode_combine", &deep_ep::Buffer::internode_combine)
|
||||
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
|
||||
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
|
||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine);
|
||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
|
||||
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user