From 07ef809d82d6dfe727323267ace516403e580db2 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 22 Apr 2025 17:48:11 +0800 Subject: [PATCH] Optimize performance --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index aa39c16..97005fa 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -249,6 +249,9 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Preload TMA multicast validity, encouraged to use unified registers + bool is_tma_multicast_valid = __shfl_sync(0xffffffff, scheduler.is_tma_multicast_valid(m_block_idx), 0); + // Decide the number of scales B to load DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; @@ -276,7 +279,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { - if (kNumTMAMulticast == 1 or not scheduler.is_tma_multicast_valid(m_block_idx)) { + if (kNumTMAMulticast == 1 or not is_tma_multicast_valid) { lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive() : void(); } else { lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();