From d1689ab64f3d7db5e5e7f4f068c7a1b0679e253c Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Thu, 27 Feb 2025 10:56:43 +0800 Subject: [PATCH] use mm1's Aregs instead of mma0's Cregs --- csrc/flash_fwd_mla_kernel.h | 24 ++++++++++++++---------- csrc/utils.h | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 3b3cd9c..0c575c7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -100,7 +100,11 @@ struct Flash_fwd_kernel_traits_mla { getSmemLayoutK(), Shape, Int >{})); - using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutP = std::conditional_t< + Is_FP8, + Layout, Int, _1, _2, Int>>, + Layout, Int, _1, _2, Int>> + >; using SmemLayoutRow = Layout>, Stride<_1, _2>>; using SmemLayoutAtomO = decltype(composition( @@ -297,7 +301,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); @@ -368,8 +372,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f : softmax.template softmax(tSrS, params.scale_softmax_log2); if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + + cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0 cute::copy(scale_o, tScale_osScale_o); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); @@ -380,7 +387,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); } - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK @@ -529,15 +535,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); + auto tSrS_layout = flash::convert_layout_acc_Aregs(partition_fragment_C(tiled_mma, Shape, Int>{}).layout()); + Tensor tOrP = make_tensor(tSrS_layout); Tensor scale_o = make_tensor(Shape<_2>{}); cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); + cute::copy(tPsP, tOrP); flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK diff --git a/csrc/utils.h b/csrc/utils.h index 854c75f..716c50c 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -255,4 +255,20 @@ CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash