Swizzling draft

This commit is contained in:
Chenggang Zhao 2025-04-11 17:19:59 +08:00
parent 76804c096d
commit 4c111418a2

View File

@ -352,6 +352,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
}
}, num_former_iters);
// TMA checks
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode / sizeof(nv_bfloat16);
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
DG_STATIC_ASSERT(WGMMA_M_PER_WARP % 8 == 0, "Invalid swizzling atom");
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
"Swizzling and padding are not compatible");
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
@ -360,19 +370,52 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the base pointer of the swizzling atom
constexpr int kNumBankGroupBytes = 16;
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + in_block_m_offset * BLOCK_N) + atom_offset * kNumBankGroupBytes;
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(WGMMA_M_PER_WARP, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(WGMMA_M_PER_WARP * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
auto row = bank_group_index / 8, col = bank_group_index % 8;
col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer
smem_ptr += row * kNumBankGroupBytes * 8 + col * kNumBankGroupBytes;
} else {
// No swizzling, just padding
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8
smem_ptr
);
}
// Issue TMA store
cute::tma_store_fence();
if (lane_idx < 16) {
if (kSwizzleDMode > 0 and lane_idx < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_m_offset = m_offset + warp_idx * WGMMA_M_PER_WARP;
auto in_block_n_offset = lane_idx * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_m_offset * BLOCK_N + in_block_n_offset;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx) + in_block_m_offset);
} else if (kSwizzleDMode == 0 and lane_idx < WGMMA_M_PER_WARP) {
uint64_t gmem_m_offset = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx);
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * 16 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
auto smem_ptr = smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
auto gmem_ptr = gmem_d + (gmem_m_offset + m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
auto num_valid_cols = (n_block_idx == ceil_div(SHAPE_N, BLOCK_N) - 1) ? (SHAPE_N - n_block_idx * BLOCK_N) : BLOCK_N;
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, num_valid_cols * sizeof(nv_bfloat16));
}
@ -468,9 +511,11 @@ public:
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
// TODO: try `FP8MMASelector<BLOCK_N>::type::M` (warpgroup-level TMA)
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
min(BLOCK_M, shape_m), kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
FP8MMASelector<BLOCK_N>::type::M / 4,
kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
swizzle_mode);
}