Use TMA instead of LD/ST for intra-node normal kernels (#191)

* Update CMake files

* Use TMA instead of LD/ST for intranode dispatch

* Use TMA instead of LD/ST for intranode combine

* Adjust configs

* Test default configs as well

* More warps for combine

* Add inter-thread fence

* Enable more warps

* Do not use TMA for senders

* Update configs

* Remove useless wait
This commit is contained in:
Chenggang Zhao
2025-06-06 15:40:17 +08:00
committed by GitHub
parent df4debe30c
commit c8dceba110
6 changed files with 230 additions and 87 deletions

View File

@@ -174,7 +174,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
}
template <int kNumRanks, int kNumThreads>
template <int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
__global__ void __launch_bounds__(kNumThreads, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
@@ -183,7 +183,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
const bool is_sender = sm_id % 2 == 0;
EP_DEVICE_ASSERT(num_sms % 2 == 0);
@@ -232,19 +232,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto half_hidden_int4 = hidden_int4 / 2;
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + half_hidden_bytes);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
if (is_sender) {
// Workers for sending
constexpr int num_send_warps = kNumThreads / 32;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
const auto send_thread_id = thread_id;
const auto send_lane_id = send_thread_id % 32;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
EP_DEVICE_ASSERT(kNumRanks <= 32);
EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens
if (send_lane_id == 0 and send_warp_id_in_rank == 0) {
if (lane_id == 0 and send_warp_id_in_rank == 0) {
int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0;
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
@@ -262,7 +276,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
// NOTES: the head index received by different warps may not be the same
auto start_time = clock64();
while (send_lane_id == 0) {
while (lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
@@ -279,7 +293,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
// NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data
if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1;
// Skip if not selected
@@ -294,30 +308,29 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Copy data
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
__ldg, st_na_global);
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global);
// Copy source index
if (send_lane_id == 0)
if (lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
// Copy `topk_idx` and `topk_weights` with transformed index
if (send_lane_id < num_topk) {
if (lane_id < num_topk) {
// Top-k index
int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank;
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id);
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;
// Top-k weights
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = weight_value;
}
// Copy `x_scales`
#pragma unroll
for (int i = send_lane_id; i < num_scales; i += 32)
for (int i = lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
}
@@ -328,7 +341,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Move tail index
// NOTES: here all warps should share the same new tail
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
if (send_warp_id_in_rank == 0 and send_lane_id == 0)
if (send_warp_id_in_rank == 0 and lane_id == 0)
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
}
} else {
@@ -336,7 +349,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
constexpr int num_recv_warps = kNumThreads / 32;
constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
const auto recv_thread_id = thread_id;
const auto recv_lane_id = recv_thread_id % 32;
const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank;
const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32;
EP_DEVICE_ASSERT(kNumRanks <= 32);
@@ -348,9 +360,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
// Receive channel offset
int total_offset, num_tokens_to_recv;
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
if (recv_lane_id == 0) {
while (lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while (lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
if (lane_id == 0) {
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
if (recv_warp_id_in_rank == 0)
recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset;
@@ -393,8 +405,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
ld_nc_global, st_na_global);
#pragma unroll
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
tma_store_wait();
tma_load_1d(tma_buffer, shifted_buffer_x_int4 + i * half_hidden_int4, tma_mbarrier, half_hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes);
mbarrier_wait(tma_mbarrier, tma_phase);
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
}
__syncwarp();
}
// Copy `src_idx`
@@ -426,12 +445,16 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
cached_channel_head_idx += num_recv_tokens;
total_offset += num_recv_tokens;
asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank));
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0)
if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
// Exit
num_tokens_to_recv -= num_recv_tokens;
}
// Make TMA store visible to the next kernel
if (lane_id == 0)
tma_store_wait();
}
}
@@ -441,17 +464,22 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 512;
constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 8192;
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#define DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, dispatch<ranks, kNumThreads>, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
break
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
} break
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(num_sms % 2 == 0);
@@ -532,7 +560,7 @@ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
#undef CACHED_NOTIFY_COMBINE
}
template<typename dtype_t, int kNumRanks, int kNumThreads>
template<typename dtype_t, int kNumRanks, int kNumThreads, int kNumTMABytesPerWarp>
__global__ void __launch_bounds__(kNumThreads, 1)
combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights,
@@ -542,7 +570,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();
const auto num_channels = num_sms / 2;
const bool is_sender = sm_id % 2 == 0;
const int responsible_channel = sm_id / 2;
@@ -553,16 +581,21 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto x_int4 = reinterpret_cast<const int4*>(x);
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
if (is_sender) {
// Workers for sending
// Several warps are responsible for a single rank
constexpr int num_send_warps = kNumThreads / 32;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks;
constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks;
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
const auto send_thread_id = thread_id;
const auto send_lane_id = send_thread_id % 32;
const auto send_warp_id = send_thread_id / 32;
const auto send_rank_id = thread_id / num_threads_per_rank;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
// Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
@@ -595,7 +628,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64();
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
while (send_lane_id == 0) {
while (lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
@@ -618,22 +651,22 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// Copy data
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
UNROLLED_WARP_COPY(4, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Send source index
if (send_lane_id == 0)
if (lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
// Send `topk_weights`
if (num_topk > 0 and send_lane_id < num_topk)
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
if (num_topk > 0 and lane_id < num_topk)
channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + lane_id);
}
token_idx += num_round_tokens;
current_channel_tail_idx += num_round_tokens;
// Move tail index
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
if (lane_id == 0 and send_warp_id_in_rank == 0)
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
}
} else {
@@ -641,7 +674,6 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
// One warp for moving the queue head, others for reduction
constexpr int num_recv_warps = kNumThreads / 32;
const auto recv_warp_id = thread_id / 32;
const auto recv_lane_id = thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
@@ -651,19 +683,19 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
__shared__ volatile bool warp_retired[num_recv_warps];
if (thread_id < num_recv_warps)
warp_retired[thread_id] = false;
if (recv_lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
if (lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][lane_id] = 0;
if (thread_id < kNumRanks)
channel_tail_idx[thread_id] = 0;
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
if (thread_id < 32) {
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater
int last_head = 0;
while (recv_lane_id < kNumRanks) {
while (lane_id < kNumRanks) {
// Check retired
bool retired = true;
#pragma unroll
@@ -673,13 +705,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
break;
// Update queue tail
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
// Update minimum head
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
min_head = min(min_head, warp_channel_head_idx[i][lane_id]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
}
@@ -716,11 +748,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
// Read expected head
int expected_head = -1;
if (recv_lane_id < kNumRanks)
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
if (lane_id < kNumRanks)
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
auto start_time = clock64();
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
@@ -740,9 +772,16 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
}
}
// Reduce data
// Wait shared memory release
if (lane_id == 0)
tma_store_wait();
__syncwarp();
// Reduce data with pipeline
constexpr int kNumStages = 8;
EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, "Invalid count");
#pragma unroll
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
for (int i = lane_id; i < hidden_int4; i += 32) {
// Read buffers
int4 recv_value_int4[kNumRanks];
#pragma unroll
@@ -759,33 +798,55 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
// Cast back to `dtype_t`
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
recv_int4[token_idx * hidden_int4 + i] = out_int4;
// Wait TMA arrival
if (lane_id == 0)
tma_store_wait<kNumStages - 1>();
__syncwarp();
// Write into TMA buffer
auto tma_stage_idx = (i / 32) % kNumStages;
reinterpret_cast<int4*>(tma_buffer)[tma_stage_idx * 32 + lane_id] = out_int4;
// Issue TMA
tma_store_fence();
__syncwarp();
if (lane_id == 0) {
auto tma_bytes = min(32, hidden_int4 - i) * static_cast<int>(sizeof(int4));
tma_store_1d(reinterpret_cast<int4*>(tma_buffer) + tma_stage_idx * 32,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
}
__syncwarp();
}
// Reduce `topk_weights`
if (recv_lane_id < num_topk) {
if (lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + lane_id);
recv_topk_weights[token_idx * num_topk + lane_id] = value;
}
// Update head
if (recv_lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
if (lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
}
// Retired
__syncwarp();
if (recv_lane_id == 0)
if (lane_id == 0)
warp_retired[recv_warp_id] = true;
// Make TMA store visible to the next kernel
if (lane_id == 0)
tma_store_wait();
}
}
}
@@ -799,15 +860,20 @@ void combine(cudaDataType_t type,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 4096;
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
num_max_send_tokens, num_recv_buffer_tokens); } \
break
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break