From 05df5554ff3c0f304b14909cabfb783541672c66 Mon Sep 17 00:00:00 2001 From: Zhicheng Wu Date: Fri, 13 Jun 2025 14:37:59 +0800 Subject: [PATCH] Use one qp per sm for internode normal kernels (#181) let the sender SM use the channel_id, and the receiver SM use channel_id + num_channels --- csrc/kernels/internode.cu | 4 ++-- tests/test_internode.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 57203dc..7e5eb33 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -831,7 +831,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, - translate_dst_rdma_rank(lane_id, nvl_rank), channel_id, lane_id == rdma_rank); + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id + num_channels, lane_id == rdma_rank); last_head = min_head; } @@ -1563,7 +1563,7 @@ combine(int4* combined_x, float* combined_topk_weights, min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id + num_channels, dst_rdma_rank == rdma_rank); last_rdma_head = min_head; } } else { diff --git a/tests/test_internode.py b/tests/test_internode.py index a4ac104..4aeca49 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -225,7 +225,7 @@ def test_loop(local_rank: int, num_local_ranks: int): ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 num_sms = 24 - num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0) + num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if test_ll_compatibility else 0) buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, num_qps_per_rank=num_qps_per_rank)