Optimize swizzle performance

This commit is contained in:
Chenggang Zhao 2025-04-14 10:10:46 +08:00
parent 5ff0eb24b5
commit 9406e2a3a1

View File

@ -353,7 +353,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
}, num_former_iters); }, num_former_iters);
// TMA checks // TMA checks
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16); constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / kNumElemBytes;
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
if constexpr (kSwizzleDMode > 0) { if constexpr (kSwizzleDMode > 0) {
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom"); DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
@ -373,13 +374,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
#pragma unroll #pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address // Swizzle or padding into the correct address
uint8_t* smem_ptr; uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) { if constexpr (kSwizzleDMode > 0) {
// Calculate the base pointer of the swizzling atom // Calculate the swizzling atom offset and in-atom offset
constexpr int kNumBankGroupBytes = 16; constexpr int kNumBankGroupBytes = 16;
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode;
// Calculate the index of the bank group to be written in the atom // Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
@ -391,7 +390,12 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
col ^= row % (kSwizzleDMode / 16); col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer // Add back into the base pointer
smem_ptr += row * kNumBankGroupBytes * 8 + col * kNumBankGroupBytes; // NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * BLOCK_N * kNumElemBytes) + // Warp offset
m_offset * (BLOCK_N * kNumElemBytes) + // Wave offset
atom_offset * WGMMA_M_PER_WARP * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else { } else {
// No swizzling, just padding // No swizzling, just padding
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8); smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);