diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 1e8d871..eebdc59 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -410,20 +410,60 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe( st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } -__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) { - nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); +__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { + if (is_local_copy) { + // Fallback to NVSHMEM legacy API + nvshmemx_signal_op(reinterpret_cast(rptr), value, NVSHMEM_SIGNAL_ADD, pe); + } else { + nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); - __be32 rkey; - uint64_t raddr; - ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); - uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); - void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); + uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); - ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), - qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); + ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), + qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); - ibgda_submit_requests(qp, my_wqe_idx, 1); + ibgda_submit_requests(qp, my_wqe_idx, 1); + } +} + +__device__ static __forceinline__ void +nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, + int dst_pe, int qp_id, bool is_local_copy) { + if (is_local_copy) { + // Fallback to NVSHMEM legacy API + // TODO: rewrite local API copy with unrolling and vectorization + nvshmem_uint8_put_nbi(reinterpret_cast(req_rptr), reinterpret_cast(req_lptr), bytes, dst_pe); + } else { + uint32_t num_wqes = 0; + uint64_t base_wqe_idx = 0; + + auto qp = ibgda_get_rc(dst_pe, qp_id); + while (bytes > 0) { + __be32 lkey, rkey; + uint64_t laddr, raddr, chunk_size; + + chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey)); + bytes -= chunk_size; + + auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); + + // Only the last WQE should send imm + ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr); + + req_lptr += chunk_size; + req_rptr += chunk_size; + if ((num_wqes ++) == 0) + base_wqe_idx = wqe_idx; + } + + ibgda_submit_requests(qp, base_wqe_idx, num_wqes); + } } } // namespace deep_ep diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 5fcf7a3..0d29aec 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -480,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels); + const auto role_meta = [=]() -> std::pair { if (is_forwarder) { if (warp_id < NUM_MAX_NVL_PEERS) { @@ -556,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); if (lane_id < NUM_MAX_NVL_PEERS) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; + dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + // Issue RDMA for non-local ranks + if (dst_rdma_rank != rdma_rank and lane_id == 0) { + nvshmemi_ibgda_put_nbi_thread(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, false); } - nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } - nvshmem_fence(); sync_rdma_sender_smem(); // Iterate over tokens and copy into buffer @@ -711,12 +721,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); - const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t); + const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); - nvshmem_fence(); + if (lane_id == dst_rdma_rank) + nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); } else { // Lighter fence for local RDMA rank memory_fence(); @@ -727,13 +737,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - if (dst_rdma_rank != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } } } @@ -933,13 +938,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - if (lane_id != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, - translate_dst_rdma_rank(lane_id, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(lane_id, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id, lane_id == rdma_rank); last_head = min_head; } @@ -1570,12 +1570,12 @@ combine(int4* combined_x, float* combined_topk_weights, if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t); + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); - nvshmem_fence(); + if (lane_id == 0) + nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); } else { memory_fence(); } @@ -1583,13 +1583,8 @@ combine(int4* combined_x, float* combined_topk_weights, // Write new RDMA tail __syncwarp(); if (lane_id == 0) { - if (dst_rdma_rank != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } } } @@ -1675,15 +1670,8 @@ combine(int4* combined_x, float* combined_topk_weights, for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - if (dst_rdma_rank != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); last_rdma_head = min_head; } } else { diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index c33e062..74fe0bd 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, EP_DEVICE_ASSERT(num_sms > 1); if (sm_id == 0) { // The first SM is also responsible for checking QPs - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); // The first SM is also responsible for cleaning the next buffer #pragma unroll @@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false); } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); } diff --git a/tests/test_internode.py b/tests/test_internode.py index 456ea64..e9b3d57 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -218,15 +218,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in # noinspection PyUnboundLocalVariable def test_loop(local_rank: int, num_local_ranks: int): num_nodes = int(os.getenv('WORLD_SIZE', 1)) - num_sms = 24 - qp_num = num_sms // 2 rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - test_ll_compatibility = False + test_ll_compatibility = True if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + num_sms = 24 + num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0) + buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num)) + num_qps_per_rank=num_qps_per_rank) assert num_local_ranks == 8 and num_ranks > 8 torch.manual_seed(rank)