mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix bugs
This commit is contained in:
parent
b3e39fcbbb
commit
901cdf79be
@ -365,7 +365,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
const bool is_forwarder = sm_id % 2 == 0;
|
||||
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms);
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms);
|
||||
|
||||
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
||||
if (is_forwarder) {
|
||||
@ -423,8 +423,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status
|
||||
__shared__ int rdma_send_channel_lock[kNumRDMARanks];
|
||||
__shared__ int rdma_send_channel_tail[kNumRDMARanks];
|
||||
__shared__ uint32_t rdma_send_channel_wip_window[kNumRDMARanks];
|
||||
__shared__ uint32_t rdma_send_channel_rdy_window[kNumRDMARanks];
|
||||
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
|
||||
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
|
||||
|
||||
// Forward warp synchronization
|
||||
@ -465,52 +464,32 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Iterate over tokens and copy into buffer
|
||||
int64_t token_idx;
|
||||
int cached_rdma_channel_head = 0;
|
||||
int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0;
|
||||
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
|
||||
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
|
||||
for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) {
|
||||
// Read RDMA rank existence
|
||||
uint64_t is_token_in_rank_uint64 = 0;
|
||||
if (lane_id < kNumRDMARanks)
|
||||
if (lane_id < kNumRDMARanks) {
|
||||
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
|
||||
global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Acquire a tail
|
||||
int rdma_tail_idx = -1;
|
||||
if (is_token_in_rank_uint64 != 0) {
|
||||
while (true) {
|
||||
// Acquire lock first
|
||||
acquire_lock(rdma_send_channel_lock + lane_id);
|
||||
// Skip the token which does not belong to this warp
|
||||
if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id)
|
||||
continue;
|
||||
auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1;
|
||||
|
||||
// If there is no remaining slot, continue
|
||||
auto wip_window = rdma_send_channel_wip_window[lane_id];
|
||||
auto rdy_window = rdma_send_channel_rdy_window[lane_id];
|
||||
auto window = wip_window | rdy_window;
|
||||
auto latest_tail = rdma_send_channel_tail[lane_id];
|
||||
// Wait the remote buffer to be released
|
||||
auto start_time = clock64();
|
||||
while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
|
||||
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
|
||||
// The same effect with `EP_DEVICE_ASSERT(window != 0xffffffffu);`
|
||||
EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps");
|
||||
|
||||
// Use the first available slot
|
||||
// NOTES: do not use `std::countr_one`, as it is buggy in CUDA C++
|
||||
auto offset = __ffs(~window) - 1;
|
||||
rdma_tail_idx = latest_tail + offset;
|
||||
rdma_send_channel_wip_window[lane_id] = wip_window | (1u << offset);
|
||||
|
||||
// Release lock
|
||||
release_lock(rdma_send_channel_lock + lane_id);
|
||||
break;
|
||||
}
|
||||
|
||||
// Wait the remote buffer to be released
|
||||
auto start_time = clock64();
|
||||
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
|
||||
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx);
|
||||
trap();
|
||||
}
|
||||
// Timeout check
|
||||
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
@ -583,22 +562,22 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
acquire_lock(rdma_send_channel_lock + lane_id);
|
||||
|
||||
// Release the transaction slot
|
||||
auto wip_window = rdma_send_channel_wip_window[lane_id];
|
||||
auto rdy_window = rdma_send_channel_rdy_window[lane_id];
|
||||
auto rdy_window = rdma_send_channel_window[lane_id];
|
||||
auto latest_tail = rdma_send_channel_tail[lane_id];
|
||||
auto offset = rdma_tail_idx - latest_tail;
|
||||
|
||||
// Erase bit and move the zeros if possible
|
||||
wip_window ^= 1u << offset;
|
||||
// The same effect with `EP_DEVICE_ASSERT(offset < 32);`
|
||||
EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps");
|
||||
|
||||
// Erase bit and move the ones if possible
|
||||
rdy_window ^= 1u << offset;
|
||||
if (offset == 0) {
|
||||
EP_DEVICE_ASSERT(rdy_window & 1);
|
||||
auto num_empty_slots = __ffs(~rdy_window) - 1;
|
||||
st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots);
|
||||
wip_window >>= num_empty_slots, rdy_window >>= num_empty_slots;
|
||||
rdy_window >>= num_empty_slots;
|
||||
}
|
||||
rdma_send_channel_wip_window[lane_id] = wip_window;
|
||||
rdma_send_channel_rdy_window[lane_id] = rdy_window;
|
||||
rdma_send_channel_window[lane_id] = rdy_window;
|
||||
|
||||
// Release lock
|
||||
release_lock(rdma_send_channel_lock + lane_id);
|
||||
@ -613,8 +592,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_wip_window[lane_id] = 0) : 0;
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_rdy_window[lane_id] = 0) : 0;
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;
|
||||
|
||||
// Synchronize shared memory
|
||||
sync_rdma_sender_smem();
|
||||
|
Loading…
Reference in New Issue
Block a user