From cd5c57fb2af52655b3490797f0a985784d2cc8cf Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 20 Jun 2025 11:15:03 +0800 Subject: [PATCH] Fix compilation --- csrc/kernels/internode.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 5327f8f..9a98887 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -713,7 +713,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Iterate over every token from the RDMA buffer for (int i = cached_rdma_channel_head; i < cached_rdma_channel_tail; ++ i) { auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; - auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; auto src_meta = ld_nc_global(reinterpret_cast(static_cast(shifted) + hidden_bytes)); // TODO: load into shared memory (only read once) @@ -726,7 +726,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Skip not selected tokens if (not is_in_dst_nvl_rank) { - kCachedMode ? (*shifted_send_nvl_head = -1) : 0; + if constexpr (kCachedMode) + *shifted_send_nvl_head = -1; continue; } @@ -734,7 +735,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv int dst_slot_idx; if (lane_id == 0) { dst_slot_idx = atomicAdd_block(const_cast(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1); - kCachedMode ? (*shifted_send_nvl_head = dst_slot_idx) : 0; + if constexpr (kCachedMode) + *shifted_send_nvl_head = dst_slot_idx; while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens) cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); } @@ -779,7 +781,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Move tail index __syncwarp(); if (lane_id == 0) - st_release_cta(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank, dst_slot_idx + 1); + st_release_cta(const_cast(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank), dst_slot_idx + 1); __syncwarp(); } }