mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Feature: LL nvlink p2p (#173)
This commit is contained in:
parent
d5ca4495c0
commit
68ae8b3d07
@ -150,7 +150,15 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|||||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||||
slot_idx * num_bytes_per_msg;
|
slot_idx * num_bytes_per_msg;
|
||||||
if (dst_rank != rank) {
|
if (dst_rank != rank) {
|
||||||
|
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||||
|
if (peer_base_addr) {
|
||||||
|
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
|
||||||
|
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||||
|
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
|
||||||
|
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||||
|
} else {
|
||||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
|
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
|
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
|
||||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||||
@ -215,7 +223,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|||||||
// Wait local sends issued and send expert counts
|
// Wait local sends issued and send expert counts
|
||||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||||
if (dst_rank != rank) {
|
if (dst_rank != rank) {
|
||||||
|
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||||
|
if (peer_base_addr) { // P2P enabled
|
||||||
|
int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
|
||||||
|
st_na_release(rptr_actual, -num_tokens_sent - 1);
|
||||||
|
} else {
|
||||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||||
}
|
}
|
||||||
@ -421,9 +435,17 @@ combine(void* combined_x,
|
|||||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||||
if (not zero_copy)
|
if (not zero_copy)
|
||||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||||
|
|
||||||
|
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||||
|
if (peer_base_addr) {
|
||||||
|
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
|
||||||
|
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
|
||||||
|
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||||
|
} else {
|
||||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Put finishing flag
|
// Put finishing flag
|
||||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||||
@ -431,7 +453,13 @@ combine(void* combined_x,
|
|||||||
if (sub_warp_id == 1 and lane_id == 0) {
|
if (sub_warp_id == 1 and lane_id == 0) {
|
||||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||||
if (dst_rank != rank) {
|
if (dst_rank != rank) {
|
||||||
|
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
|
||||||
|
if (peer_base_addr) {
|
||||||
|
int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
|
||||||
|
st_na_release(req_rptr_actual, 1);
|
||||||
|
} else {
|
||||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
|
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||||
}
|
}
|
||||||
|
@ -68,6 +68,7 @@ class Buffer:
|
|||||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||||
# Enable IBGDA
|
# Enable IBGDA
|
||||||
assert num_qps_per_rank > 0
|
assert num_qps_per_rank > 0
|
||||||
|
if not os.getenv("NVSHMEM_DISABLE_P2P"):
|
||||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||||
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
||||||
|
Loading…
Reference in New Issue
Block a user