mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix compilation
This commit is contained in:
@@ -713,7 +713,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Iterate over every token from the RDMA buffer
|
||||
for (int i = cached_rdma_channel_head; i < cached_rdma_channel_tail; ++ i) {
|
||||
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
|
||||
auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(static_cast<int8_t*>(shifted) + hidden_bytes));
|
||||
|
||||
// TODO: load into shared memory (only read once)
|
||||
@@ -726,7 +726,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Skip not selected tokens
|
||||
if (not is_in_dst_nvl_rank) {
|
||||
kCachedMode ? (*shifted_send_nvl_head = -1) : 0;
|
||||
if constexpr (kCachedMode)
|
||||
*shifted_send_nvl_head = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -734,7 +735,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
int dst_slot_idx;
|
||||
if (lane_id == 0) {
|
||||
dst_slot_idx = atomicAdd_block(const_cast<int*>(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1);
|
||||
kCachedMode ? (*shifted_send_nvl_head = dst_slot_idx) : 0;
|
||||
if constexpr (kCachedMode)
|
||||
*shifted_send_nvl_head = dst_slot_idx;
|
||||
while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
|
||||
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
|
||||
}
|
||||
@@ -779,7 +781,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Move tail index
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
st_release_cta(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank, dst_slot_idx + 1);
|
||||
st_release_cta(const_cast<int*>(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank), dst_slot_idx + 1);
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user