diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 0c575c7..512fb9b 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -12,6 +12,7 @@ using namespace cute; #include "softmax.h" #include "static_switch.h" #include "flash_mla.h" +#include "fp8_transpose_v.h" template @@ -86,20 +87,11 @@ struct Flash_fwd_kernel_traits_mla { getSmemLayoutK(), Shape, Int>{})); - // ------ for f16 ------ using SmemLayoutV = decltype(tile_to_shape( getSmemLayoutK(), Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - // ------ for f8 ------ - using SmemLayoutVtLoad = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - using SmemLayoutVtMMa = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int >{})); - using SmemLayoutP = std::conditional_t< Is_FP8, Layout, Int, _1, _2, Int>>, @@ -155,6 +147,13 @@ struct Flash_fwd_kernel_traits_mla { Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>>{})); // Val layout, 4 vals per store + + + // ------ for f8 ------ + using SmemLayoutVtMMa = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int >{})); + using SmemFp8Tranpose = SmemTransposeFp8_64x64; }; namespace flash { @@ -292,10 +291,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - auto sV = cute::conditional_return( - cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtLoad{})), - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{})); - + auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); auto sVt = cute::conditional_return( make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); @@ -438,9 +434,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (!Kernel_traits::Is_FP8) { tOrVt.data() = tOrVt.data() + sK_offset / 8; } - else { - sV.data() = sV.data() + sK_offset; - } } // We need to clear the sK smem tiles because K is V. @@ -474,53 +467,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (Kernel_traits::Is_FP8) { auto TransV = [&]() { - // refer to fa3's TransV: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L697 - using LDSM_divide_shape = Shape<_64, _8>; - using S2RTiledCopyVt = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Stride<_4, _1, _0, _0>>{}, // thread layout - Layout, Stride<_1, _2, _16, _4>>{} // val layout - )); + using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose; + SmemFp8Tranpose smem_transpose_V; + Tensor sV_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{})); - using STSM_divide_shape = Shape<_8, _16>; - using R2STiledCopyV = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Stride<_4, _1, _32, _0>>{}, // thread layout - Layout, Stride<_0, _1, _4, _8>>{} // val layout - )); - - S2RTiledCopyVt s2r_tiled_copy_vt; - R2STiledCopyV r2s_tiled_copy_v; - auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx); - auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx); - // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8) - Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sV, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) - // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64)) - Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sVt, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) - CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); - - static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; - Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_)>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) - Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_)>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) - #pragma unroll - for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { - Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}))); - static_assert(size<0>(tTransrV) == 16); - Tensor tTransrV_64 = recast(tTransrV); - cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i)), tTransrV); - #pragma unroll - for (int j = 0; j < size(tTransrV_64); ++j) { - uint32_t upper = tTransrV_64[j].x; - uint32_t lower = tTransrV_64[j].y; - tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); - tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + if (n_block % 2 == 1) { + sV_divide.data() = sV_divide.data() + size(sK); + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) { + smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j))); } - cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i))); } }; @@ -548,8 +511,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); if constexpr (!Kernel_traits::Is_FP8) { tOrVt.data() = tOrVt.data() + sK_offset / 8; - } else { - sV.data() = sV.data() + sK_offset; } } diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h new file mode 100644 index 0000000..ba9e31e --- /dev/null +++ b/csrc/fp8_transpose_v.h @@ -0,0 +1,82 @@ +#pragma once + +template +struct SmemTransposeFp8_64x64 { + static_assert(sizeof(Element) == 1); + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using SmemLayoutK = decltype(tile_to_shape( + GMMA::Layout_K_SW64_Atom{}, + Shape, Int>{})); + using SmemLayoutV = decltype(composition( + SmemLayoutK{}, + Layout, Int>, Stride<_1, Int>>{})); + using TransposeShapeAtomV = Shape<_64, _64>; + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, + shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = decltype(tile_to_shape( + SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_8, _8>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), + shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _1, _2>; + using stsm_value_stride = Stride<_1, _8, _0, _4>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; +