Add draft

This commit is contained in:
Chenggang Zhao 2025-06-23 11:45:05 +08:00
parent 7b0c25f864
commit 1c277c303e
2 changed files with 92 additions and 40 deletions

View File

@ -419,9 +419,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
// RDMA sender warp synchronization // RDMA sender warp synchronization
__shared__ volatile int rdma_send_next_token_idx; // NOTES: `rdma_send_channel_tail` means the latest released tail
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; __shared__ int rdma_send_channel_lock[kNumRDMARanks];
__shared__ int rdma_send_channel_tail[kNumRDMARanks];
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
// Forward warp synchronization // Forward warp synchronization
@ -434,12 +436,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int token_start_idx, token_end_idx; int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
// Send number of tokens in this channel by `-value - 1` // Send number of tokens in this channel by `-value - 1`
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
@ -476,14 +472,31 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id < kNumRDMARanks) if (lane_id < kNumRDMARanks)
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
// Acquire sequential lock // Acquire a tail
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
__syncwarp();
// Acquire next tail
int rdma_tail_idx = -1; int rdma_tail_idx = -1;
if (is_token_in_rank_uint64 != 0) { if (is_token_in_rank_uint64 != 0) {
rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++; do {
// Acquire lock first
acquire_lock(rdma_send_channel_lock + lane_id);
// If there is no remaining slot, continue
auto window = rdma_send_channel_window[lane_id];
auto latest_tail = rdma_send_channel_tail[lane_id];
if (window == 0xffffffffu) {
release_lock(rdma_send_channel_lock + lane_id);
continue;
}
// Use the first available slot
auto offset = std::__countr_one(window);
rdma_tail_idx = latest_tail + offset;
rdma_send_channel_window[lane_id] = window | (1u << offset);
// Release lock
release_lock(rdma_send_channel_lock + lane_id);
} while (rdma_tail_idx == -1);
// Wait the remote buffer to be released
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id))); cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
} }
@ -493,14 +506,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id < kNumRDMARanks and not kCachedMode) if (lane_id < kNumRDMARanks and not kCachedMode)
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
last_rdma_tail_idx = rdma_tail_idx;
// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
// Broadcast tails // Broadcast tails
SourceMeta src_meta; SourceMeta src_meta;
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
@ -557,24 +562,46 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
} }
__syncwarp();
// Release the transaction in the window
if (is_token_in_rank_uint64 != 0) {
// Acquire lock first
acquire_lock(rdma_send_channel_lock + lane_id);
// Release the transaction slot
auto window = rdma_send_channel_window[lane_id];
auto latest_tail = rdma_send_channel_tail[lane_id];
auto offset = rdma_tail_idx - last_rdma_tail_idx;
// Erase bit and move the zeros if possible
window ^= 1u << offset;
if (offset == 0) {
auto num_empty_slots = std::__countr_zero(window);
num_empty_slots = num_empty_slots == 32 ? 1 : num_empty_slots;
st_release_cta(rdma_send_channel_tail + lane_id, last_rdma_tail_idx + num_empty_slots);
window >>= num_empty_slots;
}
rdma_send_channel_window[lane_id] = window;
// TODO: remove this assertion after debugging
EP_DEVICE_ASSERT((window >> offset) & 1);
// Release lock
release_lock(rdma_send_channel_lock + lane_id);
}
__syncwarp();
} }
// Epilogue
// Acquire sequential lock
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
__syncwarp();
// Update last token tail
if (last_rdma_tail_idx >= 0)
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
__syncwarp();
// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
} else if (warp_role == WarpRole::kRDMASenderCoordinator) { } else if (warp_role == WarpRole::kRDMASenderCoordinator) {
// NOTES: in case of splitting, the issued put at the end of the buffer // NOTES: in case of splitting, the issued put at the end of the buffer
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// Clean shared memory
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
(lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;
(lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
(lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;
// Synchronize shared memory // Synchronize shared memory
sync_rdma_sender_smem(); sync_rdma_sender_smem();
@ -592,10 +619,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
// Timeout check // Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n", printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send); channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send);
trap(); trap();
} }
// TODO: try thread-level `put_nbi`?
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) { for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;
@ -603,9 +632,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (synced_num_tokens_to_send == 0) if (synced_num_tokens_to_send == 0)
continue; continue;
// Read progress // Read the latest progress
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); // NOTES: `rdma_send_channel_tail` does not need to be protected by lock
auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0); auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0);
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
auto num_tokens_processed = processed_tail - synced_last_issued_tail; auto num_tokens_processed = processed_tail - synced_last_issued_tail;
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
continue; continue;
@ -625,9 +655,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Lighter fence for local RDMA rank // Lighter fence for local RDMA rank
memory_fence(); memory_fence();
} }
__syncwarp();
// Update tails // Update tails
__syncwarp();
if (lane_id == dst_rdma_rank) { if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue; last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue;

View File

@ -466,4 +466,26 @@ barrier_block(int** barrier_signal_ptrs, int rank) {
__syncthreads(); __syncthreads();
} }
__forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) {
int ret;
asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory");
return ret;
}
__forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) {
int ret;
asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory");
return ret;
}
__forceinline__ __device__ void acquire_lock(int* mutex) {
// To make later memory operations valid, we must use `acquire` for memory semantics
while (atomic_cas_cta_acquire(mutex, 0, 1) != 0);
}
__forceinline__ __device__ void release_lock(int* mutex) {
// To make previous memory operations visible to other threads, we must use `release` for memory semantics
atomic_exch_cta_release(mutex, 0);
}
} // namespace deep_ep } // namespace deep_ep