diff --git a/.gitignore b/.gitignore index 5f9e980..982daef 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__/ dist/ *perf.csv *.png +/.vscode diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 9015735..0a19789 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -10,8 +10,10 @@ #include -#include "flash_mla.h" -#include "static_switch.h" +#include "kernels/config.h" +#include "kernels/get_mla_metadata.h" +#include "kernels/mla_combine.h" +#include "kernels/splitkv_mla.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -23,11 +25,6 @@ get_mla_metadata( const int num_heads_per_head_k, const int num_heads_k ) { - // This should match the logic in the MLA kernel. - static constexpr int block_size_m = 64; - static constexpr int block_size_n = 64; - static constexpr int fixed_overhead_num_blocks = 5; - CHECK_DEVICE(seqlens_k); TORCH_CHECK(seqlens_k.is_contiguous()); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); @@ -38,7 +35,7 @@ get_mla_metadata( auto dprops = at::cuda::getCurrentDeviceProperties(); int sm_count = dprops->multiProcessorCount; - int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M); auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); auto num_splits = torch::empty({batch_size + 1}, options); @@ -52,10 +49,10 @@ get_mla_metadata( params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; params.num_splits_ptr = num_splits_ptr; params.batch_size = batch_size; - params.block_size_n = block_size_n; - params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.block_size_n = Config::PAGE_BLOCK_SIZE; + params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS; params.num_sm_parts = num_sm_parts; - get_mla_metadata_func(params, stream); + run_get_mla_metadata_kernel(params, stream); return {tile_scheduler_metadata, num_splits}; } @@ -64,7 +61,6 @@ std::vector mha_fwd_kvcache_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq @@ -73,138 +69,141 @@ mha_fwd_kvcache_mla( const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits // batch_size + 1 ) { + // Check the architecture auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; TORCH_CHECK(is_sm90); - at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; - + // Check data types auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + // Check device + CHECK_DEVICE(q); + CHECK_DEVICE(kcache); + CHECK_DEVICE(seqlens_k); + CHECK_DEVICE(block_table); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_DEVICE(num_splits); + // Check layout TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + CHECK_CONTIGUOUS(seqlens_k); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + CHECK_CONTIGUOUS(num_splits); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; - const int num_heads_ori = sizes[2]; - const int head_size = sizes[3]; - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); - TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q_ori == 1) { is_causal = false; } - const int ngroups = num_heads_ori / num_heads_k; - const int seqlen_q = seqlen_q_ori * ngroups; + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; const int num_heads = num_heads_k; - q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) - .reshape({batch_size, seqlen_q, num_heads, head_size}); + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); - int head_size_k = head_size; - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); - if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - - - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); - CHECK_DEVICE(seqlens_k); - CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_SHAPE(num_splits, batch_size+1); at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); - at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse); Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size; - params.seqlen_q = seqlen_q; - params.cu_seqlens_k = seqlens_k.data_ptr(); - params.h = num_heads; - params.h_h_k_ratio = num_heads / num_heads_k; - params.ngroups = ngroups; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; params.is_causal = is_causal; - params.d = head_size; + params.d = head_size_k; params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = kcache.data_ptr(); - params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_batch_stride = q.stride(0); params.k_batch_stride = kcache.stride(0); - params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_row_stride = q.stride(-3); params.k_row_stride = kcache.stride(-3); - params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = kcache.stride(-2); - params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; - - TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); - TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); - CHECK_DEVICE(tile_scheduler_metadata); - CHECK_CONTIGUOUS(tile_scheduler_metadata); + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); - TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - CHECK_DEVICE(num_splits); - CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); - at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse_accum); + CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(head_size == 576); - + TORCH_CHECK(head_size_k == 576); if (q_dtype == torch::kBFloat16) { - run_mha_fwd_splitkv_mla(params, stream); - } - #ifndef FLASH_MLA_DISABLE_FP16 - else if (q_dtype == torch::kHalf) { - run_mha_fwd_splitkv_mla(params, stream); - } - #endif - else { + run_flash_splitkv_mla_kernel(params, stream); + run_flash_mla_combine_kernel(params, stream); + } else if (q_dtype == torch::kHalf) { +#ifdef FLASH_MLA_DISABLE_FP16 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); +#else + run_flash_splitkv_mla_kernel(params, stream); + run_flash_mla_combine_kernel(params, stream); +#endif + } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } - out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) - .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); - softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) - .reshape({batch_size, num_heads_ori, seqlen_q_ori}); + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); return {out, softmax_lse}; } diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu deleted file mode 100644 index 35691f2..0000000 --- a/csrc/flash_fwd_mla_bf16_sm90.cu +++ /dev/null @@ -1,3 +0,0 @@ -#include "flash_fwd_mla_kernel.h" - -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu deleted file mode 100644 index abdaf7b..0000000 --- a/csrc/flash_fwd_mla_fp16_sm90.cu +++ /dev/null @@ -1,3 +0,0 @@ -#include "flash_fwd_mla_kernel.h" - -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h deleted file mode 100644 index d96acd8..0000000 --- a/csrc/flash_fwd_mla_kernel.h +++ /dev/null @@ -1,603 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -using namespace cute; - -#include "named_barrier.h" -#include "utils.h" -#include "softmax.h" -#include "static_switch.h" -#include "flash_mla.h" - - -template -constexpr auto getSmemLayoutK() { - constexpr int headSizeBytes = sizeof(PrecType) * DIM; - constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; - - if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { - return GMMA::Layout_K_SW64_Atom{}; - } else { - return GMMA::Layout_K_SW32_Atom{}; - } -} - -template -struct Flash_fwd_kernel_traits_mla { - using Element = elem_type; - using ElementAccum = float; - using index_t = int64_t; - - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - static constexpr int kNWarpsS = 4; - static constexpr int kNThreadsS = kNWarpsS * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; - static_assert(kHeadDimV % 32 == 0); - static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = decltype(make_tiled_mma( - cute::GMMA::ss_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::K>(), - Layout, _1, _1>>{})); - - static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; - using TiledMmaO = decltype(make_tiled_mma( - cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), - Layout, Int, _1>>{})); - - using SmemLayoutQ = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - - using SmemLayoutK = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - - using SmemLayoutV = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - - using SmemLayoutP = Layout, Int, _1, Int>>; - using SmemLayoutRow = Layout>, Stride<_1, _2>>; - - using SmemLayoutAtomO = decltype(composition( - Swizzle{}, - Layout, Int>, Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; - static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; - static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - - using GmemLayoutAtom = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopy = decltype(make_tiled_copy( - Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - - using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom, Element>{}, - GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store - - static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); - static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; - using GmemLayoutAtomOaccum = Layout< - Shape, Int>, - Stride, _1>>; - using GmemTiledCopyOaccum = decltype(make_tiled_copy( - Copy_Atom, ElementAccum>{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store -}; - -namespace flash { - -using namespace cute; - -template -struct SharedStorageMLA { - union { - struct { - cute::array_aligned> smem_q; - cute::array_aligned * 2> smem_k; // Double buffer - cute::array_aligned> smem_p; - cute::array_aligned> smem_scale; - }; - struct { - cute::array_aligned> smem_max; - cute::array_aligned> smem_sum; - cute::array_aligned> smem_o; - }; - }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int tidx = threadIdx.x; - - typename Kernel_traits::TiledMmaO tiled_mma_o; - auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - - // Epilogue - - const int split_offset = __ldg(params.num_splits_ptr + bidb); - - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); - - using ElementO = std::conditional_t; - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum - >; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); - auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(tOrO); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - __syncthreads(); - - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - - const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), - Shape>{}, Stride<_1>{}); - - using GmemTiledCopyO = std::conditional_t; - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - - __syncthreads(); - - if (tidx >= kNThreadsS) { return; } - - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); - - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) - Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { -#pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } - } - } - - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM - ); -} - -template -__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, - const int bidb, const int bidh, const int m_block, - const int n_split_idx, const int seqlen_k, - const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kHeadDimV = Kernel_traits::kHeadDimV; - constexpr int kNThreads = Kernel_traits::kNThreads; - constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - static_assert(kNThreads == 256 and kNThreadsS == 128); - using Element = typename Kernel_traits::Element; - using index_t = typename Kernel_traits::index_t; - - const int tidx = threadIdx.x; - int n_block = n_block_max - 1; - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); - - Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); - Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); - Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); - Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); - Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); - - typename Kernel_traits::TiledMmaO tiled_mma_o; - auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); - Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) - Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - clear(tOrO); - - flash::Softmax<2 * size<1>(tOrO)> softmax; - - int warp_group_idx = cutlass::canonical_warp_group_idx(); - if (warp_group_idx == 0) { - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - - if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; -#pragma unroll 1 - for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { - __syncthreads(); - - Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) - flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); - - const bool is_masking_step = masking_step > 0; - const bool is_first_masking_step = masking_step == n_masking_steps; - - if (is_masking_step) { - Tensor cS = make_identity_tensor(Shape, Int>{}); - Tensor tScS = thr_mma.partition_C(cS); -#pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col - if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; - } else { - // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - int row = int(get<0>(tScS(i))); - int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; - if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; - } - } - } - - // We have key_padding_mask so we'll need to Check_inf - Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) - : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); - - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); - cute::copy(scale_o, tScale_osScale_o); - - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - cute::copy(softmax.row_max, tRow_maxsRow_max); - cute::copy(softmax.row_sum, tRow_sumsRow_sum); - cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - } else { - const int *block_table = params.block_table + bidb * params.block_table_batch_stride; - int cur_block_table = __ldg(&block_table[n_block]); - - const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; - auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); - Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); - - const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; - auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); - Tensor tKgK = gmem_thr_copy_K.partition_S(gK); - Tensor tKsK = gmem_thr_copy_K.partition_D(sK); - Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); - - if (n_block % 2 == 1) { - // Double buffer for sK - constexpr int sK_offset = size(sK); - tKsK.data() = tKsK.data() + sK_offset; - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - // We need to clear the sK smem tiles because K is V. - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, - seqlen_k - n_block * kBlockN); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - - if (n_block - 1 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 1]); - } - -#pragma unroll 1 - for (; n_block >= n_block_min; --n_block) { - flash::cp_async_wait<0>(); - __syncthreads(); - - if (n_block - 1 >= n_block_min) { - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tKsK.data() = tKsK.data() + sK_offset; - - const index_t offset_k = cur_block_table * params.k_batch_stride; - tKgK.data() = tKgK.data() + offset_k; - flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); - tKgK.data() = tKgK.data() + -offset_k; - cute::cp_async_fence(); - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); - - if (n_block - 2 >= n_block_min) { - cur_block_table = __ldg(&block_table[n_block - 2]); - } - - typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); - Tensor scale_o = make_tensor(Shape<_2>{}); - cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); - - flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tOrVt.data() = tOrVt.data() + sK_offset / 8; - } - - cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); - cute::copy(tRow_maxsRow_max, softmax.row_max); - cute::copy(tRow_sumsRow_sum, softmax.row_sum); - } - - if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); - else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); -} - -template -__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kBlockN = Kernel_traits::kBlockN; - const int m_block = blockIdx.x; - const int bidh = blockIdx.y; - const int partition_idx = blockIdx.z; - - extern __shared__ char shared_memory[]; - auto &shared_storage = *reinterpret_cast(shared_memory); - - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int begin_seqlen = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int end_seqlen = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - -#pragma unroll 1 - for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { - const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; - const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); - const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; - const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); - const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); - if (batch_id > begin_idx) { - __syncthreads(); // Barrier between two tiles. - } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void __launch_bounds__(256, 1, 1) -flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { - constexpr int kNThreads = 128; - - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - const int hs = params.h * params.seqlen_q; - const int batch_idx = bidx / hs; - const int hs_idx = bidx % hs; - - const int split_offset = __ldg(params.num_splits_ptr + batch_idx); - const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; - FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); - if (actual_num_splits == 1) return; - - __shared__ ElementAccum sLseScale[kMaxSplits]; - - const index_t row_offset_lseaccum = split_offset * hs + hs_idx; - const index_t row_offset_lse = bidx; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), - Shape>{}, make_stride(hs)); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape<_1>{}, Stride<_1>{}); - - int warp_idx = cutlass::canonical_warp_idx_sync(); - if (warp_idx == 0) { - constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); - - float local_lse[kNLsePerThread]; - for (int i = 0; i < kNLsePerThread; ++i) { - const int split = i * 32 + tidx; - local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; - } - - float max_lse = -INFINITY; - for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); - for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); - max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf - - float sum_lse = 0; - for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); - for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); - - float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; - if (tidx == 0) gLSE(0) = global_lse; - - for (int i = 0; i < kNLsePerThread; ++i) { - const int split = i * 32 + tidx; - if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); - } - } - __syncthreads(); - - static_assert(kHeadDimV % kNThreads == 0); - constexpr int Elements = kHeadDimV / kNThreads; - const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape>{}, Stride<_1>{}); - using GmemTiledCopyOaccum = decltype(make_tiled_copy( - Copy_Atom, ElementAccum>{}, - Layout>>{}, - Layout>>{})); - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - for (int split = 0; split < actual_num_splits; ++split) { - cute::copy(tOgOaccum, tOrOaccum); - ElementAccum lse_scale = sLseScale[split]; - for (int i = 0; i < size(tOrO); ++i) { - tOrO(i) += lse_scale * tOrOaccum(i); - } - tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; - } - - Tensor rO = flash::convert_type(tOrO); - const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; - const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); - cute::copy(rO, gO); -} - -} // namespace flash - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); - const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - auto kernel = &flash::flash_fwd_splitkv_mla_kernel; - constexpr size_t smem_size = sizeof(SharedStorage); - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - kernel<<>>(params); - }); - CHECK_CUDA_KERNEL_LAUNCH(); - - dim3 grid_combine(params.b * params.h * params.seqlen_q); - MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { - auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< - typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; - combine_kernel<<>>(params); - }); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -template -void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { - static_assert(Headdim == 576); - FLASH_ASSERT(params.d_v == 512); - FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV - using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; - run_flash_splitkv_fwd_mla>(params, stream); -} diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h index 2994cb7..3b4e254 100644 --- a/csrc/flash_mla.h +++ b/csrc/flash_mla.h @@ -5,39 +5,41 @@ struct Flash_fwd_mla_params { using index_t = int64_t; - int b, seqlen_q, d, d_v; - int h, h_h_k_ratio, ngroups; + int b; // batch size + int s_q; + int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q + int d, d_v; // K/V dimension + int h_q, h_k; // The number of Q/K heads + int num_blocks; // Number of blocks in total + int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; - int *__restrict__ cu_seqlens_k; - + void *__restrict__ q_ptr; void *__restrict__ k_ptr; - void *__restrict__ v_ptr; void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; index_t q_batch_stride; index_t k_batch_stride; - index_t v_batch_stride; index_t o_batch_stride; index_t q_row_stride; index_t k_row_stride; - index_t v_row_stride; index_t o_row_stride; index_t q_head_stride; index_t k_head_stride; - index_t v_head_stride; index_t o_head_stride; int *__restrict__ block_table; index_t block_table_batch_stride; int page_block_size; + int *__restrict__ seqlens_k_ptr; int *__restrict__ tile_scheduler_metadata_ptr; int num_sm_parts; int *__restrict__ num_splits_ptr; + int total_num_splits; void *__restrict__ softmax_lseaccum_ptr; void *__restrict__ oaccum_ptr; }; @@ -45,11 +47,6 @@ struct Flash_fwd_mla_params { static constexpr int TileSchedulerMetaDataSize = 8; // [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); - struct Mla_metadata_params { int *__restrict__ seqlens_k_ptr; int *__restrict__ tile_scheduler_metadata_ptr; @@ -59,5 +56,3 @@ struct Mla_metadata_params { int fixed_overhead_num_blocks; int num_sm_parts; }; - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/config.h b/csrc/kernels/config.h new file mode 100644 index 0000000..c9ce159 --- /dev/null +++ b/csrc/kernels/config.h @@ -0,0 +1,13 @@ +#pragma once + +namespace Config { + +static constexpr int BLOCK_SIZE_M = 64; +static constexpr int PAGE_BLOCK_SIZE = 64; + +static constexpr int HEAD_DIM_K = 576; +static constexpr int HEAD_DIM_V = 512; + +static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5; + +} diff --git a/csrc/flash_fwd_mla_metadata.cu b/csrc/kernels/get_mla_metadata.cu similarity index 79% rename from csrc/flash_fwd_mla_metadata.cu rename to csrc/kernels/get_mla_metadata.cu index 82f5b5a..6b78f9b 100644 --- a/csrc/flash_fwd_mla_metadata.cu +++ b/csrc/kernels/get_mla_metadata.cu @@ -1,8 +1,11 @@ -#include "flash_fwd_mla_kernel.h" +#include "get_mla_metadata.h" -static constexpr int MaxBatchSize = 4096; +#include +#include -__global__ void __launch_bounds__(256, 1, 1) +#include "utils.h" + +__global__ void __launch_bounds__(32, 1, 1) get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { int *seqlens_k_ptr = params.seqlens_k_ptr; int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; @@ -12,8 +15,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; int num_sm_parts = params.num_sm_parts; - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; + extern __shared__ int shared_mem[]; + int* num_blocks_shared = shared_mem; // [batch_size] + int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] int total_num_blocks = 0; for (int i = threadIdx.x; i < batch_size; i += 32) { @@ -27,7 +31,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { __syncwarp(); if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks); int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; num_splits_shared[0] = 0; @@ -70,8 +74,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { } } -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); +void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) { + int smem_size = sizeof(int) * (params.batch_size*2+1); + CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); CHECK_CUDA_KERNEL_LAUNCH(); -} \ No newline at end of file +} diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/kernels/get_mla_metadata.h new file mode 100644 index 0000000..5faa665 --- /dev/null +++ b/csrc/kernels/get_mla_metadata.h @@ -0,0 +1,5 @@ +#pragma once + +#include "flash_mla.h" + +void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/mla_combine.cu b/csrc/kernels/mla_combine.cu new file mode 100644 index 0000000..681dfe0 --- /dev/null +++ b/csrc/kernels/mla_combine.cu @@ -0,0 +1,207 @@ +#include "mla_combine.h" + +#include +#include +#include +#include + +#include "flash_mla.h" +#include "utils.h" +#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V + +using namespace cute; + +template +__global__ void __launch_bounds__(NUM_THREADS) +flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] + // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result + static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m + const int batch_idx = blockIdx.x; + const int m_block_idx = blockIdx.y; + const int warp_idx = threadIdx.x / 32; + const int lane_idx = threadIdx.x % 32; + + const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); + const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); + const int my_num_splits = end_split_idx - start_split_idx; + FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); + if (my_num_splits == 1) { + return; + } + + const int num_q_seqs = params.q_seq_per_hk * params.h_k; + const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M); + Tensor gLseAccum = make_tensor( + make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), + Shape, Int>{}, + make_stride(num_q_seqs, _1{}) + ); + Tensor gLse = make_tensor( + make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), + Shape>{}, + Stride<_1>{} + ); + + extern __shared__ float smem_buf[]; + Tensor sLseScale = make_tensor( + make_smem_ptr(smem_buf), + Shape, Int>{}, + Stride, _1>{} // +1 to avoid bank conflict + ); + + // Wait for the previous kernel (the MLA kernel) to finish + cudaGridDependencySynchronize(); + + // Read gLseAccum into sLseScale + { + #pragma unroll 4 + for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) { + int split_idx = elem_idx / BLOCK_SIZE_M; + int seq_idx = elem_idx % BLOCK_SIZE_M; + sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY; + } + __syncthreads(); + } + + if (warp_idx >= num_cur_valid_q_seqs) + return; + + // Warp #i gathers LseAccum for seq #i + { + constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32); + float local_lse[NUM_LSE_PER_THREAD]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { + const int split_idx = i*32 + lane_idx; + local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY; + } + + float max_lse = -INFINITY; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) + max_lse = max(max_lse, local_lse[i]); + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) + max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) + sum_lse = sum_lse + exp2f(local_lse[i] - max_lse); + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) + sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse; + if (lane_idx == 0) + gLse(warp_idx) = global_lse / (float)M_LOG2E; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { + const int split_idx = i*32 + lane_idx; + if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse); + } + } + + __syncwarp(); + + // Warp #i accumulates activation for seq #i + { + const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V; + Tensor gOaccum = make_tensor( + make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + make_stride(num_q_seqs*HEAD_DIM_V, _1{}) + ); + + static_assert(HEAD_DIM_V % 32 == 0); + constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32; + float result[ELEMS_PER_THREAD]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) + result[i] = 0.0f; + + #pragma unroll 2 + for (int split = 0; split < my_num_splits; ++split) { + float lse_scale = sLseScale(warp_idx, split); + if (lse_scale != 0.f) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) { + result[i] += lse_scale * gOaccum(split, lane_idx + i*32); + } + } + } + + cudaTriggerProgrammaticLaunchCompletion(); + + const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx; + const int k_head_idx = q_seq_idx / params.q_seq_per_hk; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride; + Tensor gO = make_tensor( + make_gmem_ptr(o_ptr), + Shape>{}, + Stride<_1>{} + ); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) + gO(lane_idx+i*32) = (ElementT)result[i]; + } +} + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() + + +template +void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { + constexpr int BLOCK_SIZE_M = 8; + constexpr int NUM_THREADS = BLOCK_SIZE_M*32; + constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); + auto combine_kernel = &flash_fwd_mla_combine_kernel; + CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + cudaLaunchConfig_t combine_kernel_config = { + dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1), + dim3(NUM_THREADS, 1, 1), + smem_size, + stream, + attribute, + 1 + }; + cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +#ifndef FLASH_MLA_DISABLE_FP16 +template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +#endif \ No newline at end of file diff --git a/csrc/kernels/mla_combine.h b/csrc/kernels/mla_combine.h new file mode 100644 index 0000000..3f33b8f --- /dev/null +++ b/csrc/kernels/mla_combine.h @@ -0,0 +1,6 @@ +#pragma once + +#include "flash_mla.h" + +template +void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu new file mode 100644 index 0000000..950e594 --- /dev/null +++ b/csrc/kernels/splitkv_mla.cu @@ -0,0 +1,1349 @@ +#include + +#include "flash_mla.h" +#include "utils.h" +#include "config.h" +#include "traits.h" + +using namespace cute; +using cutlass::arch::NamedBarrier; + +// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking +// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) +// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM +static constexpr float MAX_INIT_VAL_SM = -1e30f; +static constexpr float MAX_INIT_VAL = -1e33f; + + +__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx + // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + +// Launch TMA copy for a range of KV tile +// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64 +template< + int START_HEAD_DIM_TILE_IDX, + int END_HEAD_DIM_TILE_IDX, + typename TMA_K_OneTile, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void launch_kv_tiles_copy_tma( + Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled + TMA_K_OneTile &tma_K, + TMABarrier* barriers_K, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + auto thr_tma = tma_K.get_slice(_0{}); + Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); + Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int{}); + cute::copy(tma_K.with(reinterpret_cast(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV); + if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { + launch_kv_tiles_copy_tma(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup); + } + } +} + +// Prefetch some KV tiles +// Currently this is not used because it leads to performance degradation +template< + int START_HEAD_DIM_TILE_IDX, + int END_HEAD_DIM_TILE_IDX, + typename TMA_K_OneTile, + typename Engine0, typename Layout0 +> +__forceinline__ __device__ void prefetch_kv_tiles( + Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + TMA_K_OneTile &tma_K, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + auto thr_tma = tma_K.get_slice(_0{}); + Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); + cute::prefetch(tma_K, cur_gKV); + if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { + prefetch_kv_tiles(gKV, tma_K, idx_in_warpgroup); + } + } +} + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +// * Copyright (c) 2024, Tri Dao. +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + + +// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64) +// The Q-tile should be in shared memory +template< + typename TiledMMA, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void qkt_gemm_one_tile_sQ( + TiledMMA &tiled_mma, + Tensor const &thr_mma_sQ_tile, // (MMA, 1, 4) + Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) + Tensor &rP, // ((2, 2, 8), 1, 1) + TMABarrier* barrier, + bool &cur_phase, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + barrier->arrive_and_expect_tx(64*64*2); + } + barrier->wait(cur_phase ? 1 : 0); + + warpgroup_fence_operand(rP); + warpgroup_arrive(); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); + warpgroup_commit_batch(); + warpgroup_fence_operand(rP); +} + +template< + typename TiledMMA, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void qkt_gemm_one_tile_rQ( + TiledMMA &tiled_mma, + Tensor const &thr_mma_rQ_tile, // (MMA, 1, 4) + Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4) + Tensor &rP, // ((2, 2, 8), 1, 1) + TMABarrier* barrier, + bool &cur_phase, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + barrier->arrive_and_expect_tx(64*64*2); + } + barrier->wait(cur_phase ? 1 : 0); + + warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); + warpgroup_fence_operand(rP); + warpgroup_arrive(); + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP); + cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP); + warpgroup_commit_batch(); + warpgroup_fence_operand(rP); + warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile)); +} + +// Pipelined TMA wait and Q K^T gemm +// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows: +// - Wait for the 0-th tile to be ready using `barrier.wait()` +// - Compute Q K^T for the 0-th tile +// - Wait for the 1-st tile to be ready +// - Compute Q K^T for the 1-st tile +// ... +// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation +template< + typename T, // Traits + int PHASE_IDX, // See comments in the code + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3 +> +__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm( + Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + Tensor &rP, // ((2, 2, 8), 1, 1) + Tensor &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1 + TMABarrier* barriers, + bool &cur_phase, + int idx_in_warpgroup +) { + Tensor sQ_tiled = flat_divide(sQ, Shape, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9) + Tensor sKV_tiled = flat_divide(sKV, Shape, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9) + TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){}; + ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup); + Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9) + Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9) + TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){}; + + #define QKT_GEMM_ONE_TILE(TILE_IDX) \ + if constexpr(TILE_IDX != 8) { \ + qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int{}), thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ + } else { \ + qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \ + } + + if constexpr (PHASE_IDX == 0) { + // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; + tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(0); + QKT_GEMM_ONE_TILE(1); + QKT_GEMM_ONE_TILE(2); + QKT_GEMM_ONE_TILE(3); + } else if constexpr (PHASE_IDX == 1) { + // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; + tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(4); + QKT_GEMM_ONE_TILE(5); + QKT_GEMM_ONE_TILE(6); + QKT_GEMM_ONE_TILE(7); + QKT_GEMM_ONE_TILE(8); + QKT_GEMM_ONE_TILE(0); + QKT_GEMM_ONE_TILE(1); + QKT_GEMM_ONE_TILE(2); + QKT_GEMM_ONE_TILE(3); + cur_phase ^= 1; + } else { + // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles + static_assert(PHASE_IDX == 2); + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One; + tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(4); + QKT_GEMM_ONE_TILE(5); + QKT_GEMM_ONE_TILE(6); + QKT_GEMM_ONE_TILE(7); + QKT_GEMM_ONE_TILE(8); + cur_phase ^= 1; + } +} + + +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline( + Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &rP, // ((2, 2, 8), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36) + Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36) + gemm(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP); +} + + +// Compute O += PV, where P resides in register +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP( + Tensor &rP, // ((2, 2, 8), 1, 1), fragment A layout + Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) + Tensor &rO, // ((2, 2, 32), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor rP_retiled = make_tensor(rP.data(), Layout< + Shape, _1, _4>, + Stride, _0, _8> + >{}); + Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) + gemm(tiled_mma, rP_retiled, thr_mma_sKV_half, rO); +} + + +// Compute O += PV, where P resides in shared memory +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP( + Tensor &sP, + Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) + Tensor &rO, // ((2, 2, 32), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP); + Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4) + gemm(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO); +} + + +template< + typename T, + bool DO_OOB_FILLING, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4 +> +__forceinline__ __device__ void wg0_bunch_0( + Tensor &rPb, // ((2, 2, 8), 1, 1) + Tensor &rP0, // ((2, 2, 8), 1, 1) + Tensor &rO0, // ((2, 2, 32), 1, 1) + Tensor &sScale0, // (BLOCK_SIZE_M) + Tensor &sM, // (BLOCK_SIZE_M) + float rL[2], + int rRightBorderForQSeq[2], + float scale_softmax_log2, + int start_token_idx, + int idx_in_warpgroup +) { + // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png) + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + + // Mask, and get row-wise max + float cur_max = MAX_INIT_VAL; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + if constexpr (DO_OOB_FILLING) { + int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; + rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL; + rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL; + } + cur_max = max(cur_max, max(rP0(i), rP0(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + + // Update sM and sL + cur_max *= scale_softmax_log2; + float new_max = max(sM(row_idx), cur_max); + float scale_for_old = exp2f(sM(row_idx) - new_max); + __syncwarp(); // Make sure all reads have finished before updating sM + if (idx_in_warpgroup%4 == 0) { + sScale0(row_idx) = scale_for_old; + sM(row_idx) = new_max; + } + + // Scale-O + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { + rO0(i) *= scale_for_old; + rO0(i+1) *= scale_for_old; + } + + // Scale, exp, and get row-wise expsum + float cur_sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max); + rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max); + rPb(i) = (typename T::InputT)rP0(i); + rPb(i+1) = (typename T::InputT)rP0(i+1); + cur_sum += rP0(i) + rP0(i+1); + } + rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; + } +} + + +template< + typename T, + bool IS_BLK0_LAST, + bool IS_BLK1_LAST, + bool IS_BLK2_LAST, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5 +> +__forceinline__ __device__ void wg1_bunch_0( + Tensor &rP1b, // ((2, 2, 8), 1, 1) + Tensor &sScale1, // (BLOCK_SIZE_M) + Tensor &rO1, // ((2, 2, 32), 1, 1) + Tensor &sM, // (BLOCK_SIZE_M) + float rL[2], + int rRightBorderForQSeq[2], + Tensor const &sScale0, // (BLOCK_SIZE_M) + Tensor &rP1, // ((2, 2, 8), 1, 1) + float scale_softmax_log2, + int start_token_idx, + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + + // Mask, and get row-wise max + float cur_max = MAX_INIT_VAL; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { + if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) { + // Need to apply the mask when either this block is the last one, or + // the next block is the last one (because of the causal mask) + int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; + rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL; + rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL; + } else if constexpr (IS_BLK0_LAST) { + rP1(i) = rP1(i+1) = MAX_INIT_VAL; + } + cur_max = max(cur_max, max(rP1(i), rP1(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + cur_max *= scale_softmax_log2; + + float old_max = sM(row_idx); + float new_max = max(old_max, cur_max); + float scale_for_old = exp2f(old_max - new_max); + __syncwarp(); + if (idx_in_warpgroup%4 == 0) { + sM(row_idx) = new_max; + sScale1(row_idx) = scale_for_old; + } + + // Scale, exp, and get row-wise expsum + float cur_sum = 0; + if constexpr (!IS_BLK0_LAST) { + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { + rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max); + rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max); + rP1b(i) = (typename T::InputT)rP1(i); + rP1b(i+1) = (typename T::InputT)rP1(i+1); + cur_sum += rP1(i) + rP1(i+1); + } + } + + // Scale O + float cur_scale_for_o1 = scale_for_old * sScale0(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) { + rO1(i) *= cur_scale_for_o1; + rO1(i+1) *= cur_scale_for_o1; + } + + // Update rL + rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum; + } +} + + +// Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void save_rPb_to_sP( + Tensor &rPb, + Tensor &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + (typename T::TiledMMA_QK_sQ){} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +// Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void retrieve_rP_from_sP( + Tensor &rPb, + Tensor const &sP, + int idx_in_warpgroup +) { + TiledCopy s2r_copy = make_tiled_copy_A( + Copy_Atom{}, + (typename T::TiledMMA_PV_LocalP){} + ); + ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_sP = thr_copy.partition_S(sP); + Tensor thr_copy_rPb = thr_copy.retile_D(rPb); + cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); +} + + +// Rescale rP0 and save the result to rPb +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +__forceinline__ __device__ void wg0_scale_rP0( + Tensor const &sScale1, // (BLOCK_M) + Tensor const &rP0, // ((2, 2, 8), 1, 1) + Tensor &rPb, // ((2, 2, 8), 1, 1) + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + float scale_factor = sScale1(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + rPb(i) = (typename T::InputT)(rP0(i)*scale_factor); + rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor); + } + } +} + + +// Rescale rO0 according to sScale1 +template< + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void wg0_rescale_rO0( + Tensor &rO0, + Tensor &sScale1, + float rL[2], + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + float scale_factor = sScale1(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { + rO0(i) *= scale_factor; + rO0(i+1) *= scale_factor; + } + rL[local_row_idx] *= scale_factor; + } +} + + +// Fill out-of-bound V with 0.0 +// We must fill it since it may contain NaN, which may propagate to the final result +template< + typename T, + typename Engine0, typename Layout0 +> +__forceinline__ __device__ void fill_oob_V( + Tensor &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, LayoutRight{} ); + int valid_window_size, + int idx_in_warpgroup +) { + Tensor sV_int64 = make_tensor( + make_smem_ptr((int64_t*)(sV.data().get().get())), + tile_to_shape( + GMMA::Layout_MN_SW128_Atom{}, + Shape, Int>{}, + LayoutRight{} + ) + ); + valid_window_size = max(valid_window_size, 0); + int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds + for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) { + sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0; + } +} + + +// Store O / OAccum +template< + typename T, + bool IS_NO_SPLIT, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +__forceinline__ __device__ void store_o( + Tensor &rO, // ((2, 2, 32), 1, 1) + Tensor &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) + float rL[2], + char* sO_addr, + TMAParams &tma_params, + int batch_idx, + int k_head_idx, + int m_block_idx, + int num_valid_seq_q, + int warpgroup_idx, + int idx_in_warpgroup +) { + using InputT = typename T::InputT; + if constexpr (IS_NO_SPLIT) { + // Should convert the output to bfloat16 / float16, and save it to O + Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + Tensor rOb = make_tensor_like(rO); + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); ++idx) { + rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]); + } + + Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); + TiledCopy r2s_tiled_copy = make_tiled_copy_C( + Copy_Atom{}, + (typename T::TiledMMA_PV_LocalP){} + ); + ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); + Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); + Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); + cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); + cutlass::arch::fence_view_async_shared(); + + __syncthreads(); + + if (threadIdx.x == 0) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, m_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sOutputBuf), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + // Should save the result to OAccum + Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout< + Shape<_64, _512>, + Stride, _1> // We use stride = 520 here to avoid bank conflict + >{}); + + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); idx += 2) { + int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); + int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; + *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 { + rO(idx) / rL[idx%4 >= 2], + rO(idx+1) / rL[idx%4 >= 2], + }; + } + cutlass::arch::fence_view_async_shared(); + + __syncthreads(); + + int row = threadIdx.x; + if (row < num_valid_seq_q) { + SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float)); + cute::tma_store_arrive(); + } + } +} + +template< + typename T, + typename TmaParams, typename Tensor0 +> +__forceinline__ __device__ void launch_q_copy( + TmaParams const &tma_params, + int batch_idx, + int m_block_idx, + int k_head_idx, + Tensor0 &sQ, + TMABarrier* barrier_Q +) { + if (threadIdx.x == 0) { + Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) + auto thr_tma = tma_params.tma_Q.get_slice(_0{}); + Tensor my_tma_gQ = flat_divide(tma_gQ, Shape, Int>{})(_, _, m_block_idx, _0{}); + cute::copy( + tma_params.tma_Q.with(reinterpret_cast(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), + thr_tma.partition_S(my_tma_gQ), + thr_tma.partition_D(sQ) + ); + barrier_Q->arrive_and_expect_tx(64*576*2); + } +} + +template< + typename T, + bool IS_R, + typename Engine0, typename Layout0 +> +__forceinline__ __device__ auto get_half_V( + Tensor &sK +) { + Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){}); + return flat_divide(sV, Shape, Int>{})(_, _, Int<(int)IS_R>{}, _0{}); +} + +template< + typename T, + bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ... + bool IS_BLK1_LAST, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5, + typename Engine6, typename Layout6, + typename Engine7, typename Layout7, + typename Engine8, typename Layout8, + typename Engine9, typename Layout9, + typename Engine10, typename Layout10, + typename Engine11, typename Layout11 +> +__forceinline__ __device__ void wg0_subroutine( + Tensor &tma_gK, + Tensor &sQ, + Tensor &sK0, + Tensor &sK1, + Tensor &sP0, + Tensor &sP1, + Tensor &sM, + Tensor &sScale0, + Tensor &sScale1, + Tensor &rQ8, + Tensor &rP0, + Tensor &rO0, + float rL[2], + int rRightBorderForQSeq[2], + TMABarrier barriers_K0[9], + TMABarrier barriers_K1[9], + bool &cur_phase_K0, + const TMAParams &tma_params, + const Flash_fwd_mla_params ¶ms, + int* block_table_ptr, + int seqlen_k, + int block_idx, + int end_block_idx, + int idx_in_warpgroup +) { + int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; + #define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx))) + int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); + int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); + + Tensor sV0L = get_half_V(sK0); + Tensor sV1L = get_half_V(sK1); + + Tensor rPb = make_tensor(Shape, _1, _4>{}); + // Calc P0 = softmax(P0) + wg0_bunch_0(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready); + + // Issue rO0 += rPb @ sV0L + if constexpr (IS_BLK0_LAST) { + fill_oob_V(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + warpgroup_cooperative_pv_gemm_localP(rPb, sV0L, rO0, idx_in_warpgroup); + + // Wait for rO0, launch TMA for the next V0L + cute::warpgroup_wait<0>(); + + // Wait for warpgroup 1, rescale P0, notify warpgroup 1 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready); + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + // Put it here seems to be faster, don't know why + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); + } + wg0_scale_rP0(sScale1, rP0, rPb, idx_in_warpgroup); + save_rPb_to_sP(rPb, sP0, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready); + + // Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L + if constexpr (!IS_BLK0_LAST) { + if constexpr (IS_BLK1_LAST) { + fill_oob_V(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); + wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup); + warpgroup_cooperative_pv_gemm_remoteP(sP1, sV1L, rO0, idx_in_warpgroup); + } + + // Issue P0 = Q @ K0^T + // Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline. + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); + } + + // Wait for rO0 += rPb @ sV1L, launch TMA + if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) { + cute::warpgroup_wait<4>(); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); + } + + // Issue P0 = Q @ K0^T + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup); + } + + // Wait for P0 = Q @ K0^T + cute::warpgroup_wait<0>(); +} + + +template< + typename T, + bool IS_BLK0_LAST, + bool IS_BLK1_LAST, + bool IS_BLK2_LAST, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5, + typename Engine6, typename Layout6, + typename Engine7, typename Layout7, + typename Engine8, typename Layout8, + typename Engine9, typename Layout9, + typename Engine10, typename Layout10, + typename Engine11, typename Layout11 +> +__forceinline__ __device__ void wg1_subroutine( + Tensor &tma_gK, + Tensor &sQ, + Tensor &sK0, + Tensor &sK1, + Tensor &sP0, + Tensor &sP1, + Tensor &sM, + Tensor &sScale0, + Tensor &sScale1, + Tensor &rQ8, + Tensor &rP1, + Tensor &rO1, + float rL[2], + int rRightBorderForQSeq[2], + TMABarrier barriers_K0[9], + TMABarrier barriers_K1[9], + bool &cur_phase_K1, + const TMAParams &tma_params, + const Flash_fwd_mla_params ¶ms, + int* block_table_ptr, + int seqlen_k, + int block_idx, + int end_block_idx, + int idx_in_warpgroup +) { + int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; + int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); + int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); + + Tensor rP1b = make_tensor(Shape, _1, _4>{}); + + Tensor sV0R = get_half_V(sK0); + Tensor sV1R = get_half_V(sK1); + + // Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready); + wg1_bunch_0(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready); + + // Save rPb to sP, and issue rO1 += rP1b @ sV1R + // We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations + if constexpr (!IS_BLK0_LAST) { + save_rPb_to_sP(rP1b, sP1, idx_in_warpgroup); + } + if constexpr (!IS_BLK0_LAST) { + if constexpr (IS_BLK1_LAST) { + fill_oob_V(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + warpgroup_cooperative_pv_gemm_localP(rP1b, sV1R, rO1, idx_in_warpgroup); + if constexpr (!IS_BLK1_LAST) { + // We use this proxy for making sP1 visible to the async proxy + // We skip it if IS_BLK1_LAST, since in that case we have already put a fence + cutlass::arch::fence_view_async_shared(); + } + } + + // Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready); + if constexpr (IS_BLK0_LAST) { + fill_oob_V(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + } + warpgroup_cooperative_pv_gemm_remoteP(sP0, sV0R, rO1, idx_in_warpgroup); + if constexpr (!IS_BLK0_LAST) { + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); + } + + // Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { + cute::warpgroup_wait<1>(); + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); + } + + // Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + cute::warpgroup_wait<0>(); + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); + } + + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { + // Issue rP1 = sQ @ sK1, wait + warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); + } + + // We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise + // nvcc cannot correctly analyse the loop, and will think that we are using accumulator + // registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized. + cute::warpgroup_wait<0>(); +} + +// A helper function for determining the length of the causal mask for one q token +__forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, int m_block_idx, int local_seq_q_idx) { + int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; + if (global_seq_q_idx < params.q_seq_per_hk) { + int s_q_idx = global_seq_q_idx / params.q_head_per_hk; + return params.s_q - s_q_idx - 1; + } else { + // Out-of-bound request, regard as no masks + return 0; + } +} + +template +__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) { + // grid shape: [ + // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), + // num_kv_heads, + // num_sm_parts + // ] + // An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx). + // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) + // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). + + const int m_block_idx = blockIdx.x; + const int k_head_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int warpgroup_idx = threadIdx.x / 128; + const int idx_in_warpgroup = threadIdx.x % 128; + + // Define shared tensors + extern __shared__ char wksp_buf[]; + using SharedMemoryPlan = typename T::SharedMemoryPlan; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){}); + Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){}); + Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){}); + Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){}); + Tensor sP1 = flat_divide(sQ, Shape, Int>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile + Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int{})); + Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{})); + Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int{})); + Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int{})); + char* sO_addr = (char*)plan.smem_sK0.data(); // Overlap with sK0 and sK1 + + // Prefetch TMA descriptors + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + // Define TMA stuffs + Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _); + TMABarrier* barriers_K0 = plan.barriers_K0; + TMABarrier* barriers_K1 = plan.barriers_K1; + TMABarrier* barrier_Q = &(plan.barrier_Q); + + // Initialize TMA barriers + if (threadIdx.x == 0) { + barrier_Q->init(1); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 9; ++i) { + barriers_K0[i].init(1); + barriers_K1[i].init(1); + } + cutlass::arch::fence_view_async_shared(); + } + __syncthreads(); + bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0; + + // Programmatic Dependent Launch: Wait for the previous kernel to finish + cudaGridDependencySynchronize(); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + // Copy the first Q + launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + constexpr int kBlockN = T::PAGE_BLOCK_SIZE; + const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); + const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; + int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); + + int rRightBorderForQSeq[2]; + if (params.is_causal) { + // The causal mask looks like: + // XXXX + // XXXX + // ... + // XXXX + // XXX + // XXX + // ... + // XXX + // XX + // XX + // ... + // XX + // Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation. + // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks + // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling + int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); + end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE); + } + } else { + rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k; + } + + // Define global tensors + using InputT = typename T::InputT; + InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1) + float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) + int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1) + + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout< + Shape>, + Stride<_1> + >{}); + + // Copy K0 and K1 + launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); + if (start_block_idx+1 < end_block_idx) { + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + } + + Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape, Int>{}); // ((2, 2, 32), 1, 1) + float rL[2]; + rL[0] = rL[1] = 0.0f; + + // Clear buffers + cute::fill(rO, 0.); + if (threadIdx.x < size(sM)) { + sM[threadIdx.x] = MAX_INIT_VAL_SM; + } + + // Wait for Q + barrier_Q->wait(cur_phase_Q); + cur_phase_Q ^= 1; + + Tensor rQ8 = make_tensor(Shape, _1, _4>{}); + retrieve_rP_from_sP(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup); + + if (warpgroup_idx == 0) { + // Warpgroup 0 + Tensor rP0 = make_tensor((typename T::rP0Layout){}); + + // NOTE We don't use the pipelined version of Q K^T here since it leads + // to a slow-down (or even register spilling, thanks to the great NVCC) + // Wait for K0 + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 9; ++i) { + if (idx_in_warpgroup == 0) + barriers_K0[i].arrive_and_expect_tx(64*64*2); + barriers_K0[i].wait(cur_phase_K0); + } + cur_phase_K0 ^= 1; + + // Issue P0 = Q @ K0^T, wait + warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + cute::warpgroup_wait<0>(); + + #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ + wg0_subroutine( \ + tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ + rQ8, rP0, rO, rL, rRightBorderForQSeq, \ + barriers_K0, barriers_K1, cur_phase_K0, \ + tma_params, params, \ + block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ + ); + + int block_idx = start_block_idx; + #pragma unroll 1 + for (; block_idx < end_block_idx-2; block_idx += 2) { + LAUNCH_WG0_SUBROUTINE(false, false); + } + + if (block_idx+1 < end_block_idx) { + LAUNCH_WG0_SUBROUTINE(false, true); + } else if (block_idx < end_block_idx) { + LAUNCH_WG0_SUBROUTINE(true, false); + } + + } else { + // Warpgroup 1 + Tensor rP1 = make_tensor((typename T::rP0Layout){}); + + if (start_block_idx+1 < end_block_idx) { + // Issue rP1 = sQ @ sK1, wait + warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup); + cute::warpgroup_wait<0>(); + } + + #define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \ + wg1_subroutine( \ + tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \ + rQ8, rP1, rO, rL, rRightBorderForQSeq, \ + barriers_K0, barriers_K1, cur_phase_K1, \ + tma_params, params, \ + block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ + ); + + int block_idx = start_block_idx; + #pragma unroll 1 + for (; block_idx < end_block_idx-3; block_idx += 2) { + LAUNCH_WG1_SUBROUTINE(false, false, false); + } + + if (block_idx+2 < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(false, false, true); + block_idx += 2; + LAUNCH_WG1_SUBROUTINE(true, false, false); + } else if (block_idx+1 < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(false, true, false); + } else if (block_idx < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(true, false, false); + } + } + + // Reduce rL across threads within the same warp + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + + // Reduce rL across warpgroups + int my_row = get_AorC_row_idx(0, idx_in_warpgroup); + if (idx_in_warpgroup%4 == 0) { + sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0]; + sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1]; + } + __syncthreads(); + if (warpgroup_idx == 0) { + rL[0] += sL_reduction_wksp[my_row + 64]; + rL[1] += sL_reduction_wksp[my_row + 8 + 64]; + } else { + if (idx_in_warpgroup%4 == 0) { + sL_reduction_wksp[my_row] += rL[0]; + sL_reduction_wksp[my_row + 8] += rL[1]; + } + __syncwarp(); + rL[0] = sL_reduction_wksp[my_row]; + rL[1] = sL_reduction_wksp[my_row+8]; + } + + // Prune out when rL is 0.0f or NaN + // rL may be 0.0f if there are large values (~10^12) in QK^T, which leads + // to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error. + // When this happens, we set rL to 1.0f. This aligns with the old version + // of the MLA kernel. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; + + // Copy Q for the next batch + if (batch_idx+1 <= end_idx) { + launch_q_copy(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q); + } else { + // Allow the next kernel (the combine kernel) to launch + // The next kernel MUST be the combine kernel + cudaTriggerProgrammaticLaunchCompletion(); + } + + int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M); + if (is_no_split) { + store_o(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL_reduction_wksp[i]; + gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E; + } + + cute::tma_store_wait<0>(); + } else { + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout< + Shape>, + Stride<_1> + >{}); + store_o(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL_reduction_wksp[i]; + gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i); + } + + cute::tma_store_wait<0>(); + } + + if (batch_idx != end_idx) + __syncthreads(); + } +} + + +template +void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + using T = Traits; + auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((InputT*)params.q_ptr), + make_layout( + shape_Q, + make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride) + ) + ), + tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + ) + ); + auto shape_K = make_shape(Int{}, Int{}, params.h_k, params.num_blocks); + auto tma_K = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((InputT*)params.k_ptr), + make_layout( + shape_K, + make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride) + ) + ), + tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Layout< + Shape, Int<64>>, + Stride, _1> + >{} + ) + ); + auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((InputT*)params.o_ptr), + make_layout( + shape_O, + make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride) + ) + ), + tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + ) + ); + TmaParams tma_params = { + shape_Q, tma_Q, + shape_K, tma_K, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) + const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); + cudaLaunchAttribute mla_kernel_attributes[1]; + mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; + cudaLaunchConfig_t mla_kernel_config = { + dim3(num_m_block, params.h_k, params.num_sm_parts), + dim3(T::NUM_THREADS, 1, 1), + smem_size, + stream, + mla_kernel_attributes, + 1 + }; + cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +#ifndef FLASH_MLA_DISABLE_FP16 +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +#endif diff --git a/csrc/kernels/splitkv_mla.h b/csrc/kernels/splitkv_mla.h new file mode 100644 index 0000000..42109d4 --- /dev/null +++ b/csrc/kernels/splitkv_mla.h @@ -0,0 +1,6 @@ +#pragma once + +#include "flash_mla.h" + +template +void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels/traits.h b/csrc/kernels/traits.h new file mode 100644 index 0000000..31c1388 --- /dev/null +++ b/csrc/kernels/traits.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +#include "config.h" + +using TMABarrier = cutlass::arch::ClusterTransactionBarrier; +using namespace cute; + +template +struct Traits { + using InputT = InputT_; + + static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; + static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; + static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; + static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; + + static constexpr int NUM_THREADS = 256; + + static_assert(std::is_same_v || std::is_same_v); + + using TiledMMA_QK_sQ = decltype(make_tiled_mma( + GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), + Layout>{} + )); + + using TiledMMA_QK_rQ = decltype(make_tiled_mma( + GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), + Layout>{} + )); + + using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), + Layout>{} + )); + + using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), + Layout>{} + )); + + using SmemLayoutQ = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + using SmemLayoutK = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + using SmemLayoutV = decltype(composition( + SmemLayoutK{}, + make_layout(Shape, Int>{}, GenRowMajor{}) + )); // A transposed version of SmemLayoutK + + using SmemLayoutP0 = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + using rP0Layout = decltype(layout(partition_fragment_C( + TiledMMA_QK_sQ{}, + Shape, Int>{} + ))); + + struct SharedMemoryPlan { + cute::array_aligned> smem_sQ; + cute::array_aligned> smem_sK0; + cute::array_aligned> smem_sK1; + cute::array_aligned> smem_sP0; + cute::array_aligned smem_sM; + cute::array_aligned sL_reduction_wksp; + cute::array_aligned smem_sScale0; + cute::array_aligned smem_sScale1; + TMABarrier barriers_K0[HEAD_DIM_K/64]; + TMABarrier barriers_K1[HEAD_DIM_K/64]; + TMABarrier barrier_Q; + }; + +}; + +template< + typename ShapeQ, typename TMA_Q, + typename ShapeK, typename TMA_K, + typename ShapeO, typename TMA_O +> +struct TmaParams { + ShapeQ shape_Q; + TMA_Q tma_Q; + ShapeK shape_K; + TMA_K tma_K; + ShapeO shape_O; + TMA_O tma_O; +}; + +enum NamedBarriers : int { + sScale0Ready = 0, + sScale1Ready = 1, + sP0Ready = 2, + rO1sP0sV0RIssued = 3 +}; diff --git a/csrc/static_switch.h b/csrc/kernels/utils.h similarity index 57% rename from csrc/static_switch.h rename to csrc/kernels/utils.h index f156adc..ae9d0fc 100644 --- a/csrc/static_switch.h +++ b/csrc/kernels/utils.h @@ -5,7 +5,7 @@ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ + exit(1); \ } \ } while(0) @@ -29,37 +29,4 @@ } \ } while(0) - -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() - - -#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ - [&] { \ - if (NUM_SPLITS <= 32) { \ - constexpr static int NAME = 32; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 64) { \ - constexpr static int NAME = 64; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 96) { \ - constexpr static int NAME = 96; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 128) { \ - constexpr static int NAME = 128; \ - return __VA_ARGS__(); \ - } else if (NUM_SPLITS <= 160) { \ - constexpr static int NAME = 160; \ - return __VA_ARGS__(); \ - } else { \ - FLASH_ASSERT(false); \ - } \ - }() +#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } diff --git a/csrc/named_barrier.h b/csrc/named_barrier.h deleted file mode 100644 index cefa936..0000000 --- a/csrc/named_barrier.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "cutlass/barrier.h" - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Enumerates the reserved named barriers to avoid potential conflicts - -enum class NamedBarriers { - SReady = 1, - SoftmaxReady = 2, -}; - -} // flash diff --git a/csrc/softmax.h b/csrc/softmax.h deleted file mode 100644 index 17e293a..0000000 --- a/csrc/softmax.h +++ /dev/null @@ -1,200 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include - -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } - } -} - -template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++){ - dst(i) = Allreduce<4>::run(src(i), op); - } -} - -template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); -} - -template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); -} - -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - // The following macro will disable the use of fma. - // See: https://github.com/pytorch/pytorch/issues/121558 for more details - // This macro is set in PyTorch and not FlashAttention - #ifdef UNFUSE_FMA - tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); - #else - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - #endif - } - } - return tensor; -} - -// Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; - sum(mi) = 0; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - SumOp sum_op; - sum(mi) = Allreduce<4>::run(sum(mi), sum_op); - } -} - -template -__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - #pragma unroll - for (int mi = 0; mi < size(scale_o); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax { - - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; - - __forceinline__ __device__ Softmax() {}; - - template - __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - TensorT scale_o; - clear(scale_o); - if (Is_first) { - flash::template reduce_max(scores, row_max); - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - flash::reduce_sum(scores, row_sum); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - flash::template reduce_max(scores, row_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = !Check_inf - ? row_max(mi) - : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scale_o(mi) = scores_scale; - row_sum(mi) *= scores_scale; - } - flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); - // We don't do the reduce across threads here since we don't need to use the row_sum. - // We do that reduce at the end when we need to normalize the softmax. - flash::reduce_sum(scores, row_sum); - } - return scale_o; - }; - - template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT lse = make_fragment_like(row_sum); - // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } - return lse; - }; -}; - -} // namespace flash diff --git a/csrc/utils.h b/csrc/utils.h deleted file mode 100644 index 50295f7..0000000 --- a/csrc/utils.h +++ /dev/null @@ -1,241 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -// This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ __forceinline__ T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ __forceinline__ T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { - constexpr bool Is_RS = !cute::is_base_of::value; - // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const - if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } - warpgroup_fence_operand(tCrC); - if constexpr (arrive) { - warpgroup_arrive(); - } - if constexpr (zero_init) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } else { - // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } - if constexpr (commit) { - warpgroup_commit_batch(); - } - if constexpr (wg_wait >= 0) { warpgroup_wait(); } - warpgroup_fence_operand(tCrC); - if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = acc_layout; - if constexpr (!Transposed) { - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); - } else { - return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); - } - - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - if constexpr (!Transposed) { - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - } else { - return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. -// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) -// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) -template -__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { - using X = Underscore; - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); - if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { - auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) - return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); - } else { - static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); - static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); - static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); - auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) - // This combines the first two modes (<0, 0> and <0, 1>) into one mode. - // Will require register shuffling later to be correct. - return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), - get<1>(acc_layout), - coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) - // This combination is right but doesn't work with register shuffling. - // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), - // get<1>(acc_layout), - // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); - } - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Blocks until all but N previous cp.async.commit_group operations have committed. -// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all -// (which is equivalent to commit_group then wait_group 0). -// Instead we just call cp.async.wait_group 0, which is slightly faster. -// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 -template -CUTE_HOST_DEVICE -void cp_async_wait() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, const int max_MN=0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index b2922af..47637f8 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -55,7 +55,6 @@ def flash_mla_with_kvcache( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, - None, head_dim_v, cache_seqlens, block_table, diff --git a/setup.py b/setup.py index cd311f2..131ceff 100644 --- a/setup.py +++ b/setup.py @@ -11,29 +11,13 @@ from torch.utils.cpp_extension import ( IS_WINDOWS, ) -DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" - - def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] - -def get_sources(): - sources = [ - "csrc/flash_api.cpp", - "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_metadata.cu", - ] - - if not DISABLE_FP16: - sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") - - return sources - - def get_features_args(): features_args = [] + DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args @@ -56,7 +40,12 @@ ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla_cuda", - sources=get_sources(), + sources=[ + "csrc/flash_api.cpp", + "csrc/kernels/get_mla_metadata.cu", + "csrc/kernels/mla_combine.cu", + "csrc/kernels/splitkv_mla.cu", + ], extra_compile_args={ "cxx": cxx_args + get_features_args(), "nvcc": append_nvcc_threads(