mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use one qp per sm for internode normal kernels (#181)
let the sender SM use the channel_id, and the receiver SM use channel_id + num_channels
This commit is contained in:
parent
21efbe9b48
commit
05df5554ff
@ -831,7 +831,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|||||||
// Update remote head
|
// Update remote head
|
||||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
|
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
|
||||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id, lane_id == rdma_rank);
|
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id + num_channels, lane_id == rdma_rank);
|
||||||
last_head = min_head;
|
last_head = min_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1563,7 +1563,7 @@ combine(int4* combined_x, float* combined_topk_weights,
|
|||||||
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
||||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
|
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
|
||||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
|
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id + num_channels, dst_rdma_rank == rdma_rank);
|
||||||
last_rdma_head = min_head;
|
last_rdma_head = min_head;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -225,7 +225,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
|
|||||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||||
|
|
||||||
num_sms = 24
|
num_sms = 24
|
||||||
num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0)
|
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if test_ll_compatibility else 0)
|
||||||
|
|
||||||
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
|
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
|
||||||
num_qps_per_rank=num_qps_per_rank)
|
num_qps_per_rank=num_qps_per_rank)
|
||||||
|
Loading…
Reference in New Issue
Block a user