diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 79d6ba7..12e9883 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -14,18 +14,29 @@ using namespace cute; #include "flash_mla.h" -template +template constexpr auto getSmemLayoutK() { constexpr int headSizeBytes = sizeof(PrecType) * DIM; constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; - if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { - return GMMA::Layout_K_SW64_Atom{}; + if constexpr (major == GMMA::Major::K) { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } } else { - return GMMA::Layout_K_SW32_Atom{}; + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else { + return GMMA::Layout_MN_SW32_Atom{}; + } } + } template @@ -75,11 +86,16 @@ struct Flash_fwd_kernel_traits_mla { getSmemLayoutK(), Shape, Int>{})); + // ------ for f16 ------ using SmemLayoutV = decltype(tile_to_shape( getSmemLayoutK(), Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + // ------ for f8 ------ + using SmemLayoutVtLoad = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); using SmemLayoutVtMMa = decltype(tile_to_shape( getSmemLayoutK(), Shape, Int >{})); @@ -135,11 +151,6 @@ struct Flash_fwd_kernel_traits_mla { Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>>{})); // Val layout, 4 vals per store - - - - // for fp8 trans-v - }; namespace flash { @@ -278,7 +289,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); auto sV = cute::conditional_return( - cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{})), + cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtLoad{})), make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{})); auto sVt = cute::conditional_return( @@ -476,9 +487,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx); auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx); // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8) - Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sV, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64)) - Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sVt, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));