mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Low latency kernels use rdma atomic to support AR.
This commit is contained in:
@@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Wait local sends issued and send expert counts
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
|
||||
dst_rank, dst_expert_local_idx, 0);
|
||||
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||
}
|
||||
@@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_recv_tokens, recv_token_begin_idx;
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
if (src_rank != rank) {
|
||||
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
|
||||
num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
|
||||
EP_DEVICE_ASSERT(num_recv_tokens != 0);
|
||||
} else {
|
||||
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||
}
|
||||
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||
num_recv_tokens = -num_recv_tokens - 1;
|
||||
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
|
||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||
@@ -439,7 +431,7 @@ combine(void* combined_x,
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0);
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||
}
|
||||
@@ -456,16 +448,8 @@ combine(void* combined_x,
|
||||
// Wait all ranks to arrive and notify PCIe usage
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
||||
if (sub_warp_id == 0 and lane_id == 0) {
|
||||
// TODO: refactor QP indices
|
||||
auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
auto src_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
if (src_rank != rank) {
|
||||
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
|
||||
} else {
|
||||
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||
}
|
||||
}
|
||||
if (sub_warp_id == 0 and lane_id == 0)
|
||||
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||
}
|
||||
cg::this_grid().sync();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user