From 6a4eb631e2b0b7b8986f1485dcf90236ac41120a Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 17:57:00 +0800 Subject: [PATCH] add transv barrier --- csrc/flash_fwd_mla_kernel.h | 8 ++++++++ csrc/named_barrier.h | 1 + 2 files changed, 9 insertions(+) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 7493f79..07a4a64 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -360,6 +360,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f flash::rescale_o(tOrO, scale_o); + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); @@ -440,6 +444,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::cp_async_fence(); } + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); if (n_block - 2 >= n_block_min) { diff --git a/csrc/named_barrier.h b/csrc/named_barrier.h index cefa936..940c934 100644 --- a/csrc/named_barrier.h +++ b/csrc/named_barrier.h @@ -10,6 +10,7 @@ namespace flash { enum class NamedBarriers { SReady = 1, SoftmaxReady = 2, + TransVReady = 3, }; } // flash