diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index b2d5024..b46a3a0 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1190,8 +1190,9 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id } torch::Tensor -Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { +Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto buffer = layout.buffers[low_latency_buffer_idx]; auto dtype = torch::kBFloat16; auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index e0ad4d6..a12a8a0 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -147,7 +147,7 @@ public: const std::optional& out = std::nullopt); torch::Tensor - get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); + get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; }; } // namespace deep_ep diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 9f8c37c..2259bc4 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -432,4 +432,18 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons } } +__device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) { + // Local rank, no need for mapping + if (rank == dst_rank) + return ptr; + auto peer_base = __ldg(reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank); + + // RDMA connected + if (peer_base == 0) + return 0; + + // NVLink P2P is enabled + return peer_base + (ptr - reinterpret_cast(nvshmemi_device_state_d.heap_base)); +} + } // namespace deep_ep diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 4bd75e5..899cc11 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -149,20 +149,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; - if (dst_rank != rank) { - void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); - if (peer_base_addr) { - char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base)); - const auto* src_int4_ptr = reinterpret_cast(src_ptr); - const auto* dst_int4_ptr = reinterpret_cast(req_rptr_actual); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); - } else { - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); - } + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); } else { // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast(src_ptr); - const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } @@ -222,16 +215,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); - if (dst_rank != rank) { - void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); - if (peer_base_addr) { // P2P enabled - int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base))); - st_na_release(rptr_actual, -num_tokens_sent - 1); - } else { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); - } + auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); } else { - st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); } // Clean workspace for next use @@ -428,22 +417,15 @@ combine(void* combined_x, auto src_idx = __ldg(local_src_info + token_idx); const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; - if (dst_rank == rank) { - const auto dst_int4_ptr = reinterpret_cast(dst_ptr); - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); - } else { + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { const auto buf_int4_ptr = reinterpret_cast(buf_ptr); if (not zero_copy) UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); - - void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); - if (peer_base_addr) { - char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base)); - const auto dst_int4_ptr = reinterpret_cast(req_rptr_actual); - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); - } else { - nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); - } + nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset); + } else { + const auto dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); } } @@ -452,16 +434,12 @@ combine(void* combined_x, asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32)); if (sub_warp_id == 1 and lane_id == 0) { while (ld_acquire_global(atomic_clean_flag) == 0); - if (dst_rank != rank) { - void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); - if (peer_base_addr) { - int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base))); - st_na_release(req_rptr_actual, 1); - } else { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx); - } + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); } else { - st_na_release(rdma_recv_flag + global_expert_idx, 1); + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); } atomic_add_release_global(atomic_clean_flag, -1); } @@ -473,7 +451,7 @@ combine(void* combined_x, if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; - // Wait all ranks to arrive and notify PCIe usage + // Wait all ranks to arrive if (responsible_expert_idx < num_experts) { EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); if (sub_warp_id == 0 and lane_id == 0) diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 913144b..8527f24 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -31,7 +31,8 @@ class Buffer: def __init__(self, group: dist.ProcessGroup, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, - low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None: + low_latency_mode: bool = False, num_qps_per_rank: int = 12, + allow_nvlink_for_low_latency_mode: bool = False) -> None: """ Initialize the communication buffer. @@ -42,6 +43,10 @@ class Buffer: low_latency_mode: whether to enable low-latency mode. num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals to the number of local experts. + allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice + this is somehow incompatible with the hook-based overlapping. + Warning: PCIe connections may lead to errors due to memory ordering issues, + please make sure all connections are via NVLink. """ # Initialize the CPP runtime @@ -68,8 +73,7 @@ class Buffer: if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: # Enable IBGDA assert num_qps_per_rank > 0 - if not os.getenv("NVSHMEM_DISABLE_P2P"): - os.environ['NVSHMEM_DISABLE_P2P'] = '1' + os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1' os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'