Add low-latency kernel PCIe usage flag (#195)

* Add low-latency kernel usage flag

* Update comments
This commit is contained in:
Chenggang Zhao 2025-06-09 14:37:13 +08:00 committed by GitHub
parent 564e375234
commit 0d1a855d81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 13 deletions

View File

@ -76,6 +76,13 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));
*moe_recv_rdma_counter = -1;
}
// Low-latency kernels' usage flag
if (low_latency_mode) {
CUDA_CHECK(cudaMallocHost(&low_latency_usage_flag, sizeof(int), cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&low_latency_usage_flag_mapped, const_cast<int*>(low_latency_usage_flag), 0));
*low_latency_usage_flag = 0;
}
}
Buffer::~Buffer() noexcept(false) {
@ -997,6 +1004,11 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
return {combined_x, combined_topk_weights, event};
}
uint64_t Buffer::get_low_latency_usage_flag() const {
EP_HOST_ASSERT(low_latency_usage_flag != nullptr);
return reinterpret_cast<uint64_t>(low_latency_usage_flag);
}
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
EP_HOST_ASSERT(low_latency_mode);
@ -1078,7 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8,
workspace, launch_stream, phases);
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@ -1165,7 +1177,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream,
workspace, low_latency_usage_flag_mapped, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@ -1238,6 +1250,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
.def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("get_low_latency_usage_flag", &deep_ep::Buffer::get_low_latency_usage_flag)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)

View File

@ -71,6 +71,10 @@ private:
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;
// Host-side low-latency kernels' usages
volatile int* low_latency_usage_flag = nullptr;
int* low_latency_usage_flag_mapped = nullptr;
private:
void move_fifo_slots(int num_slots = 1);
@ -132,6 +136,8 @@ public:
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
uint64_t get_low_latency_usage_flag() const;
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>

View File

@ -138,7 +138,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases);
void* workspace, int* usage_flag,
cudaStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
@ -147,8 +148,8 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream,
int phases, bool zero_copy);
void* workspace, int* usage_flag,
cudaStream_t stream, int phases, bool zero_copy);
} // namespace internode_ll

View File

@ -47,7 +47,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int phases) {
int* usage_flag, int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
@ -180,6 +180,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
#pragma unroll
for (int i = lane_id; i < num_experts; i += 32)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
} else if (sm_id == 1) {
// The second SM is also responsible for notifying PCIe usage
if (lane_id == 0)
atomicAdd_system(usage_flag, 1);
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
@ -311,7 +315,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
void* workspace, cudaStream_t stream, int phases) {
void* workspace, int* usage_flag,
cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
@ -338,7 +343,8 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); } break
num_topk, num_experts, rank, num_ranks, \
usage_flag, phases); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
@ -356,7 +362,7 @@ combine(void* combined_x,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int phases, bool zero_copy) {
int* usage_flag, int phases, bool zero_copy) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
@ -451,11 +457,14 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Wait all ranks to arrive
// Wait all ranks to arrive and notify usages
if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
if (sub_warp_id == 0 and lane_id == 0)
if (sub_warp_id == 0 and lane_id == 0) {
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
} else if (sm_id == 0 and sub_warp_id == 1 and lane_id == 0) {
atomicAdd_system(usage_flag, 1);
}
}
cg::this_grid().sync();
@ -506,8 +515,8 @@ void combine(void* combined_x,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream,
int phases, bool zero_copy) {
void* workspace, int* usage_flag,
cudaStream_t stream, int phases, bool zero_copy) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9;
@ -531,6 +540,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
usage_flag, \
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);

View File

@ -443,6 +443,19 @@ class Buffer:
async_finish, allocate_on_comm_stream)
return combined_x, combined_topk_weights, EventOverlap(event)
def get_low_latency_usage_flag(self):
"""
Return a host-side integer flag, which indicates the stages of low-latency kernels.
The initial value is 0, the low-latency dispatch will add 1 before communication, the low-latency combine
will add 1 after communication.
This is useful when there is no two-batch overlap, and you want to overlap H2D/D2H transfer with attention layers.
Returns:
flag: the host-side integer flag pointer. The value is in `int`, but returns a `uint64_t` pointer. Please
`reinterpret_cast` the returned value into `int*`.
"""
return self.runtime.get_low_latency_usage_flag()
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer

View File

@ -155,6 +155,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks)
buffer.get_low_latency_usage_flag()
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
do_pressure_test = False