From 6dcea4952c6d1a1c7e7ae74c8f1018cbf0117d3f Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 18:37:36 +0800 Subject: [PATCH] add TransV --- csrc/flash_fwd_mla_kernel.h | 76 ++++++++++++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 07a4a64..79d6ba7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -276,10 +276,14 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + + auto sV = cute::conditional_return( + cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{})), + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{})); + auto sVt = cute::conditional_return( - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); @@ -377,6 +381,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::copy(softmax.row_sum, tRow_sumsRow_sum); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); } else { + const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int cur_block_table = __ldg(&block_table[n_block]); @@ -412,7 +417,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK constexpr int sK_offset = size(sK); tKsK.data() = tKsK.data() + sK_offset; - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + else { + sV.data() = sV.data() + sK_offset; + } } // We need to clear the sK smem tiles because K is V. @@ -445,6 +455,58 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } if constexpr (Kernel_traits::Is_FP8) { + auto TransV = [&]() { + // refer to fa3's TransV: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L697 + using LDSM_divide_shape = Shape<_64, _8>; + using S2RTiledCopyVt = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Stride<_4, _1, _0, _0>>{}, // thread layout + Layout, Stride<_1, _2, _16, _4>>{} // val layout + )); + + using STSM_divide_shape = Shape<_8, _16>; + using R2STiledCopyV = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Stride<_4, _1, _32, _0>>{}, // thread layout + Layout, Stride<_0, _1, _4, _8>>{} // val layout + )); + + S2RTiledCopyVt s2r_tiled_copy_vt; + R2STiledCopyV r2s_tiled_copy_v; + auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx); + auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx); + // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) + // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64)) + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) + CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); + + static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; + Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_)>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) + Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_)>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) + #pragma unroll + for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { + Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}))); + static_assert(size<0>(tTransrV) == 16); + Tensor tTransrV_64 = recast(tTransrV); + cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i)), tTransrV); + #pragma unroll + for (int j = 0; j < size(tTransrV_64); ++j) { + uint32_t upper = tTransrV_64[j].x; + uint32_t lower = tTransrV_64[j].y; + tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); + tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + } + cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i))); + } + }; + + TransV(); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); } @@ -468,7 +530,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } else { + sV.data() = sV.data() + sK_offset; + } } cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady));