From f4b205bfa3e4d0a109450d0f8f4d8af7170784a6 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 23 Apr 2025 13:32:19 +0800 Subject: [PATCH] Minor fixes --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index d557f1e..a019594 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -226,8 +226,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait unaligned cases #pragma unroll for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->arrive(); empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); } }, 0); } @@ -249,9 +249,6 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Preload TMA multicast validity, encouraged to use unified registers - bool is_tma_multicast_valid = __shfl_sync(0xffffffff, scheduler.is_tma_multicast_valid(m_block_idx), 0); - // Decide the number of scales B to load DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; @@ -279,8 +276,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { - if (kNumTMAMulticast == 1 or not is_tma_multicast_valid) { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive() : void(); + if (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); } else { lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); }