Add low-latency kernel PCIe usage flag (#195)

* Add low-latency kernel usage flag

* Update comments
This commit is contained in:
Chenggang Zhao
2025-06-09 14:37:13 +08:00
committed by GitHub
parent 564e375234
commit 0d1a855d81
6 changed files with 57 additions and 13 deletions

View File

@@ -76,6 +76,13 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));
*moe_recv_rdma_counter = -1;
}
// Low-latency kernels' usage flag
if (low_latency_mode) {
CUDA_CHECK(cudaMallocHost(&low_latency_usage_flag, sizeof(int), cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&low_latency_usage_flag_mapped, const_cast<int*>(low_latency_usage_flag), 0));
*low_latency_usage_flag = 0;
}
}
Buffer::~Buffer() noexcept(false) {
@@ -997,6 +1004,11 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
return {combined_x, combined_topk_weights, event};
}
uint64_t Buffer::get_low_latency_usage_flag() const {
EP_HOST_ASSERT(low_latency_usage_flag != nullptr);
return reinterpret_cast<uint64_t>(low_latency_usage_flag);
}
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
EP_HOST_ASSERT(low_latency_mode);
@@ -1078,7 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8,
workspace, launch_stream, phases);
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@@ -1165,7 +1177,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream,
workspace, low_latency_usage_flag_mapped, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@@ -1238,6 +1250,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
.def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("get_low_latency_usage_flag", &deep_ep::Buffer::get_low_latency_usage_flag)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)