Optimize performance

This commit is contained in:
Chenggang Zhao 2025-04-22 17:48:11 +08:00
parent 59884211ea
commit 07ef809d82

View File

@ -249,6 +249,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Preload TMA multicast validity, encouraged to use unified registers
bool is_tma_multicast_valid = __shfl_sync(0xffffffff, scheduler.is_tma_multicast_valid(m_block_idx), 0);
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
@ -276,7 +279,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if (kNumTMAMulticast == 1 or not scheduler.is_tma_multicast_valid(m_block_idx)) {
if (kNumTMAMulticast == 1 or not is_tma_multicast_valid) {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();