From 68ae8b3d07c2e78c03ae6eb4af3be3cb67167152 Mon Sep 17 00:00:00 2001 From: cywork121 Date: Fri, 23 May 2025 10:37:45 +0800 Subject: [PATCH] Feature: LL nvlink p2p (#173) --- csrc/kernels/internode_ll.cu | 36 ++++++++++++++++++++++++++++++++---- deep_ep/buffer.py | 3 ++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 8e0d9e4..4bd75e5 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -150,7 +150,15 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; if (dst_rank != rank) { - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); + void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); + if (peer_base_addr) { + char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base)); + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(req_rptr_actual); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } else { + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); + } } else { // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast(src_ptr); @@ -215,7 +223,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); + void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); + if (peer_base_addr) { // P2P enabled + int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base))); + st_na_release(rptr_actual, -num_tokens_sent - 1); + } else { + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); + } } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); } @@ -421,7 +435,15 @@ combine(void* combined_x, const auto buf_int4_ptr = reinterpret_cast(buf_ptr); if (not zero_copy) UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); + + void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); + if (peer_base_addr) { + char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base)); + const auto dst_int4_ptr = reinterpret_cast(req_rptr_actual); + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); + } else { + nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); + } } } @@ -431,7 +453,13 @@ combine(void* combined_x, if (sub_warp_id == 1 and lane_id == 0) { while (ld_acquire_global(atomic_clean_flag) == 0); if (dst_rank != rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx); + void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); + if (peer_base_addr) { + int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base))); + st_na_release(req_rptr_actual, 1); + } else { + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx); + } } else { st_na_release(rdma_recv_flag + global_expert_idx, 1); } diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index feeb386..913144b 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -68,7 +68,8 @@ class Buffer: if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: # Enable IBGDA assert num_qps_per_rank > 0 - os.environ['NVSHMEM_DISABLE_P2P'] = '1' + if not os.getenv("NVSHMEM_DISABLE_P2P"): + os.environ['NVSHMEM_DISABLE_P2P'] = '1' os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'