Fix compilation

This commit is contained in:
Chenggang Zhao
2025-06-20 11:15:03 +08:00
parent 49b9084268
commit cd5c57fb2a

View File

@@ -713,7 +713,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate over every token from the RDMA buffer
for (int i = cached_rdma_channel_head; i < cached_rdma_channel_tail; ++ i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(static_cast<int8_t*>(shifted) + hidden_bytes));
// TODO: load into shared memory (only read once)
@@ -726,7 +726,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Skip not selected tokens
if (not is_in_dst_nvl_rank) {
kCachedMode ? (*shifted_send_nvl_head = -1) : 0;
if constexpr (kCachedMode)
*shifted_send_nvl_head = -1;
continue;
}
@@ -734,7 +735,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int dst_slot_idx;
if (lane_id == 0) {
dst_slot_idx = atomicAdd_block(const_cast<int*>(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1);
kCachedMode ? (*shifted_send_nvl_head = dst_slot_idx) : 0;
if constexpr (kCachedMode)
*shifted_send_nvl_head = dst_slot_idx;
while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
}
@@ -779,7 +781,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Move tail index
__syncwarp();
if (lane_id == 0)
st_release_cta(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank, dst_slot_idx + 1);
st_release_cta(const_cast<int*>(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank), dst_slot_idx + 1);
__syncwarp();
}
}