Support statistics tensor for low-latency kernels (#196)

This commit is contained in:
Chenggang Zhao
2025-06-09 15:50:56 +08:00
committed by GitHub
parent 0d1a855d81
commit 5a2e37fa28
6 changed files with 27 additions and 3 deletions

View File

@@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,

View File

@@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
@@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
if (cumulative_local_expert_recv_stats != nullptr)
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
}
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
@@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
@@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
cumulative_local_expert_recv_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \