|
|
|
|
@@ -365,7 +365,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
const bool is_forwarder = sm_id % 2 == 0;
|
|
|
|
|
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
|
|
|
|
|
|
|
|
|
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms);
|
|
|
|
|
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms);
|
|
|
|
|
|
|
|
|
|
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
|
|
|
|
if (is_forwarder) {
|
|
|
|
|
@@ -419,9 +419,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
|
|
|
|
|
|
|
|
|
// RDMA sender warp synchronization
|
|
|
|
|
__shared__ volatile int rdma_send_next_token_idx;
|
|
|
|
|
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
|
|
|
|
|
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
|
|
|
|
|
// NOTES: `rdma_send_channel_tail` means the latest released tail
|
|
|
|
|
// NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status
|
|
|
|
|
__shared__ int rdma_send_channel_lock[kNumRDMARanks];
|
|
|
|
|
__shared__ int rdma_send_channel_tail[kNumRDMARanks];
|
|
|
|
|
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
|
|
|
|
|
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
|
|
|
|
|
|
|
|
|
|
// Forward warp synchronization
|
|
|
|
|
@@ -434,12 +436,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
int token_start_idx, token_end_idx;
|
|
|
|
|
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
|
|
|
|
|
|
|
|
|
// Clean shared memory
|
|
|
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
|
|
|
|
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
|
|
|
|
|
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
|
|
|
|
|
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
|
|
|
|
|
|
|
|
|
|
// Send number of tokens in this channel by `-value - 1`
|
|
|
|
|
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
|
|
|
|
|
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
|
|
|
|
|
@@ -468,24 +464,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
|
|
|
|
|
// Iterate over tokens and copy into buffer
|
|
|
|
|
int64_t token_idx;
|
|
|
|
|
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
|
|
|
|
|
int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0;
|
|
|
|
|
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
|
|
|
|
|
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
|
|
|
|
|
for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) {
|
|
|
|
|
// Read RDMA rank existence
|
|
|
|
|
uint64_t is_token_in_rank_uint64 = 0;
|
|
|
|
|
if (lane_id < kNumRDMARanks)
|
|
|
|
|
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
|
|
|
|
|
|
|
|
|
|
// Acquire sequential lock
|
|
|
|
|
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
|
|
|
|
|
if (lane_id < kNumRDMARanks) {
|
|
|
|
|
is_token_in_rank_uint64 = __ldg(reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS));
|
|
|
|
|
global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);
|
|
|
|
|
}
|
|
|
|
|
__syncwarp();
|
|
|
|
|
|
|
|
|
|
// Acquire next tail
|
|
|
|
|
int rdma_tail_idx = -1;
|
|
|
|
|
if (is_token_in_rank_uint64 != 0) {
|
|
|
|
|
rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++;
|
|
|
|
|
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
|
|
|
|
|
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
|
|
|
|
// Skip the token which does not belong to this warp
|
|
|
|
|
if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id)
|
|
|
|
|
continue;
|
|
|
|
|
auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1;
|
|
|
|
|
|
|
|
|
|
// Wait the remote buffer to be released
|
|
|
|
|
auto start_time = clock64();
|
|
|
|
|
while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
|
|
|
|
|
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
|
|
|
|
|
|
|
|
|
// Timeout check
|
|
|
|
|
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
|
|
|
|
|
printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n",
|
|
|
|
|
channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx);
|
|
|
|
|
trap();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncwarp();
|
|
|
|
|
|
|
|
|
|
@@ -493,14 +498,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
if (lane_id < kNumRDMARanks and not kCachedMode)
|
|
|
|
|
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
|
|
|
|
|
|
|
|
|
|
// Update last token tail
|
|
|
|
|
if (last_rdma_tail_idx >= 0)
|
|
|
|
|
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
|
|
|
|
last_rdma_tail_idx = rdma_tail_idx;
|
|
|
|
|
|
|
|
|
|
// Release sequential lock
|
|
|
|
|
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
|
|
|
|
|
|
|
|
|
// Broadcast tails
|
|
|
|
|
SourceMeta src_meta;
|
|
|
|
|
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
|
|
|
|
|
@@ -557,24 +554,46 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
|
|
|
|
|
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
|
|
|
|
|
}
|
|
|
|
|
__syncwarp();
|
|
|
|
|
|
|
|
|
|
// Release the transaction in the window
|
|
|
|
|
if (is_token_in_rank_uint64 != 0) {
|
|
|
|
|
// Acquire lock first
|
|
|
|
|
acquire_lock(rdma_send_channel_lock + lane_id);
|
|
|
|
|
|
|
|
|
|
// Release the transaction slot
|
|
|
|
|
auto rdy_window = rdma_send_channel_window[lane_id];
|
|
|
|
|
auto latest_tail = rdma_send_channel_tail[lane_id];
|
|
|
|
|
auto offset = rdma_tail_idx - latest_tail;
|
|
|
|
|
|
|
|
|
|
// The same effect with `EP_DEVICE_ASSERT(offset < 32);`
|
|
|
|
|
EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps");
|
|
|
|
|
|
|
|
|
|
// Erase bit and move the ones if possible
|
|
|
|
|
rdy_window ^= 1u << offset;
|
|
|
|
|
if (offset == 0) {
|
|
|
|
|
EP_DEVICE_ASSERT(rdy_window & 1);
|
|
|
|
|
auto num_empty_slots = __ffs(~rdy_window) - 1;
|
|
|
|
|
st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots);
|
|
|
|
|
rdy_window >>= num_empty_slots;
|
|
|
|
|
}
|
|
|
|
|
rdma_send_channel_window[lane_id] = rdy_window;
|
|
|
|
|
|
|
|
|
|
// Release lock
|
|
|
|
|
release_lock(rdma_send_channel_lock + lane_id);
|
|
|
|
|
}
|
|
|
|
|
__syncwarp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Epilogue
|
|
|
|
|
// Acquire sequential lock
|
|
|
|
|
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
|
|
|
|
|
__syncwarp();
|
|
|
|
|
|
|
|
|
|
// Update last token tail
|
|
|
|
|
if (last_rdma_tail_idx >= 0)
|
|
|
|
|
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
|
|
|
|
__syncwarp();
|
|
|
|
|
|
|
|
|
|
// Release sequential lock
|
|
|
|
|
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
|
|
|
|
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
|
|
|
|
|
// NOTES: in case of splitting, the issued put at the end of the buffer
|
|
|
|
|
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
|
|
|
|
|
|
|
|
|
|
// Clean shared memory
|
|
|
|
|
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
|
|
|
|
(lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;
|
|
|
|
|
(lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
|
|
|
|
|
(lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;
|
|
|
|
|
|
|
|
|
|
// Synchronize shared memory
|
|
|
|
|
sync_rdma_sender_smem();
|
|
|
|
|
|
|
|
|
|
@@ -592,10 +611,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
|
|
|
|
|
// Timeout check
|
|
|
|
|
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
|
|
|
|
|
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n",
|
|
|
|
|
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n",
|
|
|
|
|
channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send);
|
|
|
|
|
trap();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: try thread-level `put_nbi`?
|
|
|
|
|
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
|
|
|
|
|
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
|
|
|
|
|
int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;
|
|
|
|
|
@@ -603,9 +624,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
if (synced_num_tokens_to_send == 0)
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
// Read progress
|
|
|
|
|
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
|
|
|
|
|
// Read the latest progress
|
|
|
|
|
// NOTES: `rdma_send_channel_tail` does not need to be protected by lock
|
|
|
|
|
auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0);
|
|
|
|
|
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
|
|
|
|
|
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
|
|
|
|
|
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
|
|
|
|
|
continue;
|
|
|
|
|
@@ -625,9 +647,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|
|
|
|
// Lighter fence for local RDMA rank
|
|
|
|
|
memory_fence();
|
|
|
|
|
}
|
|
|
|
|
__syncwarp();
|
|
|
|
|
|
|
|
|
|
// Update tails
|
|
|
|
|
__syncwarp();
|
|
|
|
|
if (lane_id == dst_rdma_rank) {
|
|
|
|
|
last_issued_tail += num_tokens_to_issue;
|
|
|
|
|
num_tokens_to_send -= num_tokens_to_issue;
|
|
|
|
|
|