From 9f361aa02e513d7040eca5737ffaff75252abec3 Mon Sep 17 00:00:00 2001 From: Gareth Jones Date: Sun, 23 Feb 2025 18:23:07 -0800 Subject: [PATCH] Stage accumulator fragment to shared memory using tiled copy --- csrc/flash_fwd_mla_kernel.h | 724 ++++++++++++++++++------------------ 1 file changed, 362 insertions(+), 362 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..a06cc03 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -13,10 +13,14 @@ using namespace cute; #include "static_switch.h" #include "flash_mla.h" - -template +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper: Decide K-Layout at SMEM level given type and dimension. +/// Swizzling is determined primarily by alignment constraints. +/// Return GMMA Layout at compile time. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template constexpr auto getSmemLayoutK() { - constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes = sizeof(PrecType) * DIM; constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { @@ -28,466 +32,462 @@ constexpr auto getSmemLayoutK() { } } -template -struct Flash_fwd_kernel_traits_mla { - using Element = elem_type; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Kernel Trait: FWD MLA for Flash Attention +/// - Templated on HeadDim (kHeadDim_), block tiling, warp usage, etc. +/// - Provides all necessary sub-layouts for Q/K/V, softmax partials, etc. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template < + int kHeadDim_, + int kBlockM_, + int kBlockN_, + int kNumWarps_, + typename ElemType = cutlass::bfloat16_t, + int kHeadDimV_ = 0 +> +struct FlashFwdKernelTraitsMLA { + using Element = ElemType; using ElementAccum = float; - using index_t = int64_t; + using IndexT = int64_t; - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - static constexpr int kNWarpsS = 4; - static constexpr int kNThreadsS = kNWarpsS * 32; + // Warp organization + static constexpr int kNumWarps = kNumWarps_; + static constexpr int kNumThreads = kNumWarps * 32; + static constexpr int kNumWarpsSoftmax = 4; + static constexpr int kNumThreadsSoftmax = kNumWarpsSoftmax * 32; - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; + // Tiling in M, N, K + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); - static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + + // Possibly distinct V-dimension + static constexpr int kHeadDimV = (kHeadDimV_ != 0) ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + // SMEM swizzling for partial K/V + static constexpr int kBlockKSmem = (kHeadDim % 64 == 0) ? 64 : 32; + static constexpr int kSwizzle = (kBlockKSmem == 32) ? 2 : 3; + + // GMMA Tiled Mma + // Q*K -> S using TiledMma = decltype(make_tiled_mma( - cute::GMMA::ss_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::K>(), - Layout, _1, _1>>{})); + cute::GMMA::ss_op_selector< + Element, Element, ElementAccum, + Shape, Int, Int>, + GMMA::Major::K, GMMA::Major::K + >(), + Layout, _1, _1>>{} + )); - static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + // S*V -> O + // For the O “outer product,” we define the shape in [M, HeadDimV, N]. + static constexpr int AtomLayoutNO = kNumThreads / kNumThreadsSoftmax; using TiledMmaO = decltype(make_tiled_mma( - cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), - Layout, Int, _1>>{})); + cute::GMMA::rs_op_selector< + Element, Element, ElementAccum, + Shape, Int, Int>, + GMMA::Major::K, GMMA::Major::MN + >(), + Layout, Int, _1>>{} + )); - using SmemLayoutQ = decltype(tile_to_shape( + //////////////////////////////////////////////////////////////////////////////////////////////////// + /// SMEM Layout definitions: Q/K/V, P, row-scale, etc. + //////////////////////////////////////////////////////////////////////////////////////////////////// + using SmemLayoutQ = decltype( + tile_to_shape( getSmemLayoutK(), - Shape, Int>{})); + Shape, Int>{} + ) + ); - using SmemLayoutK = decltype(tile_to_shape( + using SmemLayoutK = decltype( + tile_to_shape( getSmemLayoutK(), - Shape, Int>{})); + Shape, Int>{} + ) + ); - using SmemLayoutV = decltype(tile_to_shape( + using SmemLayoutV = decltype( + tile_to_shape( getSmemLayoutK(), - Shape, Int>{})); - using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + Shape, Int>{} + ) + ); + using SmemLayoutVtransposed = decltype( + composition( + SmemLayoutV{}, + make_layout( + Shape, Int>{}, + GenRowMajor{} + ) + ) + ); - using SmemLayoutP = Layout, Int, _1, Int>>; - using SmemLayoutRow = Layout>, Stride<_1, _2>>; + // For partial S data (softmax region) + using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; - using SmemLayoutAtomO = decltype(composition( + // Layout for the O tile in smem + using SmemLayoutAtomO = decltype( + composition( Swizzle{}, - Layout, Int>, Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( + Layout, Int>, Stride, _1>>{} + ) + ); + using SmemLayoutO = decltype( + tile_to_shape( SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + Shape, Int>{} + ) + ); - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + //////////////////////////////////////////////////////////////////////////////////////////////////// + /// Copy Atoms for SMEM read/write + //////////////////////////////////////////////////////////////////////////////////////////////////// + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + /// GMEM Tiled Copies for Q/K/V + //////////////////////////////////////////////////////////////////////////////////////////////////// + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must align with vector load size"); static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; - static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; - static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemCopyStruct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNumThreadsLoad = kNumThreads - kNumThreadsSoftmax; + static_assert(kNumThreadsLoad % kGmemThreadsPerRow == 0, "Thread counts must match row partitions"); using GmemLayoutAtom = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopy = decltype(make_tiled_copy( - Copy_Atom{}, + Shape, Int>, + Stride, _1> + >; + using GmemTiledCopy = decltype( + make_tiled_copy( + Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>{} // 8 vals per read + ) + ); + // For storing O to GMEM using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyO = decltype(make_tiled_copy( + Shape, Int>, + Stride, _1> + >; + using GmemTiledCopyO = decltype( + make_tiled_copy( Copy_Atom, Element>{}, GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store + Layout>{} // 8 vals per store + ) + ); - static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + // For accumulation path (split) + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; using GmemLayoutAtomOaccum = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Shape, Int>, + Stride, _1> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store + Layout>{} // 4 vals per store + ) + ); }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Shared Storage Container for MLA +/// - Re-used union across Q/K/P/O or row sums, etc. +//////////////////////////////////////////////////////////////////////////////////////////////////// namespace flash { using namespace cute; -template +template struct SharedStorageMLA { union { struct { - cute::array_aligned> smem_q; - cute::array_aligned * 2> smem_k; // Double buffer - cute::array_aligned> smem_p; - cute::array_aligned> smem_scale; + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // double buffer + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; }; struct { - cute::array_aligned> smem_max; - cute::array_aligned> smem_sum; - cute::array_aligned> smem_o; + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; }; }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; +/// store() Epilogue for partial or non-partial results +/// - Manages writing O/accumulation to global memory + writing out LSE for row block. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename KernelTraits, + bool Split, + typename SharedStorage, + typename AccO, + typename Softmax +> +__forceinline__ __device__ +void store( + const Flash_fwd_mla_params ¶ms, + const int batch_id, + const int head_id, + const int m_block, + const int n_split_idx, + SharedStorage &shared_storage, + AccO tOrO, + Softmax softmax +) { + constexpr int kBlockM = KernelTraits::kBlockM; + constexpr int kHeadDimV = KernelTraits::kHeadDimV; + constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax; + using Element = typename KernelTraits::Element; + using ElementAccum = typename KernelTraits::ElementAccum; + using IndexT = typename KernelTraits::IndexT; const int tidx = threadIdx.x; - typename Kernel_traits::TiledMmaO tiled_mma_o; + typename KernelTraits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - // Epilogue - - const int split_offset = __ldg(params.num_splits_ptr + bidb); - - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + // Softmax LSE for final normalization + auto lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + // Decide if writing ephemeral partial results (float accumulation) or final (Element). using ElementO = std::conditional_t; - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum - >; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + + // Prepare SMEM for O + Tensor sOaccum = make_tensor( + make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), + typename KernelTraits::SmemLayoutO{} + ); + auto smem_tiled_copy_Oaccum = make_tiled_copy_C( + std::conditional_t{}, + tiled_mma_o + ); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(tOrO); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_Oaccum.retile_S(rO); + Tensor taccOsO = smem_thr_copy_Oaccum.partition_D(sOaccum); __syncthreads(); + cute::copy(smem_tiled_copy_Oaccum, taccOrO, taccOsO); - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + // Compute GMEM offsets + const IndexT row_offset_o = batch_id * params.o_batch_stride + + m_block * kBlockM * params.o_row_stride + + head_id * params.o_head_stride; + const IndexT row_offset_oaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx) + * params.h + head_id) + * params.seqlen_q + (m_block * kBlockM)) * params.d_v; + const IndexT row_offset_lse = (batch_id * params.h + head_id) * params.seqlen_q + m_block * kBlockM; + const IndexT row_offset_lseaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx) + * params.h + head_id) + * params.seqlen_q + (m_block * kBlockM)); - const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + // Prepare GMEM for final or partial O + Tensor gOaccum = make_tensor( + make_gmem_ptr( + reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + + (Split ? row_offset_oaccum : row_offset_o) + ), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}) + ); - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), - Shape>{}, Stride<_1>{}); + // Prepare GMEM LSE + Tensor gLSEaccum = make_tensor( + make_gmem_ptr( + reinterpret_cast( + Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr + ) + (Split ? row_offset_lseaccum : row_offset_lse) + ), + Shape>{}, + Stride<_1>{} + ); - using GmemTiledCopyO = std::conditional_t; - GmemTiledCopyO gmem_tiled_copy_Oaccum; + // Tiled copy from SMEM -> GMEM for O + using GmemTiledCopyOAccum = std::conditional_t< + !Split, + typename KernelTraits::GmemTiledCopyO, + typename KernelTraits::GmemTiledCopyOaccum + >; + GmemTiledCopyOAccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); __syncthreads(); - if (tidx >= kNThreadsS) { return; } + // If out of range of the "softmax" portion, do not store + if (tidx >= kNumThreadsS) { return; } + // Load from SMEM Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) - Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + // Write out the LSE + auto caccO = make_identity_tensor(Shape, Int>{}); + auto taccOcO = thr_mma_o.partition_C(caccO); + auto taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); + if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + if (row < params.seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } } } - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - // Clear_OOB_K must be false since we don't want to write zeros to gmem + // Identity layout for sO + auto cO = make_identity_tensor( + make_shape(size<0>(sOaccum), size<1>(sOaccum)) + ); + auto tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + auto tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + + // Copy final O back to GMEM flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, + params.seqlen_q - m_block * kBlockM ); } -template -__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, - const int bidb, const int bidh, const int m_block, - const int n_split_idx, const int seqlen_k, - const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreads = Kernel_traits::kNThreads; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - static_assert(kNThreads == 256 and kNThreadsS == 128); - using Element = typename Kernel_traits::Element; - using index_t = typename Kernel_traits::index_t; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// compute_attn_1rowblock_splitkv_mla() +/// - Core logic for Q*K -> S -> Softmax -> S*V -> O +/// - Includes partial accumulation for splits and optional causal masking. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ +void compute_attn_1rowblock_splitkv_mla( + const Flash_fwd_mla_params ¶ms, + const int batch_id, + const int head_id, + const int m_block, + const int n_split_idx, + const int seqlen_k, + const int n_block_min, + const int n_block_max, + const bool no_split, + SharedStorage &shared_storage +) { + constexpr int kBlockM = KernelTraits::kBlockM; + constexpr int kBlockN = KernelTraits::kBlockN; + constexpr int kHeadDim = KernelTraits::kHeadDim; + constexpr int kHeadDimV = KernelTraits::kHeadDimV; + constexpr int kNumThreads = KernelTraits::kNumThreads; + constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax; + using Element = typename KernelTraits::Element; + using IndexT = typename KernelTraits::IndexT; + + static_assert(kNumThreads == 256 && kNumThreadsS == 128, "Expected 256 main threads, 128 softmax threads."); const int tidx = threadIdx.x; - int n_block = n_block_max - 1; + int n_block = n_block_max - 1; - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + // Smem pointers for Q, K, V, partial S, etc. + Tensor sQ = make_tensor( + make_smem_ptr(shared_storage.smem_q.data()), + typename KernelTraits::SmemLayoutQ{} + ); + Tensor sK = make_tensor( + make_smem_ptr(shared_storage.smem_k.data()), + typename KernelTraits::SmemLayoutK{} + ); + Tensor sV = make_tensor( + make_smem_ptr(shared_storage.smem_k.data()), + typename KernelTraits::SmemLayoutV{} + ); + Tensor sVt = make_tensor( + make_smem_ptr(shared_storage.smem_k.data()), + typename KernelTraits::SmemLayoutVtransposed{} + ); - Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); - Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); - Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); - Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + // Softmax partial + Tensor sP = make_tensor( + make_smem_ptr(shared_storage.smem_p.data()), + typename KernelTraits::SmemLayoutP{} + ); + Tensor tPsP = sP(_, tidx % kNumThreadsS, _, _); - typename Kernel_traits::TiledMmaO tiled_mma_o; + // Row-based scale, sum, etc. + Tensor sScale = make_tensor( + make_smem_ptr(shared_storage.smem_scale.data()), + typename KernelTraits::SmemLayoutRow{} + ); + Tensor tScale = sScale(_, tidx % kNumThreadsS); + Tensor sRowMax = make_tensor( + make_smem_ptr(shared_storage.smem_max.data()), + typename KernelTraits::SmemLayoutRow{} + ); + Tensor tRowMax = sRowMax(_, tidx % kNumThreadsS); + Tensor sRowSum = make_tensor( + make_smem_ptr(shared_storage.smem_sum.data()), + typename KernelTraits::SmemLayoutRow{} + ); + Tensor tRowSum = sRowSum(_, tidx % kNumThreadsS); + + // Mma for O + typename KernelTraits::TiledMmaO tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) - Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(tOrO); + // Combined softmax utility flash::Softmax<2 * size<1>(tOrO)> softmax; - int warp_group_idx = cutlass::canonical_warp_group_idx(); - if (warp_group_idx == 0) { - typename Kernel_traits::TiledMma tiled_mma; + // Warp group logic: warpGroupIdx=0 does Q*K->S, warpGroupIdx=1 does async loads for next iteration + int warpGroupIdx = cutlass::canonical_warp_group_idx(); + if (warpGroupIdx == 0) { + // Main matmul Q*K -> S + typename KernelTraits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); + Tensor tSrK = thr_mma.partition_fragment_B(sK); + + // If n_block is odd => shift for double-buffer if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + constexpr int sKOffset = size(sK); + tSrK.data() += (sKOffset / 8); + tOrVt.data() += (sKOffset / 8); } - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; -#pragma unroll 1 - for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { - __syncthreads(); - - Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); - - const bool is_masking_step = masking_step > 0; - const bool is_first_masking_step = masking_step == n_masking_steps; - - if (is_masking_step) { - Tensor cS = make_identity_tensor(Shape, Int>{}); - Tensor tScS = thr_mma.partition_C(cS); -#pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col - if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; - } else { - // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - int row = int(get<0>(tScS(i))); - int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; - if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; - } - } - } - - // We have key_padding_mask so we'll need to Check_inf - Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) - : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); - - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); - cute::copy(scale_o, tScale_osScale_o); - - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - cute::copy(softmax.row_max, tRow_maxsRow_max); - cute::copy(softmax.row_sum, tRow_sumsRow_sum); - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - } else { - const int *block_table = params.block_table + bidb * params.block_table_batch_stride; - int cur_block_table = __ldg(&block_table[n_block]); - - const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); - - const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; - auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); - Tensor tKgK = gmem_thr_copy_K.partition_S(gK); - Tensor tKsK = gmem_thr_copy_K.partition_D(sK); - Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); - - if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tKsK.data() = tKsK.data() + sK_offset; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need to clear the sK smem tiles because K is V. - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, - seqlen_k - n_block * kBlockN); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - - if (n_block - 1 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 1]); - } + // We have a loop from n_block_max-1 down to n_block_min + // Need to do “masking step(s)” for partial or causal scenarios. + constexpr int nMaskingSteps = !IsCausal + ? 1 + : cute::ceil_div(kBlockM, kBlockN) + 1; #pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - flash::cp_async_wait<0>(); - __syncthreads(); - - if (n_block - 1 >= n_block_min) { - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tKsK.data() = tKsK.data() + sK_offset; - - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); - - if (n_block - 2 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 2]); - } - - typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); - Tensor scale_o = make_tensor(Shape<_2>{}); - cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - cute::copy(tRow_maxsRow_max, softmax.row_max); - cute::copy(tRow_sumsRow_sum, softmax.row_sum); - } - - if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); - else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); -} - -template -__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kBlockN = Kernel_traits::kBlockN; - const int m_block = blockIdx.x; - const int bidh = blockIdx.y; - const int partition_idx = blockIdx.z; - - extern __shared__ char shared_memory[]; - auto &shared_storage = *reinterpret_cast(shared_memory); - - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int begin_seqlen = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int end_seqlen = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - -#pragma unroll 1 - for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { - const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; - const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); - const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; - const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); - const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); - if (batch_id > begin_idx) { - __syncthreads(); // Barrier between two tiles. - } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void __launch_bounds__(256, 1, 1) -flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kNThreads = 128; - - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; + for (int masking const int hs = params.h * params.seqlen_q; const int batch_idx = bidx / hs; const int hs_idx = bidx % hs;