mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Merge pull request #174 from deepseek-ai/p2p-refactor
Low-latency P2P code cleanup and bug fixed
This commit is contained in:
commit
8da1b1f81e
@ -1190,8 +1190,9 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
auto dtype = torch::kBFloat16;
|
||||
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
|
||||
|
@ -147,7 +147,7 @@ public:
|
||||
const std::optional<torch::Tensor>& out = std::nullopt);
|
||||
|
||||
torch::Tensor
|
||||
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
|
||||
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
|
||||
};
|
||||
|
||||
} // namespace deep_ep
|
||||
|
@ -432,4 +432,18 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) {
|
||||
// Local rank, no need for mapping
|
||||
if (rank == dst_rank)
|
||||
return ptr;
|
||||
auto peer_base = __ldg(reinterpret_cast<uint64_t*>(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank);
|
||||
|
||||
// RDMA connected
|
||||
if (peer_base == 0)
|
||||
return 0;
|
||||
|
||||
// NVLink P2P is enabled
|
||||
return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
@ -149,20 +149,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
slot_idx * num_bytes_per_msg;
|
||||
if (dst_rank != rank) {
|
||||
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||
if (peer_base_addr) {
|
||||
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
|
||||
}
|
||||
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
|
||||
if (dst_p2p_ptr == 0) {
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
|
||||
} else {
|
||||
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
}
|
||||
|
||||
@ -222,16 +215,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
|
||||
// Wait local sends issued and send expert counts
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||
if (peer_base_addr) { // P2P enabled
|
||||
int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
|
||||
st_na_release(rptr_actual, -num_tokens_sent - 1);
|
||||
} else {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
||||
}
|
||||
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
|
||||
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
|
||||
if (dst_p2p_ptr == 0) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||
st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);
|
||||
}
|
||||
|
||||
// Clean workspace for next use
|
||||
@ -428,22 +417,15 @@ combine(void* combined_x,
|
||||
auto src_idx = __ldg(local_src_info + token_idx);
|
||||
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
|
||||
if (dst_rank == rank) {
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
|
||||
if (dst_p2p_ptr == 0) {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
if (not zero_copy)
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
|
||||
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||
if (peer_base_addr) {
|
||||
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||
}
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||
} else {
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
}
|
||||
}
|
||||
|
||||
@ -452,16 +434,12 @@ combine(void* combined_x,
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||
if (dst_rank != rank) {
|
||||
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||
if (peer_base_addr) {
|
||||
int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
|
||||
st_na_release(req_rptr_actual, 1);
|
||||
} else {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
|
||||
}
|
||||
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
|
||||
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
|
||||
if (dst_p2p_ptr == 0) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||
st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);
|
||||
}
|
||||
atomic_add_release_global(atomic_clean_flag, -1);
|
||||
}
|
||||
@ -473,7 +451,7 @@ combine(void* combined_x,
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
// Wait all ranks to arrive and notify PCIe usage
|
||||
// Wait all ranks to arrive
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
||||
if (sub_warp_id == 0 and lane_id == 0)
|
||||
|
@ -31,7 +31,8 @@ class Buffer:
|
||||
|
||||
def __init__(self, group: dist.ProcessGroup,
|
||||
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None:
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 12,
|
||||
allow_nvlink_for_low_latency_mode: bool = False) -> None:
|
||||
"""
|
||||
Initialize the communication buffer.
|
||||
|
||||
@ -42,6 +43,10 @@ class Buffer:
|
||||
low_latency_mode: whether to enable low-latency mode.
|
||||
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
|
||||
to the number of local experts.
|
||||
allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice
|
||||
this is somehow incompatible with the hook-based overlapping.
|
||||
Warning: PCIe connections may lead to errors due to memory ordering issues,
|
||||
please make sure all connections are via NVLink.
|
||||
"""
|
||||
|
||||
# Initialize the CPP runtime
|
||||
@ -68,8 +73,7 @@ class Buffer:
|
||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||
# Enable IBGDA
|
||||
assert num_qps_per_rank > 0
|
||||
if not os.getenv("NVSHMEM_DISABLE_P2P"):
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1'
|
||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
||||
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
|
||||
|
Loading…
Reference in New Issue
Block a user