Refactor some code.

This commit is contained in:
Shangyan Zhou
2025-04-22 10:22:30 +08:00
parent c07fdd197c
commit 20b2aaaf9e
4 changed files with 90 additions and 61 deletions

View File

@@ -480,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const bool is_forwarder = sm_id % 2 == 0;
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels);
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
if (is_forwarder) {
if (warp_id < NUM_MAX_NVL_PEERS) {
@@ -556,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
if (lane_id < NUM_MAX_NVL_PEERS) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
}
__syncwarp();
// Issue RDMA for non-local ranks
if (dst_rdma_rank != rdma_rank and lane_id == 0) {
nvshmemi_ibgda_put_nbi_thread(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id, false);
}
nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmem_fence();
sync_rdma_sender_smem();
// Iterate over tokens and copy into buffer
@@ -711,12 +721,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t);
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
nvshmem_fence();
if (lane_id == dst_rdma_rank)
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
} else {
// Lighter fence for local RDMA rank
memory_fence();
@@ -727,13 +737,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
}
}
@@ -933,13 +938,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
if (lane_id != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id, lane_id == rdma_rank);
last_head = min_head;
}
@@ -1570,12 +1570,12 @@ combine(int4* combined_x, float* combined_topk_weights,
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t);
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
nvshmem_fence();
if (lane_id == 0)
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
} else {
memory_fence();
}
@@ -1583,13 +1583,8 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
__syncwarp();
if (lane_id == 0) {
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
}
}
}
@@ -1675,15 +1670,8 @@ combine(int4* combined_x, float* combined_topk_weights,
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
} else {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
}
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
last_rdma_head = min_head;
}
} else {