Refactor some code.

This commit is contained in:
Shangyan Zhou 2025-04-22 10:22:30 +08:00
parent c07fdd197c
commit 20b2aaaf9e
4 changed files with 90 additions and 61 deletions

View File

@ -410,20 +410,60 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
}
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op(reinterpret_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
} else {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
}
}
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
int dst_pe, int qp_id, bool is_local_copy) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
// TODO: rewrite local API copy with unrolling and vectorization
nvshmem_uint8_put_nbi(reinterpret_cast<uint8_t*>(req_rptr), reinterpret_cast<uint8_t*>(req_lptr), bytes, dst_pe);
} else {
uint32_t num_wqes = 0;
uint64_t base_wqe_idx = 0;
auto qp = ibgda_get_rc(dst_pe, qp_id);
while (bytes > 0) {
__be32 lkey, rkey;
uint64_t laddr, raddr, chunk_size;
chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey));
bytes -= chunk_size;
auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);
// Only the last WQE should send imm
ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr);
req_lptr += chunk_size;
req_rptr += chunk_size;
if ((num_wqes ++) == 0)
base_wqe_idx = wqe_idx;
}
ibgda_submit_requests<true>(qp, base_wqe_idx, num_wqes);
}
}
} // namespace deep_ep

View File

@ -480,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels);
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) {
if (warp_id < NUM_MAX_NVL_PEERS) {
@ -556,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
if (lane_id < NUM_MAX_NVL_PEERS) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
}
__syncwarp();
// Issue RDMA for non-local ranks
if (dst_rdma_rank != rdma_rank and lane_id == 0) {
nvshmemi_ibgda_put_nbi_thread(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id, false);
}
nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmem_fence();
sync_rdma_sender_smem();
// Iterate over tokens and copy into buffer
@ -711,12 +721,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t);
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
nvshmem_fence();
if (lane_id == dst_rdma_rank)
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
} else {
// Lighter fence for local RDMA rank
memory_fence();
@ -727,13 +737,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
}
}
@ -933,13 +938,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
if (lane_id != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id, lane_id == rdma_rank);
last_head = min_head;
}
@ -1570,12 +1570,12 @@ combine(int4* combined_x, float* combined_topk_weights,
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t);
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
nvshmem_fence();
if (lane_id == 0)
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
} else {
memory_fence();
}
@ -1583,13 +1583,8 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
__syncwarp();
if (lane_id == 0) {
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
}
}
@ -1675,15 +1670,8 @@ combine(int4* combined_x, float* combined_topk_weights,
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
last_rdma_head = min_head;
}
} else {

View File

@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false);
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
}

View File

@ -218,15 +218,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv('WORLD_SIZE', 1))
num_sms = 24
qp_num = num_sms // 2
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
test_ll_compatibility = True
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_sms = 24
num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num))
num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)