From 9406e2a3a1b96c9eb8d7df7386f9a717b593b7e2 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 14 Apr 2025 10:10:46 +0800 Subject: [PATCH] Optimize swizzle performance --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 154073a..b4bce1c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -353,7 +353,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, }, num_former_iters); // TMA checks - constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16); + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes; constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; if constexpr (kSwizzleDMode > 0) { DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom"); @@ -373,13 +374,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // Swizzle or padding into the correct address - uint8_t* smem_ptr; + uint8_t* smem_ptr = nullptr; if constexpr (kSwizzleDMode > 0) { - // Calculate the base pointer of the swizzling atom + // Calculate the swizzling atom offset and in-atom offset constexpr int kNumBankGroupBytes = 16; - auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP; auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); - smem_ptr = reinterpret_cast(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode; // Calculate the index of the bank group to be written in the atom auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); @@ -391,7 +390,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, col ^= row % (kSwizzleDMode / 16); // Add back into the base pointer - smem_ptr += row * kNumBankGroupBytes * 8 + col * kNumBankGroupBytes; + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset + m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset } else { // No swizzling, just padding smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);