From 0337732dc1310d24dbe54ac928da3f685433c0a0 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 08:09:02 +0800 Subject: [PATCH] reorg --- csrc/flash_fwd_mla_kernel.h | 23 +++++++++++++++-------- csrc/fp8_transpose_v.h | 7 ++----- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 6dff13a..b4f3ed7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -150,10 +150,11 @@ struct Flash_fwd_kernel_traits_mla { // ------ for f8 ------ - using SmemLayoutVtMMa = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int >{})); - using SmemFp8Tranpose = SmemTransposeFp8_64x64; + using SmemFp8Tranpose = SmemTransposeFp8_64x64; + // using SmemLayoutVtMMa = decltype(tile_to_shape( + // getSmemLayoutK(), + // Shape, Int >{})); + using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; }; namespace flash { @@ -292,9 +293,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - auto sVt = cute::conditional_return( - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); + auto sVt = [&](){ + if constexpr(Kernel_traits::Is_FP8){ + return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + } + }(); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); @@ -381,8 +386,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (Kernel_traits::Is_FP8) { cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + __syncthreads(); } - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK @@ -504,6 +509,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::copy(tPsP, tOrP); flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) __syncthreads(); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h index 020944b..f1a1919 100644 --- a/csrc/fp8_transpose_v.h +++ b/csrc/fp8_transpose_v.h @@ -1,13 +1,10 @@ #pragma once -template +template struct SmemTransposeFp8_64x64 { - static_assert(sizeof(Element) == 1); static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); - using SmemLayoutK = decltype(tile_to_shape( - GMMA::Layout_K_SW64_Atom{}, - Shape, Int>{})); + using Element = cutlass::float_e4m3_t; using SmemLayoutV = decltype(composition( SmemLayoutK{}, Layout, Int>, Stride<_1, Int>>{}));