Always use STSMx2

This commit is contained in:
Chenggang Zhao 2025-04-11 14:37:37 +08:00
parent 93c92c2c89
commit 76804c096d

View File

@ -359,20 +359,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8
);
}