From 340d9880f4a418d943d34260d20a79f41f4c0526 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 18 Apr 2025 11:18:23 +0800 Subject: [PATCH] Overlap TMA store --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index cfcb569..5283344 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -364,6 +364,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, DG_STATIC_ASSERT(static_cast(kSwizzleDMode > 0) + static_cast(BLOCK_N_PADDING > 0) <= 1, "Swizzling and padding are not compatible"); + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll @@ -424,10 +428,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_block_idx * BLOCK_N + in_block_n_offset, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - - // Wait TMA to be finished cute::tma_store_arrive(); - cute::tma_store_wait<0>(); } __syncwarp(); }