Make partition pipelined

This commit is contained in:
Chenggang Zhao 2025-04-10 18:07:25 +08:00
parent 5bda27244b
commit a77009cb14
2 changed files with 33 additions and 33 deletions

View File

@ -350,50 +350,40 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
} }
}, num_former_iters); }, num_former_iters);
// Write back to shared memory using STSM // Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll #pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M; auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll #pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 8; ++ i) { for (uint32_t partition_idx = 0; partition_idx < 2; ++ partition_idx) {
SM90_U32x4_STSM_N<nv_bfloat162>::copy( // Store into shared memory
__float22bfloat162_rn({shifted_accum[i * 8 + 0], shifted_accum[i * 8 + 1]}), #pragma unroll
__float22bfloat162_rn({shifted_accum[i * 8 + 2], shifted_accum[i * 8 + 3]}), for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
__float22bfloat162_rn({shifted_accum[i * 8 + 4], shifted_accum[i * 8 + 5]}), auto casted = __float22bfloat162_rn({shifted_accum[i * 4 + partition_idx * 2 + 0],
__float22bfloat162_rn({shifted_accum[i * 8 + 6], shifted_accum[i * 8 + 7]}), shifted_accum[i * 4 + partition_idx * 2 + 1]});
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + i * 16 + 8 * (lane_idx / 16) auto smem_ptr = smem_d + i * 8;
); smem_ptr += (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
SM90_U32x1_STSM_N<nv_bfloat162>::copy(casted, smem_ptr);
} }
if constexpr (WGMMA::kNumAccum % 8 != 0) {
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 0], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 1]}),
__float22bfloat162_rn({shifted_accum[WGMMA::kNumAccum / 8 * 8 + 2], shifted_accum[WGMMA::kNumAccum / 8 * 8 + 3]}),
smem_d + (m_offset + warp_idx * 16 + lane_idx % 16) * (BLOCK_N + BLOCK_N_PADDING) + WGMMA::kNumAccum / 8 * 16
);
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory // Issue TMA store
if (threadIdx.x == 0) { cute::tma_store_fence();
if (n_block_idx < SHAPE_N / BLOCK_N) { if (lane_idx < 8) {
// Except the last unaligned block auto smem_ptr = smem_d + (m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * (BLOCK_N + BLOCK_N_PADDING);
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.first, smem_d, 0, n_block_idx, auto gmem_ptr = gmem_d + (m_block_idx * BLOCK_M + m_offset + warp_idx * 16 + partition_idx * 8 + lane_idx) * SHAPE_N + n_block_idx * BLOCK_N;
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); cute::SM90_BULK_COPY_S2G::copy(smem_ptr, gmem_ptr, BLOCK_N * sizeof(nv_bfloat16));
} else {
// The last unaligned block
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d.second, smem_d, 0, 0,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
}
cute::tma_store_arrive();
cute::tma_store_wait<0>();
} }
__syncwarp(); __syncwarp();
} }
} }
// Wait TMA to be finished
cute::tma_store_arrive();
cute::tma_store_wait<0>();
}
}
#else #else
if (blockIdx.x == 0 and threadIdx.x == 0) if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");

View File

@ -867,6 +867,16 @@ struct SM90_64x192x32_F32E4M3E4M3_SS {
static constexpr int kNumAccum = M * N / 128; static constexpr int kNumAccum = M * N / 128;
}; };
template <typename dtype_t>
struct SM90_U32x1_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, void* smem_dst) {
const uint32_t src[1] = {*reinterpret_cast<uint32_t*>(&src_0)};
asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n"
:: "l"(smem_dst), "r"(src[0]));
}
};
template <typename dtype_t> template <typename dtype_t>
struct SM90_U32x2_STSM_N { struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void __device__ __forceinline__ static void