Make partition pipelined

This commit is contained in:
Chenggang Zhao 2025-04-10 18:07:25 +08:00
parent 5bda27244b
commit a77009cb14
2 changed files with 33 additions and 33 deletions

View File

@ -350,48 +350,38 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
}
}, num_former_iters);
// Write back to shared memory using STSM
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}),
__float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16)
);
}
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
);
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
for (uint32_t partition_idx = 0; partition_idx < 2; ++ partition_idx) {
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
auto casted = __float22bfloat162_rn({shifted_accum[i * 4 + partition_idx * 2 + 0],
shifted_accum[i * 4 + partition_idx * 2 + 1]});
auto smem_ptr = smem_d + i * 8;
smem_ptr += (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
SM90_U32x1_STSM_N<nv_bfloat162>::copy(casted, smem_ptr);
}
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
if (n_block_idx < SHAPE_N / BLOCK_N) {
// Except the last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
} else {
// The last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
// Issue TMA store
cute::tma_store_fence();
if (lane_idx < 8) {
auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
auto gmem_ptr = gmem_d + (m_block_idx * BLOCK_M + m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16));
}
__syncwarp();
}
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
__syncwarp();
// Wait TMA to be finished
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
#else

View File

@ -867,6 +867,16 @@ struct SM90_64x192x32_F32E4M3E4M3_SS {
static constexpr int kNumAccum = M * N / 128;
};
template <typename dtype_t>
struct SM90_U32x1_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, void* smem_dst) {
const uint32_t src[1] = {*reinterpret_cast<uint32_t*>(&src_0)};
asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n"
:: "l"(smem_dst), "r"(src[0]));
}
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void