mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix several bugs
This commit is contained in:
parent
11053474b7
commit
be96674e94
@ -464,7 +464,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|||||||
|
|
||||||
// Iterate over tokens and copy into buffer
|
// Iterate over tokens and copy into buffer
|
||||||
int64_t token_idx;
|
int64_t token_idx;
|
||||||
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
|
int cached_rdma_channel_head = 0;
|
||||||
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
|
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
|
||||||
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
|
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
|
||||||
// Read RDMA rank existence
|
// Read RDMA rank existence
|
||||||
@ -488,7 +488,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use the first available slot
|
// Use the first available slot
|
||||||
auto offset = std::__countr_one(window);
|
// NOTES: do not use `std::countr_one`, as it is buggy in CUDA C++
|
||||||
|
auto offset = __ffs(~window) - 1;
|
||||||
rdma_tail_idx = latest_tail + offset;
|
rdma_tail_idx = latest_tail + offset;
|
||||||
rdma_send_channel_window[lane_id] = window | (1u << offset);
|
rdma_send_channel_window[lane_id] = window | (1u << offset);
|
||||||
|
|
||||||
@ -572,21 +573,20 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
|||||||
// Release the transaction slot
|
// Release the transaction slot
|
||||||
auto window = rdma_send_channel_window[lane_id];
|
auto window = rdma_send_channel_window[lane_id];
|
||||||
auto latest_tail = rdma_send_channel_tail[lane_id];
|
auto latest_tail = rdma_send_channel_tail[lane_id];
|
||||||
auto offset = rdma_tail_idx - last_rdma_tail_idx;
|
auto offset = rdma_tail_idx - latest_tail;
|
||||||
|
|
||||||
|
// TODO: remove this assertion after debugging
|
||||||
|
EP_DEVICE_ASSERT((window >> offset) & 1 and (window & 1));
|
||||||
|
|
||||||
// Erase bit and move the zeros if possible
|
// Erase bit and move the zeros if possible
|
||||||
window ^= 1u << offset;
|
window ^= 1u << offset;
|
||||||
if (offset == 0) {
|
if (offset == 0) {
|
||||||
auto num_empty_slots = std::__countr_zero(window);
|
auto num_empty_slots = window == 0 ? 1 : (__ffs(window) - 1);
|
||||||
num_empty_slots = num_empty_slots == 32 ? 1 : num_empty_slots;
|
st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots);
|
||||||
st_release_cta(rdma_send_channel_tail + lane_id, last_rdma_tail_idx + num_empty_slots);
|
|
||||||
window >>= num_empty_slots;
|
window >>= num_empty_slots;
|
||||||
}
|
}
|
||||||
rdma_send_channel_window[lane_id] = window;
|
rdma_send_channel_window[lane_id] = window;
|
||||||
|
|
||||||
// TODO: remove this assertion after debugging
|
|
||||||
EP_DEVICE_ASSERT((window >> offset) & 1);
|
|
||||||
|
|
||||||
// Release lock
|
// Release lock
|
||||||
release_lock(rdma_send_channel_lock + lane_id);
|
release_lock(rdma_send_channel_lock + lane_id);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user