Add automatic warp count control for low-latency combine

This commit is contained in:
Chenggang Zhao
2025-06-16 11:42:04 +08:00
parent 632c81f1d7
commit 72beb15827

View File

@@ -389,8 +389,8 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
#undef DISPATCH_LAUNCH_CASE
}
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
template <int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(1024, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
@@ -400,16 +400,18 @@ combine(void* combined_x,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int* usage_flag, int phases, bool zero_copy) {
int* usage_flag,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
@@ -441,10 +443,10 @@ combine(void* combined_x,
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) +
const auto local_x = static_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout
@@ -452,7 +454,7 @@ combine(void* combined_x,
unpack2(layout, num_tokens_to_send, offset);
// Issue IBGDA send
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
@@ -473,9 +475,9 @@ combine(void* combined_x,
}
}
// Put finishing flag
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
// Put the finishing flag
EP_DEVICE_ASSERT(num_warps_per_group > 1);
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
@@ -497,7 +499,7 @@ combine(void* combined_x,
// Wait all ranks to arrive and notify usages
if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
} else if (sm_id == 0 and sub_warp_id == 1 and lane_id == 0) {
@@ -524,7 +526,7 @@ combine(void* combined_x,
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_type = reinterpret_cast<const int*>(static_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
@@ -541,7 +543,7 @@ combine(void* combined_x,
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
(static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
}
@@ -556,20 +558,21 @@ void combine(void* combined_x,
void* workspace, int* usage_flag,
int num_device_sms, cudaStream_t stream,
int phases, bool zero_copy) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = ceil_div(num_experts, kNumWarpGroups);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = ceil_div(num_experts, num_warp_groups);
// Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
auto atomic_clean_flag = static_cast<int*>(workspace);
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
auto combine_func = combine<hidden, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
@@ -580,6 +583,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
usage_flag, \
num_warp_groups, num_warps_per_group, \
phases, zero_copy); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);