From f8797d3c127fb7d5402f6eb022b51a32547cecee Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 14 Apr 2025 11:29:28 +0800 Subject: [PATCH] Optimize TMA issues --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 60 +++++++++++------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index ced5e03..bf5e0ff 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -354,7 +354,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // TMA checks constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); - constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes; + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; if constexpr (kSwizzleDMode > 0) { DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom"); @@ -384,8 +384,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); // Reshape the atom in another view and swizzle - // - original: `(WGMMA_M_PER_WARP, kSwizzleDMode / kNumBankGroupBytes)` - // - new: `(WGMMA_M_PER_WARP * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); @@ -393,13 +393,15 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Add back into the base pointer // NOTES: think twice before modifying this, as changes may affect the number of instructions - smem_ptr = reinterpret_cast(smem_d) + // Base pointer - warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset - m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset - atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants) - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset } else { // No swizzling, just padding + // NOTES: padding must be zero for BF16 output + DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output"); smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); } @@ -410,29 +412,25 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, smem_ptr ); } - - // Issue TMA store - cute::tma_store_fence(); - if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) { - auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP; - auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N; - auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + lane_idx * TMA_D_BLOCK_N * WGMMA_M_PER_WARP; - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - n_block_idx * BLOCK_N + in_block_n_offset, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset); - } else if (kSwizzleDMode == 0 and lane_idx < WGMMA_M_PER_WARP) { - uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); - auto smem_ptr = smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING); - auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N; - auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N; - cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16)); - } - __syncwarp(); } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - // Wait TMA to be finished - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + + // Wait TMA to be finished + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + __syncwarp(); } } #else @@ -519,11 +517,9 @@ public: // Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes // So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required - // TODO: try `FP8MMASelector::type::M` (warpgroup-level TMA) return make_2d_tma_desc(global_address, Layout::RowMajor, shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N, - FP8MMASelector::type::M / 4, - kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), + BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T), swizzle_mode); }