Fix send heads

This commit is contained in:
Chenggang Zhao
2025-06-19 18:05:59 +08:00
parent 55bbd8caaf
commit 177e491e92

View File

@@ -432,7 +432,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
// Forward warp synchronization
// TODO: allocate some shared memory buffers
__shared__ volatile int forward_channel_nvl_tail_allocator[NUM_MAX_NVL_PEERS];
__shared__ volatile int forward_channel_nvl_tail[kNumRDMARanks][NUM_MAX_NVL_PEERS];
__shared__ volatile bool forward_channel_retired[kNumRDMARanks];
@@ -736,18 +735,24 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Try to send to different NVL ranks
// TODO: shuffle the destination ranks
for (int dst_nvl_rank = 0; dst_nvl_rank < NUM_MAX_NVL_PEERS; ++ dst_nvl_rank) {
if (not src_meta.is_token_in_nvl_rank(dst_nvl_rank))
const bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
auto shifted_send_nvl_head = send_nvl_head + ((src_rdma_channel_prefix + i) * NUM_MAX_NVL_PEERS + dst_nvl_rank);
// Skip not selected tokens
if (not is_in_dst_nvl_rank) {
kCachedMode ? (*send_nvl_head = -1) : 0;
continue;
}
// Allocate a slot tail and wait until that we can ensure the slot is safe to overwrite
int dst_slot_idx;
auto shifted_send_nvl_head = send_nvl_head + (src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank);
if (lane_id == 0) {
dst_slot_idx = atomicAdd_block(const_cast<int*>(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1);
kCachedMode ? (*send_nvl_head = dst_slot_idx) : 0;
while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
}
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, 0);
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx % num_max_nvl_chunked_recv_tokens, 0);
// Copy data
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,