From 414a2f3eedeb5ad3c4a6e89d8641e059519cacc9 Mon Sep 17 00:00:00 2001 From: Jiashi Li <31004720+beginlner@users.noreply.github.com> Date: Fri, 21 Feb 2025 14:31:27 +0800 Subject: [PATCH] Initial commit i --- .gitignore | 5 + .gitmodules | 3 + LICENSE | 21 + README.md | 61 +++ csrc/cutlass | 1 + csrc/flash_api.cpp | 203 +++++++++ csrc/flash_fwd_mla_bf16_sm90.cu | 3 + csrc/flash_fwd_mla_kernel.h | 679 +++++++++++++++++++++++++++++++ csrc/flash_mla.h | 63 +++ csrc/named_barrier.h | 15 + csrc/softmax.h | 197 +++++++++ csrc/static_switch.h | 65 +++ csrc/utils.h | 238 +++++++++++ flash_mla/__init__.py | 6 + flash_mla/flash_mla_interface.py | 67 +++ setup.py | 78 ++++ tests/test_flash_mla.py | 114 ++++++ 17 files changed, 1819 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 LICENSE create mode 100644 README.md create mode 160000 csrc/cutlass create mode 100644 csrc/flash_api.cpp create mode 100644 csrc/flash_fwd_mla_bf16_sm90.cu create mode 100644 csrc/flash_fwd_mla_kernel.h create mode 100644 csrc/flash_mla.h create mode 100644 csrc/named_barrier.h create mode 100644 csrc/softmax.h create mode 100644 csrc/static_switch.h create mode 100644 csrc/utils.h create mode 100644 flash_mla/__init__.py create mode 100644 flash_mla/flash_mla_interface.py create mode 100644 setup.py create mode 100644 tests/test_flash_mla.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfe80b5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +build +*.so +*.egg-info/ +__pycache__/ +dist/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..8d501cb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "csrc/cutlass"] + path = csrc/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5c48bdc --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..bb55395 --- /dev/null +++ b/README.md @@ -0,0 +1,61 @@ +# FlashMLA + +FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. + +Currently released: +- BF16 +- Paged kvcache with block size of 64 + +## Quick start + +### Install + +```bash +python setup.py install +``` + +### Benchmark + +```bash +python tests/test_flash_mla.py +``` + +Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6. + +### Usage + +```python +from flash_mla import get_mla_metadata, flash_mla_with_kvcache + +tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + +for i in range(num_layers): + ... + o_i, lse_i = flash_mla_with_kvcache( + q_i, kvcache_i, block_table, cache_seqlens, dv, + tile_scheduler_metadata, num_splits, causal=True, + ) + ... +``` + +## Requirements + +- Hopper GPUs +- CUDA 12.3 and above +- PyTorch 2.0 and above + +## Acknowledgement + +FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. + +## Citation + +```bibtex +@misc{flashmla2025, + title={FlashMLA: Efficient MLA decoding kernel}, + author={Jiashi Li}, + year={2025}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}}, +} +``` diff --git a/csrc/cutlass b/csrc/cutlass new file mode 160000 index 0000000..afa1772 --- /dev/null +++ b/csrc/cutlass @@ -0,0 +1 @@ +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp new file mode 100644 index 0000000..5a1cb8e --- /dev/null +++ b/csrc/flash_api.cpp @@ -0,0 +1,203 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp + +#include +#include +#include +#include + +#include + +#include "flash_mla.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::vector +get_mla_metadata( + at::Tensor &seqlens_k, + const int num_heads_per_head_k, + const int num_heads_k +) { + // This should match the logic in the MLA kernel. + static constexpr int block_size_m = 64; + static constexpr int block_size_n = 64; + static constexpr int fixed_overhead_num_blocks = 5; + + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + + auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + Mla_metadata_params params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = block_size_n; + params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.num_sm_parts = num_sm_parts; + get_mla_metadata_func(params, stream); + + return {tile_scheduler_metadata, num_splits}; +} + +std::vector +mha_fwd_kvcache_mla( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size + c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const float softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits // batch_size + 1 +) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90); + + at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_ori = sizes[2]; + const int head_size = sizes[3]; + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int ngroups = num_heads_ori / num_heads_k; + const int seqlen_q = seqlen_q_ori * ngroups; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) + .reshape({batch_size, seqlen_q, num_heads, head_size}); + + int head_size_k = head_size; + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + + + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); + CHECK_SHAPE(seqlens_k, batch_size); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + Flash_fwd_mla_params params = {}; + // Set the sizes. + params.b = batch_size; + params.seqlen_q = seqlen_q; + params.cu_seqlens_k = seqlens_k.data_ptr(); + params.h = num_heads; + params.h_h_k_ratio = num_heads / num_heads_k; + params.ngroups = ngroups; + params.is_causal = is_causal; + params.d = head_size; + params.d_v = head_size_v; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.v_ptr = vcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.v_batch_stride = vcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(-3); + params.v_row_stride = vcache.stride(-3); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(-2); + params.v_head_stride = vcache.stride(-2); + params.o_head_stride = out.stride(-2); + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + CHECK_DEVICE(num_splits); + CHECK_CONTIGUOUS(num_splits); + params.num_splits_ptr = num_splits.data_ptr(); + + at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(head_size == 576); + run_mha_fwd_splitkv_mla(params, stream); + + out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) + .reshape({batch_size, num_heads_ori, seqlen_q_ori}); + + return {out, softmax_lse}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashMLA"; + m.def("get_mla_metadata", &get_mla_metadata); + m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); +} diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu new file mode 100644 index 0000000..35691f2 --- /dev/null +++ b/csrc/flash_fwd_mla_bf16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h new file mode 100644 index 0000000..55f6811 --- /dev/null +++ b/csrc/flash_fwd_mla_kernel.h @@ -0,0 +1,679 @@ +#pragma once + +#include +#include +#include +#include + +using namespace cute; + +#include "named_barrier.h" +#include "utils.h" +#include "softmax.h" +#include "static_switch.h" +#include "flash_mla.h" + + +template +constexpr auto getSmemLayoutK() { + constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; + + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } +} + +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + static_assert(kHeadDimV % 32 == 0); + static_assert(kHeadDimV <= kHeadDim); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + using TiledMmaO = decltype(make_tiled_mma( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::MN>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutV = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + + using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + using GmemLayoutAtom = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomO = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + GmemLayoutAtomO{}, + Layout>{})); // Val layout, 8 vals per store + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; + using GmemLayoutAtomOaccum = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store +}; + +namespace flash { + +using namespace cute; + +template +struct SharedStorageMLA { + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; + }; + struct { + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; + }; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + + // Epilogue + + const int split_offset = __ldg(params.num_splits_ptr + bidb); + + Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); + + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + if (tidx >= kNThreadsS) { return; } + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + ); +} + +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + clear(tOrO); + + flash::Softmax<2 * size<1>(tOrO)> softmax; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, params.scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, params.scale_softmax_log2) + : softmax.template softmax(tSrS, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(tSrS); + cute::copy(rP, tPsP); + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); + Tensor rP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, rP); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + +#pragma unroll 1 + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int hs = params.h * params.seqlen_q; + const int batch_idx = bidx / hs; + const int hs_idx = bidx % hs; + + const int split_offset = __ldg(params.num_splits_ptr + batch_idx); + const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; + FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); + if (actual_num_splits == 1) return; + + __shared__ ElementAccum sLseScale[kMaxSplits]; + + const index_t row_offset_lseaccum = split_offset * hs + hs_idx; + const index_t row_offset_lse = bidx; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, make_stride(hs)); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + float local_lse[kNLsePerThread]; + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; + } + + float max_lse = -INFINITY; + for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); + for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); + for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; + if (tidx == 0) gLSE(0) = global_lse; + + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + __syncthreads(); + + static_assert(kHeadDimV % kNThreads == 0); + constexpr int Elements = kHeadDimV / kNThreads; + const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape>{}, Stride<_1>{}); + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + Layout>>{}, + Layout>>{})); + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + for (int split = 0; split < actual_num_splits; ++split) { + cute::copy(tOgOaccum, tOrOaccum); + ElementAccum lse_scale = sLseScale[split]; + for (int i = 0; i < size(tOrO); ++i) { + tOrO(i) += lse_scale * tOrOaccum(i); + } + tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; + } + + Tensor rO = flash::convert_type(tOrO); + const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; + const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); + cute::copy(rO, gO); +} + +} // namespace flash + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); + const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + auto kernel = &flash::flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(SharedStorage); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); + + dim3 grid_combine(params.b * params.h * params.seqlen_q); + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { + auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< + typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + combine_kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + static_assert(Headdim == 576); + FLASH_ASSERT(params.d_v == 512); + FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; + run_flash_splitkv_fwd_mla>(params, stream); +} + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h new file mode 100644 index 0000000..2994cb7 --- /dev/null +++ b/csrc/flash_mla.h @@ -0,0 +1,63 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_mla_params { + using index_t = int64_t; + + int b, seqlen_q, d, d_v; + int h, h_h_k_ratio, ngroups; + bool is_causal; + float scale_softmax, scale_softmax_log2; + int *__restrict__ cu_seqlens_k; + + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ o_ptr; + void *__restrict__ softmax_lse_ptr; + + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t o_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t o_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t o_head_stride; + + int *__restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + int *__restrict__ tile_scheduler_metadata_ptr; + int num_sm_parts; + int *__restrict__ num_splits_ptr; + + void *__restrict__ softmax_lseaccum_ptr; + void *__restrict__ oaccum_ptr; +}; + +static constexpr int TileSchedulerMetaDataSize = 8; +// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +struct Mla_metadata_params { + int *__restrict__ seqlens_k_ptr; + int *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; + int batch_size; + int block_size_n; + int fixed_overhead_num_blocks; + int num_sm_parts; +}; + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/named_barrier.h b/csrc/named_barrier.h new file mode 100644 index 0000000..cefa936 --- /dev/null +++ b/csrc/named_barrier.h @@ -0,0 +1,15 @@ +#pragma once + +#include "cutlass/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + SReady = 1, + SoftmaxReady = 2, +}; + +} // flash diff --git a/csrc/softmax.h b/csrc/softmax.h new file mode 100644 index 0000000..4ab6ae9 --- /dev/null +++ b/csrc/softmax.h @@ -0,0 +1,197 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + +#pragma once + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } + return tensor; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scale_o); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scale_o; + clear(scale_o); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scale_o(mi) = scores_scale; + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scale_o; + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/csrc/static_switch.h b/csrc/static_switch.h new file mode 100644 index 0000000..f156adc --- /dev/null +++ b/csrc/static_switch.h @@ -0,0 +1,65 @@ +#pragma once + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +#define FLASH_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + exit(1); \ + } \ + } while(0) + + +#define FLASH_DEVICE_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while(0) + + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() diff --git a/csrc/utils.h b/csrc/utils.h new file mode 100644 index 0000000..3b8dd52 --- /dev/null +++ b/csrc/utils.h @@ -0,0 +1,238 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h + +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py new file mode 100644 index 0000000..51b8600 --- /dev/null +++ b/flash_mla/__init__.py @@ -0,0 +1,6 @@ +__version__ = "1.0.0" + +from flash_mla.flash_mla_interface import ( + get_mla_metadata, + flash_mla_with_kvcache, +) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py new file mode 100644 index 0000000..2f3aa46 --- /dev/null +++ b/flash_mla/flash_mla_interface.py @@ -0,0 +1,67 @@ +from typing import Optional, Tuple + +import torch + +import flash_mla_cuda + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..bab4792 --- /dev/null +++ b/setup.py @@ -0,0 +1,78 @@ +import os +from pathlib import Path +from datetime import datetime +import subprocess + +from setuptools import setup, find_packages + +from torch.utils.cpp_extension import ( + BuildExtension, + CUDAExtension, +) + + +def append_nvcc_threads(nvcc_extra_args): + nvcc_threads = os.getenv("NVCC_THREADS") or "32" + return nvcc_extra_args + ["--threads", nvcc_threads] + + +subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + +cc_flag = [] +cc_flag.append("-gencode") +cc_flag.append("arch=compute_90a,code=sm_90a") + +this_dir = os.path.dirname(os.path.abspath(__file__)) + +ext_modules = [] +ext_modules.append( + CUDAExtension( + name="flash_mla_cuda", + sources=[ + "csrc/flash_api.cpp", + "csrc/flash_fwd_mla_bf16_sm90.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-DNDEBUG", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v,--register-usage-level=10" + ] + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "cutlass" / "include", + ], + ) +) + + +try: + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() +except Exception as _: + now = datetime.now() + date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") + rev = '+' + date_time_str + + +setup( + name="flash_mla", + version="1.0.0" + rev, + packages=find_packages(include=['flash_mla']), + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, +) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py new file mode 100644 index 0000000..a057f6a --- /dev/null +++ b/tests/test_flash_mla.py @@ -0,0 +1,114 @@ +import math +import random + +import torch +import triton + +from flash_mla import get_mla_metadata, flash_mla_with_kvcache + + +def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + RMSE = ((x - y) * (x - y)).mean().sqrt().item() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + amax_diff = (x - y).abs().max().item() + # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + assert cos_diff < 1e-5 + + +@torch.inference_mode() +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, blocked_k, block_table, cache_seqlens, dv, + tile_scheduler_metadata, num_splits, causal=causal, + ) + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_flash, lse_flash = flash_mla() + out_torch, lse_torch = ref_mla() + cal_diff(out_flash, out_torch, "out") + cal_diff(lse_flash, lse_torch, "lse") + + t = triton.testing.do_bench(flash_mla, fast_flush=False) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + + +if __name__ == "__main__": + dtype = torch.bfloat16 + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + + h_kv = 1 + d, dv = 576, 512 + causal = True + + for b in [128]: + for s in [4096, 8192]: + for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 + for s_q in [1, 2]: # MTP = 1, 2 + for varlen in [False, True]: + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)