mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Add transaction windows
This commit is contained in:
parent
185ecf5c4a
commit
d7d13878e0
@ -429,9 +429,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
|
||||
|
||||
// Forward warp synchronization
|
||||
__shared__ volatile int forward_channel_nvl_tail_allocator[NUM_MAX_NVL_PEERS];
|
||||
__shared__ volatile int forward_channel_nvl_tail[kNumRDMARanks][NUM_MAX_NVL_PEERS];
|
||||
__shared__ volatile bool forward_channel_retired[kNumRDMARanks];
|
||||
__shared__ int forward_channel_nvl_lock[NUM_MAX_NVL_PEERS];
|
||||
__shared__ int forward_channel_nvl_tail[NUM_MAX_NVL_PEERS];
|
||||
__shared__ uint32_t forward_channel_nvl_wip_window[NUM_MAX_NVL_PEERS];
|
||||
__shared__ uint32_t forward_channel_nvl_rdy_window[NUM_MAX_NVL_PEERS];
|
||||
__shared__ bool forward_channel_retired[kNumRDMARanks];
|
||||
auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMARanks + 1) * 32)); };
|
||||
|
||||
if (warp_role == WarpRole::kRDMASender) {
|
||||
@ -628,7 +630,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Read the latest progress
|
||||
// NOTES: `rdma_send_channel_tail` does not need to be protected by lock
|
||||
auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0);
|
||||
auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(rdma_send_channel_tail + dst_rdma_rank), 0);
|
||||
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
|
||||
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
|
||||
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
|
||||
@ -752,16 +754,37 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
continue;
|
||||
}
|
||||
|
||||
// Allocate a slot tail and wait until that we can ensure the slot is safe to overwrite
|
||||
int dst_slot_idx;
|
||||
// Allocate a slot tail
|
||||
int nvl_tail_idx;
|
||||
if (lane_id == 0) {
|
||||
dst_slot_idx = atomicAdd_block(const_cast<int*>(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1);
|
||||
// Acquire lock first
|
||||
acquire_lock(forward_channel_nvl_lock + dst_nvl_rank);
|
||||
|
||||
// Allocate a new bit slot
|
||||
auto wip_window = forward_channel_nvl_wip_window[dst_nvl_rank];
|
||||
auto rdy_window = forward_channel_nvl_rdy_window[dst_nvl_rank];
|
||||
auto latest_tail = forward_channel_nvl_tail[dst_nvl_rank];
|
||||
auto window = wip_window | rdy_window;
|
||||
|
||||
// The same effect with `EP_DEVICE_ASSERT(window != 0xffffffffu);`
|
||||
EP_STATIC_ASSERT(kNumRDMARanks < 32, "Invalid RDMA ranks");
|
||||
|
||||
// Use the first available slot
|
||||
// NOTES: do not use `std::countr_one`, as it is buggy in CUDA C++
|
||||
auto offset = __ffs(~window) - 1;
|
||||
nvl_tail_idx = latest_tail + offset;
|
||||
forward_channel_nvl_wip_window[dst_nvl_rank] = wip_window | (1u << offset);
|
||||
|
||||
// Release lock
|
||||
release_lock(forward_channel_nvl_lock + dst_nvl_rank);
|
||||
|
||||
// Wait until that we can ensure the slot is safe to overwrite
|
||||
if constexpr (kCachedMode)
|
||||
*shifted_send_nvl_head = dst_slot_idx;
|
||||
while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
|
||||
*shifted_send_nvl_head = nvl_tail_idx;
|
||||
while (nvl_tail_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
|
||||
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
|
||||
}
|
||||
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx % num_max_nvl_chunked_recv_tokens, 0);
|
||||
auto dst_slot_idx = __shfl_sync(0xffffffff, nvl_tail_idx % num_max_nvl_chunked_recv_tokens, 0);
|
||||
|
||||
// Copy data
|
||||
// The `shifted` should be restored
|
||||
@ -805,11 +828,32 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
weight_value = idx_value >= 0 ? weight_value : 0.0f;
|
||||
st_na_global(dst_nvl_channel_topk_weights + dst_slot_idx * num_topk + lane_id, weight_value);
|
||||
}
|
||||
|
||||
// Move tail index
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
st_release_cta(const_cast<int*>(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank), dst_slot_idx + 1);
|
||||
|
||||
// Release the window slot
|
||||
// TODO: parallelize all these
|
||||
if (lane_id == 0) {
|
||||
// Acquire lock first
|
||||
acquire_lock(forward_channel_nvl_lock + dst_nvl_rank);
|
||||
|
||||
// Erase bit and move the zeros if possible
|
||||
auto wip_window = forward_channel_nvl_wip_window[dst_nvl_rank];
|
||||
auto rdy_window = forward_channel_nvl_rdy_window[dst_nvl_rank];
|
||||
auto latest_tail = forward_channel_nvl_tail[dst_nvl_rank];
|
||||
auto offset = nvl_tail_idx - latest_tail;
|
||||
wip_window ^= 1u << offset;
|
||||
rdy_window ^= 1u << offset;
|
||||
if (offset == 0) {
|
||||
auto num_empty_slots = __ffs(~rdy_window) - 1;
|
||||
st_release_cta(forward_channel_nvl_tail + dst_nvl_rank, latest_tail + num_empty_slots);
|
||||
wip_window >>= num_empty_slots, rdy_window >>= num_empty_slots;
|
||||
}
|
||||
forward_channel_nvl_wip_window[dst_nvl_rank] = wip_window;
|
||||
forward_channel_nvl_rdy_window[dst_nvl_rank] = rdy_window;
|
||||
|
||||
// Release lock
|
||||
release_lock(forward_channel_nvl_lock + dst_nvl_rank);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
@ -838,10 +882,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
return;
|
||||
|
||||
// Clean shared memory
|
||||
for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32)
|
||||
forward_channel_nvl_tail[i / NUM_MAX_NVL_PEERS][i % NUM_MAX_NVL_PEERS] = 0;
|
||||
if (lane_id < NUM_MAX_NVL_PEERS)
|
||||
forward_channel_nvl_tail_allocator[lane_id] = 0;
|
||||
if (lane_id < NUM_MAX_NVL_PEERS) {
|
||||
forward_channel_nvl_lock[lane_id] = 0;
|
||||
forward_channel_nvl_tail[lane_id] = 0;
|
||||
forward_channel_nvl_wip_window[lane_id] = 0;
|
||||
forward_channel_nvl_rdy_window[lane_id] = 0;
|
||||
}
|
||||
if (lane_id < kNumRDMARanks)
|
||||
forward_channel_retired[lane_id] = 0;
|
||||
sync_forwarder_smem();
|
||||
@ -849,19 +895,20 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Update minimum tail
|
||||
int last_tail = 0, dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
|
||||
while (true) {
|
||||
// Find minimum tail
|
||||
int min_tail = std::numeric_limits<int>::max();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRDMARanks; ++ i) if (not forward_channel_retired[i])
|
||||
min_tail = min(min_tail, forward_channel_nvl_tail[i][dst_nvl_rank]);
|
||||
if (__all_sync(0xffffffff, min_tail == std::numeric_limits<int>::max()))
|
||||
if (__all_sync(0xffffffff, forward_channel_retired[dst_nvl_rank]))
|
||||
break;
|
||||
|
||||
// Update remote tail
|
||||
// TODO: control update interval
|
||||
if (lane_id < NUM_MAX_NVL_PEERS and min_tail != std::numeric_limits<int>::max() and min_tail > last_tail)
|
||||
st_release_sys_global(nvl_channel_tail.buffer(), min_tail);
|
||||
if (lane_id < NUM_MAX_NVL_PEERS) {
|
||||
auto new_tail = ld_acquire_cta(forward_channel_nvl_tail + dst_nvl_rank);
|
||||
if (new_tail > last_tail)
|
||||
st_release_sys_global(nvl_channel_tail.buffer(), last_tail = new_tail);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Let other warps work
|
||||
__nanosleep(NUM_WAIT_NANOSECONDS);
|
||||
}
|
||||
} else {
|
||||
// NVL consumers
|
||||
|
Loading…
Reference in New Issue
Block a user