add transv barrier

This commit is contained in:
chenhongmin.will 2025-02-26 17:57:00 +08:00
parent 59f691763e
commit 6a4eb631e2
2 changed files with 9 additions and 0 deletions

View File

@ -360,6 +360,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
flash::rescale_o(tOrO, scale_o);
if constexpr (Kernel_traits::Is_FP8) {
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
}
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
@ -440,6 +444,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::cp_async_fence();
}
if constexpr (Kernel_traits::Is_FP8) {
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
if (n_block - 2 >= n_block_min) {

View File

@ -10,6 +10,7 @@ namespace flash {
enum class NamedBarriers {
SReady = 1,
SoftmaxReady = 2,
TransVReady = 3,
};
} // flash