diff --git a/.gitignore b/.gitignore
index 5f9e980..982daef 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,4 @@ __pycache__/
dist/
*perf.csv
*.png
+/.vscode
diff --git a/README.md b/README.md
index 1dad9ef..6de1640 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,28 @@
# FlashMLA
+## Performance Update (2025.04.22)
+
+We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀
+
+Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up here:
+
+The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance.
+
+## Introduction
+
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
- BF16, FP16
- Paged kvcache with block size of 64
+## Requirements
+
+- Hopper GPUs
+- CUDA 12.3 and above
+ - **But we highly recommend 12.8 or above for the best performance**
+- PyTorch 2.0 and above
+
## Quick start
### Install
@@ -20,7 +37,9 @@ python setup.py install
python tests/test_flash_mla.py
```
-Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
+It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
+
+Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance.
### Usage
@@ -38,13 +57,6 @@ for i in range(num_layers):
...
```
-## Requirements
-
-- Hopper GPUs
-- CUDA 12.3 and above
- - **But we highly recommend 12.8 or above for the best performance**
-- PyTorch 2.0 and above
-
## Acknowledgement
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
@@ -91,7 +103,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c
```bibtex
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
- author={Jiashi Li},
+ author={Jiashi Li, Shengyu Liu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py
index 14e1352..95c75f2 100644
--- a/benchmark/bench_flash_mla.py
+++ b/benchmark/bench_flash_mla.py
@@ -435,7 +435,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
- if target not in ["flash_infer", "flash_mla_triton"]:
+ if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]:
# flash_infer has a different lse return value
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp
index 9015735..a87e1ab 100644
--- a/csrc/flash_api.cpp
+++ b/csrc/flash_api.cpp
@@ -10,8 +10,11 @@
#include
-#include "flash_mla.h"
-#include "static_switch.h"
+#include "kernels/config.h"
+#include "kernels/get_mla_metadata.h"
+#include "kernels/mla_combine.h"
+#include "kernels/params.h"
+#include "kernels/splitkv_mla.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
@@ -23,11 +26,6 @@ get_mla_metadata(
const int num_heads_per_head_k,
const int num_heads_k
) {
- // This should match the logic in the MLA kernel.
- static constexpr int block_size_m = 64;
- static constexpr int block_size_n = 64;
- static constexpr int fixed_overhead_num_blocks = 5;
-
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
@@ -38,7 +36,7 @@ get_mla_metadata(
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
- int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
+ int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
@@ -52,10 +50,10 @@ get_mla_metadata(
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
- params.block_size_n = block_size_n;
- params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
+ params.block_size_n = Config::PAGE_BLOCK_SIZE;
+ params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS;
params.num_sm_parts = num_sm_parts;
- get_mla_metadata_func(params, stream);
+ run_get_mla_metadata_kernel(params, stream);
return {tile_scheduler_metadata, num_splits};
}
@@ -64,7 +62,6 @@ std::vector
mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
- std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
@@ -73,138 +70,141 @@ mha_fwd_kvcache_mla(
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
+ // Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);
- at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
-
+ // Check data types
auto q_dtype = q.dtype();
+ TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
+ TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
+ TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
+ TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
+ TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
- CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
+ // Check device
+ CHECK_DEVICE(q);
+ CHECK_DEVICE(kcache);
+ CHECK_DEVICE(seqlens_k);
+ CHECK_DEVICE(block_table);
+ CHECK_DEVICE(tile_scheduler_metadata);
+ CHECK_DEVICE(num_splits);
+ // Check layout
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
- TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
-
- CHECK_DEVICE(block_table);
- TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
+ CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
+ CHECK_CONTIGUOUS(tile_scheduler_metadata);
+ CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
- const int num_heads_ori = sizes[2];
- const int head_size = sizes[3];
- TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
- TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
+ const int num_heads_q = sizes[2];
+ const int head_size_k = sizes[3];
+ TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
+ TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
- TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+ TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
- const int ngroups = num_heads_ori / num_heads_k;
- const int seqlen_q = seqlen_q_ori * ngroups;
+ const int num_q_heads_per_hk = num_heads_q / num_heads_k;
+ const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
- q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
- .reshape({batch_size, seqlen_q, num_heads, head_size});
+ q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
+ .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
- int head_size_k = head_size;
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
+ CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
- if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
- CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
-
-
- TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
- CHECK_DEVICE(seqlens_k);
- CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
+ CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
+ TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
+ CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
- at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
- at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+ at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
+ at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
+ CHECK_CONTIGUOUS(softmax_lse);
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
- params.seqlen_q = seqlen_q;
- params.cu_seqlens_k = seqlens_k.data_ptr();
- params.h = num_heads;
- params.h_h_k_ratio = num_heads / num_heads_k;
- params.ngroups = ngroups;
+ params.s_q = seqlen_q_ori;
+ params.q_seq_per_hk = q_seq_per_hk;
+ params.seqlens_k_ptr = seqlens_k.data_ptr();
+ params.h_q = num_heads_q;
+ params.h_k = num_heads_k;
+ params.num_blocks = num_blocks;
+ params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
- params.d = head_size;
+ params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
- params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
- params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
- params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
- params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
-
- TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
- TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
- CHECK_DEVICE(tile_scheduler_metadata);
- CHECK_CONTIGUOUS(tile_scheduler_metadata);
+
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr();
params.num_sm_parts = tile_scheduler_metadata.size(0);
- TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
- CHECK_DEVICE(num_splits);
- CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr();
- at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
- at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
+ const int total_num_splits = batch_size + params.num_sm_parts;
+ at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
+ at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
+ CHECK_CONTIGUOUS(softmax_lse_accum);
+ CHECK_CONTIGUOUS(out_accum);
+ params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
- TORCH_CHECK(head_size == 576);
-
+ TORCH_CHECK(head_size_k == 576);
if (q_dtype == torch::kBFloat16) {
- run_mha_fwd_splitkv_mla(params, stream);
- }
- #ifndef FLASH_MLA_DISABLE_FP16
- else if (q_dtype == torch::kHalf) {
- run_mha_fwd_splitkv_mla(params, stream);
- }
- #endif
- else {
+ run_flash_splitkv_mla_kernel(params, stream);
+ run_flash_mla_combine_kernel(params, stream);
+ } else if (q_dtype == torch::kHalf) {
+#ifdef FLASH_MLA_DISABLE_FP16
+ TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
+#else
+ run_flash_splitkv_mla_kernel(params, stream);
+ run_flash_mla_combine_kernel(params, stream);
+#endif
+ } else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
- out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
- .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
- softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
- .reshape({batch_size, num_heads_ori, seqlen_q_ori});
+ out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
+ .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
+ softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
+ .reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, softmax_lse};
}
diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu
deleted file mode 100644
index 35691f2..0000000
--- a/csrc/flash_fwd_mla_bf16_sm90.cu
+++ /dev/null
@@ -1,3 +0,0 @@
-#include "flash_fwd_mla_kernel.h"
-
-template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu
deleted file mode 100644
index abdaf7b..0000000
--- a/csrc/flash_fwd_mla_fp16_sm90.cu
+++ /dev/null
@@ -1,3 +0,0 @@
-#include "flash_fwd_mla_kernel.h"
-
-template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h
deleted file mode 100644
index d96acd8..0000000
--- a/csrc/flash_fwd_mla_kernel.h
+++ /dev/null
@@ -1,603 +0,0 @@
-#pragma once
-
-#include
-#include
-#include
-#include
-
-using namespace cute;
-
-#include "named_barrier.h"
-#include "utils.h"
-#include "softmax.h"
-#include "static_switch.h"
-#include "flash_mla.h"
-
-
-template
-constexpr auto getSmemLayoutK() {
- constexpr int headSizeBytes = sizeof(PrecType) * DIM;
- constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
-
- if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
- return GMMA::Layout_K_SW128_Atom{};
- } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
- return GMMA::Layout_K_SW64_Atom{};
- } else {
- return GMMA::Layout_K_SW32_Atom{};
- }
-}
-
-template
-struct Flash_fwd_kernel_traits_mla {
- using Element = elem_type;
- using ElementAccum = float;
- using index_t = int64_t;
-
- static constexpr int kNWarps = kNWarps_;
- static constexpr int kNThreads = kNWarps * 32;
- static constexpr int kNWarpsS = 4;
- static constexpr int kNThreadsS = kNWarpsS * 32;
-
- static constexpr int kBlockM = kBlockM_;
- static constexpr int kBlockN = kBlockN_;
- static constexpr int kHeadDim = kHeadDim_;
- static_assert(kHeadDim % 32 == 0);
- static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
- static_assert(kHeadDimV % 32 == 0);
- static_assert(kHeadDimV <= kHeadDim);
- static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
- static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
-
- using TiledMma = decltype(make_tiled_mma(
- cute::GMMA::ss_op_selector, Int, Int>,
- GMMA::Major::K, GMMA::Major::K>(),
- Layout, _1, _1>>{}));
-
- static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
- using TiledMmaO = decltype(make_tiled_mma(
- cute::GMMA::rs_op_selector, Int, Int>,
- GMMA::Major::K, GMMA::Major::MN>(),
- Layout, Int, _1>>{}));
-
- using SmemLayoutQ = decltype(tile_to_shape(
- getSmemLayoutK(),
- Shape, Int>{}));
-
- using SmemLayoutK = decltype(tile_to_shape(
- getSmemLayoutK(),
- Shape, Int>{}));
-
- using SmemLayoutV = decltype(tile_to_shape(
- getSmemLayoutK(),
- Shape, Int>{}));
- using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{})));
-
- using SmemLayoutP = Layout, Int, _1, Int>>;
- using SmemLayoutRow = Layout>, Stride<_1, _2>>;
-
- using SmemLayoutAtomO = decltype(composition(
- Swizzle{},
- Layout, Int>, Stride, _1>>{}));
- using SmemLayoutO = decltype(tile_to_shape(
- SmemLayoutAtomO{},
- Shape, Int>{}));
- using SmemCopyAtomO = Copy_Atom;
- using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>;
-
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
- static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
- using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL;
- static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
- static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
-
- using GmemLayoutAtom = Layout<
- Shape, Int>,
- Stride, _1>>;
- using GmemTiledCopy = decltype(make_tiled_copy(
- Copy_Atom{},
- GmemLayoutAtom{},
- Layout>{})); // Val layout, 8 vals per read
-
- using GmemLayoutAtomO = Layout<
- Shape, Int>,
- Stride, _1>>;
- using GmemTiledCopyO = decltype(make_tiled_copy(
- Copy_Atom, Element>{},
- GmemLayoutAtomO{},
- Layout>{})); // Val layout, 8 vals per store
-
- static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
- static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
- using GmemLayoutAtomOaccum = Layout<
- Shape, Int>,
- Stride, _1>>;
- using GmemTiledCopyOaccum = decltype(make_tiled_copy(
- Copy_Atom, ElementAccum>{},
- GmemLayoutAtomOaccum{},
- Layout>{})); // Val layout, 4 vals per store
-};
-
-namespace flash {
-
-using namespace cute;
-
-template
-struct SharedStorageMLA {
- union {
- struct {
- cute::array_aligned> smem_q;
- cute::array_aligned * 2> smem_k; // Double buffer
- cute::array_aligned> smem_p;
- cute::array_aligned> smem_scale;
- };
- struct {
- cute::array_aligned> smem_max;
- cute::array_aligned> smem_sum;
- cute::array_aligned> smem_o;
- };
- };
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template
-__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
- SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
- constexpr int kBlockM = Kernel_traits::kBlockM;
- constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
- constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
- using Element = typename Kernel_traits::Element;
- using ElementAccum = typename Kernel_traits::ElementAccum;
- using index_t = typename Kernel_traits::index_t;
-
- const int tidx = threadIdx.x;
-
- typename Kernel_traits::TiledMmaO tiled_mma_o;
- auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
-
- // Epilogue
-
- const int split_offset = __ldg(params.num_splits_ptr + bidb);
-
- Tensor lse = softmax.template normalize_softmax_lse*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
-
- using ElementO = std::conditional_t;
- Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
- // Partition sO to match the accumulator partitioning
- using SmemTiledCopyO = std::conditional_t<
- !Split,
- typename Kernel_traits::SmemCopyAtomO,
- typename Kernel_traits::SmemCopyAtomOaccum
- >;
- auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
- auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
- Tensor rO = flash::convert_type(tOrO);
- Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
- Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
-
- __syncthreads();
-
- cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
-
- const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
- const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
- const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
- const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
-
- Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
- Shape, Int>{},
- make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
- Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
- Shape>{}, Stride<_1>{});
-
- using GmemTiledCopyO = std::conditional_t;
- GmemTiledCopyO gmem_tiled_copy_Oaccum;
- auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
- Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
- Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
-
- __syncthreads();
-
- if (tidx >= kNThreadsS) { return; }
-
- Tensor tOrOaccum = make_tensor(shape(tOgOaccum));
- cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
-
- Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
- Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
- CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
- if (get<1>(taccOcO_row(0)) == 0) {
-#pragma unroll
- for (int mi = 0; mi < size(lse); ++mi) {
- const int row = get<0>(taccOcO_row(mi));
- if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
- }
- }
-
- // Construct identity layout for sO
- Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- // Repeat the partitioning with identity layouts
- Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
- Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum)));
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
- flash::copy*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
- );
-}
-
-template
-__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms,
- const int bidb, const int bidh, const int m_block,
- const int n_split_idx, const int seqlen_k,
- const int n_block_min, const int n_block_max, const bool NoSplit,
- SharedStorage &shared_storage) {
- constexpr int kBlockM = Kernel_traits::kBlockM;
- constexpr int kBlockN = Kernel_traits::kBlockN;
- constexpr int kHeadDim = Kernel_traits::kHeadDim;
- constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
- constexpr int kNThreads = Kernel_traits::kNThreads;
- constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
- static_assert(kNThreads == 256 and kNThreadsS == 128);
- using Element = typename Kernel_traits::Element;
- using index_t = typename Kernel_traits::index_t;
-
- const int tidx = threadIdx.x;
- int n_block = n_block_max - 1;
-
- Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
- Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
- Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
- Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
-
- Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
- Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
- Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
- Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
- Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
- Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
- Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
- Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
-
- typename Kernel_traits::TiledMmaO tiled_mma_o;
- auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
- Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
- Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
- clear(tOrO);
-
- flash::Softmax<2 * size<1>(tOrO)> softmax;
-
- int warp_group_idx = cutlass::canonical_warp_group_idx();
- if (warp_group_idx == 0) {
- typename Kernel_traits::TiledMma tiled_mma;
- auto thr_mma = tiled_mma.get_thread_slice(tidx);
- Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
- Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
-
- if (n_block % 2 == 1) {
- // Double buffer for sK
- constexpr int sK_offset = size(sK);
- tSrK.data() = tSrK.data() + sK_offset / 8;
- tOrVt.data() = tOrVt.data() + sK_offset / 8;
- }
-
- // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
- // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
- // We will have at least 1 "masking" iteration.
- // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
- // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
- constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
-#pragma unroll 1
- for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
- __syncthreads();
-
- Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
- flash::gemm*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
-
- const bool is_masking_step = masking_step > 0;
- const bool is_first_masking_step = masking_step == n_masking_steps;
-
- if (is_masking_step) {
- Tensor cS = make_identity_tensor(Shape, Int>{});
- Tensor tScS = thr_mma.partition_C(cS);
-#pragma unroll
- for (int i = 0; i < size(tSrS); ++i) {
- if constexpr (!Is_causal) { // Just masking based on col
- if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
- } else {
- // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
- // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
- int row = int(get<0>(tScS(i)));
- int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
- if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
- }
- }
- }
-
- // We have key_padding_mask so we'll need to Check_inf
- Tensor scale_o = is_first_masking_step
- ? softmax.template softmax*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
- : is_masking_step ?
- softmax.template softmax*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
- : softmax.template softmax*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
-
- Tensor rP = flash::convert_type(tSrS);
- cute::copy(rP, tPsP);
- cute::copy(scale_o, tScale_osScale_o);
-
- cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady));
-
- flash::rescale_o(tOrO, scale_o);
-
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout()));
- flash::gemm*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
-
- // Double buffer for sK
- const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
- tSrK.data() = tSrK.data() + sK_offset / 8;
- tOrVt.data() = tOrVt.data() + sK_offset / 8;
- }
-
- cute::copy(softmax.row_max, tRow_maxsRow_max);
- cute::copy(softmax.row_sum, tRow_sumsRow_sum);
- cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady));
- } else {
- const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
- int cur_block_table = __ldg(&block_table[n_block]);
-
- const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
- Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q),
- Shape, Int>{},
- make_stride(params.q_row_stride, _1{}));
- typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
- auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
- Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
- Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
- Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
- Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
- Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ)));
-
- // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
- flash::copy*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
- params.seqlen_q - m_block * kBlockM);
-
- const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
- Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k),
- Shape, Int>{},
- make_stride(params.k_row_stride, _1{}));
- typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
- auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
- Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
- Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
- Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
- Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
- Tensor tKpK = make_tensor(make_shape(size<2>(tKsK)));
-
- if (n_block % 2 == 1) {
- // Double buffer for sK
- constexpr int sK_offset = size(sK);
- tKsK.data() = tKsK.data() + sK_offset;
- tOrVt.data() = tOrVt.data() + sK_offset / 8;
- }
-
- // We need to clear the sK smem tiles because K is V.
- const index_t offset_k = cur_block_table * params.k_batch_stride;
- tKgK.data() = tKgK.data() + offset_k;
- flash::copy*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
- seqlen_k - n_block * kBlockN);
- tKgK.data() = tKgK.data() + -offset_k;
- cute::cp_async_fence();
-
- if (n_block - 1 >= n_block_min) {
- cur_block_table = __ldg(&block_table[n_block - 1]);
- }
-
-#pragma unroll 1
- for (; n_block >= n_block_min; --n_block) {
- flash::cp_async_wait<0>();
- __syncthreads();
-
- if (n_block - 1 >= n_block_min) {
- // Double buffer for sK
- const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
- tKsK.data() = tKsK.data() + sK_offset;
-
- const index_t offset_k = cur_block_table * params.k_batch_stride;
- tKgK.data() = tKgK.data() + offset_k;
- flash::copy*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
- tKgK.data() = tKgK.data() + -offset_k;
- cute::cp_async_fence();
- }
-
- cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady));
-
- if (n_block - 2 >= n_block_min) {
- cur_block_table = __ldg(&block_table[n_block - 2]);
- }
-
- typename Kernel_traits::TiledMma tiled_mma;
- auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout();
- Tensor rP = make_tensor(tSrS_layout);
- Tensor scale_o = make_tensor(Shape<_2>{});
- cute::copy(tScale_osScale_o, scale_o);
- cute::copy(tPsP, rP);
-
- flash::rescale_o(tOrO, scale_o);
-
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout()));
- flash::gemm*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
-
- // Double buffer for sK
- const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
- tOrVt.data() = tOrVt.data() + sK_offset / 8;
- }
-
- cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady));
- cute::copy(tRow_maxsRow_max, softmax.row_max);
- cute::copy(tRow_sumsRow_sum, softmax.row_sum);
- }
-
- if (NoSplit)
- store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
- else
- store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
-}
-
-template
-__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
-flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
- constexpr int kBlockN = Kernel_traits::kBlockN;
- const int m_block = blockIdx.x;
- const int bidh = blockIdx.y;
- const int partition_idx = blockIdx.z;
-
- extern __shared__ char shared_memory[];
- auto &shared_storage = *reinterpret_cast(shared_memory);
-
- int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
- int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr));
- int begin_idx = tile_scheduler_metadata.x;
- int begin_seqlen = tile_scheduler_metadata.y;
- int end_idx = tile_scheduler_metadata.z;
- int end_seqlen = tile_scheduler_metadata.w;
- if (begin_idx >= params.b) return;
- int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
-
-#pragma unroll 1
- for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
- const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
- const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
- const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
- const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
- const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
- if (batch_id > begin_idx) {
- __syncthreads(); // Barrier between two tiles.
- }
- flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template
-__global__ void __launch_bounds__(256, 1, 1)
-flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
- constexpr int kNThreads = 128;
-
- const int tidx = threadIdx.x;
- const int bidx = blockIdx.x;
- const int hs = params.h * params.seqlen_q;
- const int batch_idx = bidx / hs;
- const int hs_idx = bidx % hs;
-
- const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
- const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
- FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
- if (actual_num_splits == 1) return;
-
- __shared__ ElementAccum sLseScale[kMaxSplits];
-
- const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
- const index_t row_offset_lse = bidx;
- Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
- Shape>{}, make_stride(hs));
- Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse),
- Shape<_1>{}, Stride<_1>{});
-
- int warp_idx = cutlass::canonical_warp_idx_sync();
- if (warp_idx == 0) {
- constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
-
- float local_lse[kNLsePerThread];
- for (int i = 0; i < kNLsePerThread; ++i) {
- const int split = i * 32 + tidx;
- local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
- }
-
- float max_lse = -INFINITY;
- for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
- for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
- max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
-
- float sum_lse = 0;
- for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
- for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
-
- float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
- if (tidx == 0) gLSE(0) = global_lse;
-
- for (int i = 0; i < kNLsePerThread; ++i) {
- const int split = i * 32 + tidx;
- if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
- }
- }
- __syncthreads();
-
- static_assert(kHeadDimV % kNThreads == 0);
- constexpr int Elements = kHeadDimV / kNThreads;
- const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
- Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum),
- Shape>{}, Stride<_1>{});
- using GmemTiledCopyOaccum = decltype(make_tiled_copy(
- Copy_Atom, ElementAccum>{},
- Layout>>{},
- Layout>>{}));
- GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
- auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
- Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
- Tensor tOrOaccum = make_tensor(shape(tOgOaccum));
- Tensor tOrO = make_tensor(shape(tOgOaccum));
- clear(tOrO);
-
- for (int split = 0; split < actual_num_splits; ++split) {
- cute::copy(tOgOaccum, tOrOaccum);
- ElementAccum lse_scale = sLseScale[split];
- for (int i = 0; i < size(tOrO); ++i) {
- tOrO(i) += lse_scale * tOrOaccum(i);
- }
- tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
- }
-
- Tensor rO = flash::convert_type(tOrO);
- const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
- const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
- auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
- Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{});
- cute::copy(rO, gO);
-}
-
-} // namespace flash
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template
-void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
- FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
- const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- auto kernel = &flash::flash_fwd_splitkv_mla_kernel;
- constexpr size_t smem_size = sizeof(SharedStorage);
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- kernel<<>>(params);
- });
- CHECK_CUDA_KERNEL_LAUNCH();
-
- dim3 grid_combine(params.b * params.h * params.seqlen_q);
- MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
- auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
- typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
- combine_kernel<<>>(params);
- });
- CHECK_CUDA_KERNEL_LAUNCH();
-}
-
-template
-void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
- static_assert(Headdim == 576);
- FLASH_ASSERT(params.d_v == 512);
- FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
- using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
- run_flash_splitkv_fwd_mla>(params, stream);
-}
diff --git a/csrc/kernels/config.h b/csrc/kernels/config.h
new file mode 100644
index 0000000..c9ce159
--- /dev/null
+++ b/csrc/kernels/config.h
@@ -0,0 +1,13 @@
+#pragma once
+
+namespace Config {
+
+static constexpr int BLOCK_SIZE_M = 64;
+static constexpr int PAGE_BLOCK_SIZE = 64;
+
+static constexpr int HEAD_DIM_K = 576;
+static constexpr int HEAD_DIM_V = 512;
+
+static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5;
+
+}
diff --git a/csrc/flash_fwd_mla_metadata.cu b/csrc/kernels/get_mla_metadata.cu
similarity index 79%
rename from csrc/flash_fwd_mla_metadata.cu
rename to csrc/kernels/get_mla_metadata.cu
index 82f5b5a..6b78f9b 100644
--- a/csrc/flash_fwd_mla_metadata.cu
+++ b/csrc/kernels/get_mla_metadata.cu
@@ -1,8 +1,11 @@
-#include "flash_fwd_mla_kernel.h"
+#include "get_mla_metadata.h"
-static constexpr int MaxBatchSize = 4096;
+#include
+#include
-__global__ void __launch_bounds__(256, 1, 1)
+#include "utils.h"
+
+__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
@@ -12,8 +15,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
- __shared__ int num_blocks_shared[MaxBatchSize];
- __shared__ int num_splits_shared[MaxBatchSize];
+ extern __shared__ int shared_mem[];
+ int* num_blocks_shared = shared_mem; // [batch_size]
+ int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
@@ -27,7 +31,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
__syncwarp();
if (threadIdx.x == 0) {
- int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
+ int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
@@ -70,8 +74,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
-void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
- FLASH_ASSERT(params.batch_size < MaxBatchSize);
- get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
+void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) {
+ int smem_size = sizeof(int) * (params.batch_size*2+1);
+ CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+ get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
-}
\ No newline at end of file
+}
diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/kernels/get_mla_metadata.h
new file mode 100644
index 0000000..5130581
--- /dev/null
+++ b/csrc/kernels/get_mla_metadata.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#include "params.h"
+
+void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream);
diff --git a/csrc/kernels/mla_combine.cu b/csrc/kernels/mla_combine.cu
new file mode 100644
index 0000000..b6ba8f8
--- /dev/null
+++ b/csrc/kernels/mla_combine.cu
@@ -0,0 +1,207 @@
+#include "mla_combine.h"
+
+#include
+#include
+#include
+#include
+
+#include "params.h"
+#include "utils.h"
+#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
+
+using namespace cute;
+
+template
+__global__ void __launch_bounds__(NUM_THREADS)
+flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
+ // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
+ // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
+ static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
+ const int batch_idx = blockIdx.x;
+ const int m_block_idx = blockIdx.y;
+ const int warp_idx = threadIdx.x / 32;
+ const int lane_idx = threadIdx.x % 32;
+
+ const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
+ const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
+ const int my_num_splits = end_split_idx - start_split_idx;
+ FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
+ if (my_num_splits == 1) {
+ return;
+ }
+
+ const int num_q_seqs = params.q_seq_per_hk * params.h_k;
+ const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M);
+ Tensor gLseAccum = make_tensor(
+ make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
+ Shape, Int>{},
+ make_stride(num_q_seqs, _1{})
+ );
+ Tensor gLse = make_tensor(
+ make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
+ Shape>{},
+ Stride<_1>{}
+ );
+
+ extern __shared__ float smem_buf[];
+ Tensor sLseScale = make_tensor(
+ make_smem_ptr(smem_buf),
+ Shape, Int>{},
+ Stride, _1>{} // +1 to avoid bank conflict
+ );
+
+ // Wait for the previous kernel (the MLA kernel) to finish
+ cudaGridDependencySynchronize();
+
+ // Read gLseAccum into sLseScale
+ {
+ #pragma unroll 4
+ for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) {
+ int split_idx = elem_idx / BLOCK_SIZE_M;
+ int seq_idx = elem_idx % BLOCK_SIZE_M;
+ sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY;
+ }
+ __syncthreads();
+ }
+
+ if (warp_idx >= num_cur_valid_q_seqs)
+ return;
+
+ // Warp #i gathers LseAccum for seq #i
+ {
+ constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
+ float local_lse[NUM_LSE_PER_THREAD];
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
+ const int split_idx = i*32 + lane_idx;
+ local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY;
+ }
+
+ float max_lse = -INFINITY;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
+ max_lse = max(max_lse, local_lse[i]);
+ CUTLASS_PRAGMA_UNROLL
+ for (int offset = 16; offset >= 1; offset /= 2)
+ max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
+ max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
+
+ float sum_lse = 0;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
+ sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
+ CUTLASS_PRAGMA_UNROLL
+ for (int offset = 16; offset >= 1; offset /= 2)
+ sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
+
+ float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse;
+ if (lane_idx == 0)
+ gLse(warp_idx) = global_lse / (float)M_LOG2E;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
+ const int split_idx = i*32 + lane_idx;
+ if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse);
+ }
+ }
+
+ __syncwarp();
+
+ // Warp #i accumulates activation for seq #i
+ {
+ const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V;
+ Tensor gOaccum = make_tensor(
+ make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum),
+ Shape, Int>{},
+ make_stride(num_q_seqs*HEAD_DIM_V, _1{})
+ );
+
+ static_assert(HEAD_DIM_V % 32 == 0);
+ constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32;
+ float result[ELEMS_PER_THREAD];
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ELEMS_PER_THREAD; ++i)
+ result[i] = 0.0f;
+
+ #pragma unroll 2
+ for (int split = 0; split < my_num_splits; ++split) {
+ float lse_scale = sLseScale(warp_idx, split);
+ if (lse_scale != 0.f) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
+ result[i] += lse_scale * gOaccum(split, lane_idx + i*32);
+ }
+ }
+ }
+
+ cudaTriggerProgrammaticLaunchCompletion();
+
+ const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx;
+ const int k_head_idx = q_seq_idx / params.q_seq_per_hk;
+ auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride;
+ Tensor gO = make_tensor(
+ make_gmem_ptr(o_ptr),
+ Shape>{},
+ Stride<_1>{}
+ );
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ELEMS_PER_THREAD; ++i)
+ gO(lane_idx+i*32) = (ElementT)result[i];
+ }
+}
+
+
+#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
+ [&] { \
+ if (NUM_SPLITS <= 32) { \
+ constexpr static int NAME = 32; \
+ return __VA_ARGS__(); \
+ } else if (NUM_SPLITS <= 64) { \
+ constexpr static int NAME = 64; \
+ return __VA_ARGS__(); \
+ } else if (NUM_SPLITS <= 96) { \
+ constexpr static int NAME = 96; \
+ return __VA_ARGS__(); \
+ } else if (NUM_SPLITS <= 128) { \
+ constexpr static int NAME = 128; \
+ return __VA_ARGS__(); \
+ } else if (NUM_SPLITS <= 160) { \
+ constexpr static int NAME = 160; \
+ return __VA_ARGS__(); \
+ } else { \
+ FLASH_ASSERT(false); \
+ } \
+ }()
+
+
+template
+void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
+ MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
+ constexpr int BLOCK_SIZE_M = 8;
+ constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
+ constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
+ auto combine_kernel = &flash_fwd_mla_combine_kernel;
+ CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+ // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
+ cudaLaunchAttribute attribute[1];
+ attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
+ attribute[0].val.programmaticStreamSerializationAllowed = 1;
+ cudaLaunchConfig_t combine_kernel_config = {
+ dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1),
+ dim3(NUM_THREADS, 1, 1),
+ smem_size,
+ stream,
+ attribute,
+ 1
+ };
+ cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params);
+ });
+ CHECK_CUDA_KERNEL_LAUNCH();
+}
+
+template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
+
+#ifndef FLASH_MLA_DISABLE_FP16
+template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
+#endif
\ No newline at end of file
diff --git a/csrc/kernels/mla_combine.h b/csrc/kernels/mla_combine.h
new file mode 100644
index 0000000..69035e9
--- /dev/null
+++ b/csrc/kernels/mla_combine.h
@@ -0,0 +1,6 @@
+#pragma once
+
+#include "params.h"
+
+template
+void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
diff --git a/csrc/flash_mla.h b/csrc/kernels/params.h
similarity index 71%
rename from csrc/flash_mla.h
rename to csrc/kernels/params.h
index 2994cb7..3b4e254 100644
--- a/csrc/flash_mla.h
+++ b/csrc/kernels/params.h
@@ -5,39 +5,41 @@
struct Flash_fwd_mla_params {
using index_t = int64_t;
- int b, seqlen_q, d, d_v;
- int h, h_h_k_ratio, ngroups;
+ int b; // batch size
+ int s_q;
+ int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q
+ int d, d_v; // K/V dimension
+ int h_q, h_k; // The number of Q/K heads
+ int num_blocks; // Number of blocks in total
+ int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
- int *__restrict__ cu_seqlens_k;
-
+
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
- void *__restrict__ v_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
- index_t v_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
- index_t v_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
- index_t v_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
+ int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
+ int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
@@ -45,11 +47,6 @@ struct Flash_fwd_mla_params {
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template
-void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
-
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
@@ -59,5 +56,3 @@ struct Mla_metadata_params {
int fixed_overhead_num_blocks;
int num_sm_parts;
};
-
-void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu
new file mode 100644
index 0000000..ff29305
--- /dev/null
+++ b/csrc/kernels/splitkv_mla.cu
@@ -0,0 +1,1350 @@
+#include
+
+#include "params.h"
+#include "utils.h"
+#include "config.h"
+#include "traits.h"
+
+using namespace cute;
+using cutlass::arch::NamedBarrier;
+
+// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking
+// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2)
+// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM
+static constexpr float MAX_INIT_VAL_SM = -1e30f;
+static constexpr float MAX_INIT_VAL = -1e33f;
+
+
+__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
+ // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
+ // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
+ int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
+ return row_idx;
+}
+
+// Launch TMA copy for a range of KV tile
+// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64
+template<
+ int START_HEAD_DIM_TILE_IDX,
+ int END_HEAD_DIM_TILE_IDX,
+ typename TMA_K_OneTile,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1
+>
+__forceinline__ __device__ void launch_kv_tiles_copy_tma(
+ Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
+ Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled
+ TMA_K_OneTile &tma_K,
+ TMABarrier* barriers_K,
+ int idx_in_warpgroup
+) {
+ if (idx_in_warpgroup == 0) {
+ auto thr_tma = tma_K.get_slice(_0{});
+ Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{});
+ Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int{});
+ cute::copy(tma_K.with(reinterpret_cast(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV);
+ if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {
+ launch_kv_tiles_copy_tma(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup);
+ }
+ }
+}
+
+// Prefetch some KV tiles
+// Currently this is not used because it leads to performance degradation
+template<
+ int START_HEAD_DIM_TILE_IDX,
+ int END_HEAD_DIM_TILE_IDX,
+ typename TMA_K_OneTile,
+ typename Engine0, typename Layout0
+>
+__forceinline__ __device__ void prefetch_kv_tiles(
+ Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
+ TMA_K_OneTile &tma_K,
+ int idx_in_warpgroup
+) {
+ if (idx_in_warpgroup == 0) {
+ auto thr_tma = tma_K.get_slice(_0{});
+ Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{});
+ cute::prefetch(tma_K, cur_gKV);
+ if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {
+ prefetch_kv_tiles(gKV, tma_K, idx_in_warpgroup);
+ }
+ }
+}
+
+// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
+// * Copyright (c) 2024, Tri Dao.
+template
+__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
+ constexpr bool Is_RS = !cute::is_base_of::value;
+ // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
+ if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); }
+ warpgroup_fence_operand(tCrC);
+ if constexpr (arrive) {
+ warpgroup_arrive();
+ }
+ if constexpr (zero_init) {
+ tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
+ // Unroll the K mode manually to set scale D to 1
+ CUTLASS_PRAGMA_UNROLL
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
+ cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
+ }
+ } else {
+ // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
+ // Unroll the K mode manually to set scale D to 1
+ CUTLASS_PRAGMA_UNROLL
+ for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
+ cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
+ }
+ }
+ if constexpr (commit) {
+ warpgroup_commit_batch();
+ }
+ if constexpr (wg_wait >= 0) { warpgroup_wait(); }
+ warpgroup_fence_operand(tCrC);
+ if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); }
+}
+
+
+// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64)
+// The Q-tile should be in shared memory
+template<
+ typename TiledMMA,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2
+>
+__forceinline__ __device__ void qkt_gemm_one_tile_sQ(
+ TiledMMA &tiled_mma,
+ Tensor const &thr_mma_sQ_tile, // (MMA, 1, 4)
+ Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4)
+ Tensor &rP, // ((2, 2, 8), 1, 1)
+ TMABarrier* barrier,
+ bool &cur_phase,
+ int idx_in_warpgroup
+) {
+ if (idx_in_warpgroup == 0) {
+ barrier->arrive_and_expect_tx(64*64*2);
+ }
+ barrier->wait(cur_phase ? 1 : 0);
+
+ warpgroup_fence_operand(rP);
+ warpgroup_arrive();
+ cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
+ cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);
+ cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);
+ cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);
+ warpgroup_commit_batch();
+ warpgroup_fence_operand(rP);
+}
+
+template<
+ typename TiledMMA,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2
+>
+__forceinline__ __device__ void qkt_gemm_one_tile_rQ(
+ TiledMMA &tiled_mma,
+ Tensor const &thr_mma_rQ_tile, // (MMA, 1, 4)
+ Tensor const &thr_mma_sKV_tile, // (MMA, 1, 4)
+ Tensor &rP, // ((2, 2, 8), 1, 1)
+ TMABarrier* barrier,
+ bool &cur_phase,
+ int idx_in_warpgroup
+) {
+ if (idx_in_warpgroup == 0) {
+ barrier->arrive_and_expect_tx(64*64*2);
+ }
+ barrier->wait(cur_phase ? 1 : 0);
+
+ warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile));
+ warpgroup_fence_operand(rP);
+ warpgroup_arrive();
+ cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);
+ tiled_mma.accumulate_ = GMMA::ScaleOut::One;
+ cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);
+ cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);
+ cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);
+ warpgroup_commit_batch();
+ warpgroup_fence_operand(rP);
+ warpgroup_fence_operand(const_cast &>(thr_mma_rQ_tile));
+}
+
+// Pipelined TMA wait and Q K^T gemm
+// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows:
+// - Wait for the 0-th tile to be ready using `barrier.wait()`
+// - Compute Q K^T for the 0-th tile
+// - Wait for the 1-st tile to be ready
+// - Compute Q K^T for the 1-st tile
+// ...
+// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation
+template<
+ typename T, // Traits
+ int PHASE_IDX, // See comments in the code
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2,
+ typename Engine3, typename Layout3
+>
+__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm(
+ Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K)
+ Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
+ Tensor &rP, // ((2, 2, 8), 1, 1)
+ Tensor &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1
+ TMABarrier* barriers,
+ bool &cur_phase,
+ int idx_in_warpgroup
+) {
+ Tensor sQ_tiled = flat_divide(sQ, Shape, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9)
+ Tensor sKV_tiled = flat_divide(sKV, Shape, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9)
+ TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){};
+ ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup);
+ Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9)
+ Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9)
+ TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){};
+
+ #define QKT_GEMM_ONE_TILE(TILE_IDX) \
+ if constexpr(TILE_IDX != 8) { \
+ qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int{}), thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \
+ } else { \
+ qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \
+ }
+
+ if constexpr (PHASE_IDX == 0) {
+ // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles
+ tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;
+ tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
+ QKT_GEMM_ONE_TILE(0);
+ QKT_GEMM_ONE_TILE(1);
+ QKT_GEMM_ONE_TILE(2);
+ QKT_GEMM_ONE_TILE(3);
+ } else if constexpr (PHASE_IDX == 1) {
+ // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles
+ tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;
+ tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
+ QKT_GEMM_ONE_TILE(4);
+ QKT_GEMM_ONE_TILE(5);
+ QKT_GEMM_ONE_TILE(6);
+ QKT_GEMM_ONE_TILE(7);
+ QKT_GEMM_ONE_TILE(8);
+ QKT_GEMM_ONE_TILE(0);
+ QKT_GEMM_ONE_TILE(1);
+ QKT_GEMM_ONE_TILE(2);
+ QKT_GEMM_ONE_TILE(3);
+ cur_phase ^= 1;
+ } else {
+ // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles
+ static_assert(PHASE_IDX == 2);
+ tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One;
+ tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
+ QKT_GEMM_ONE_TILE(4);
+ QKT_GEMM_ONE_TILE(5);
+ QKT_GEMM_ONE_TILE(6);
+ QKT_GEMM_ONE_TILE(7);
+ QKT_GEMM_ONE_TILE(8);
+ cur_phase ^= 1;
+ }
+}
+
+
+template<
+ typename T,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2
+>
+__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline(
+ Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K)
+ Tensor &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K)
+ Tensor &rP, // ((2, 2, 8), 1, 1)
+ int idx_in_warpgroup
+) {
+ TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){};
+ ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
+ Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36)
+ Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36)
+ gemm(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP);
+}
+
+
+// Compute O += PV, where P resides in register
+template<
+ typename T,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2
+>
+__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP(
+ Tensor &rP, // ((2, 2, 8), 1, 1), fragment A layout
+ Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)
+ Tensor &rO, // ((2, 2, 32), 1, 1)
+ int idx_in_warpgroup
+) {
+ TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){};
+ ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
+ Tensor rP_retiled = make_tensor(rP.data(), Layout<
+ Shape, _1, _4>,
+ Stride, _0, _8>
+ >{});
+ Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4)
+ gemm(tiled_mma, rP_retiled, thr_mma_sKV_half, rO);
+}
+
+
+// Compute O += PV, where P resides in shared memory
+template<
+ typename T,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2
+>
+__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP(
+ Tensor &sP,
+ Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)
+ Tensor &rO, // ((2, 2, 32), 1, 1)
+ int idx_in_warpgroup
+) {
+ TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){};
+ ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
+ Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP);
+ Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4)
+ gemm(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO);
+}
+
+
+template<
+ typename T,
+ bool DO_OOB_FILLING,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2,
+ typename Engine3, typename Layout3,
+ typename Engine4, typename Layout4
+>
+__forceinline__ __device__ void wg0_bunch_0(
+ Tensor &rPb, // ((2, 2, 8), 1, 1)
+ Tensor &rP0, // ((2, 2, 8), 1, 1)
+ Tensor &rO0, // ((2, 2, 32), 1, 1)
+ Tensor &sScale0, // (BLOCK_SIZE_M)
+ Tensor &sM, // (BLOCK_SIZE_M)
+ float rL[2],
+ int rRightBorderForQSeq[2],
+ float scale_softmax_log2,
+ int start_token_idx,
+ int idx_in_warpgroup
+) {
+ // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png)
+ CUTLASS_PRAGMA_UNROLL
+ for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
+ int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
+
+ // Mask, and get row-wise max
+ float cur_max = MAX_INIT_VAL;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
+ if constexpr (DO_OOB_FILLING) {
+ int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;
+ rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL;
+ rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL;
+ }
+ cur_max = max(cur_max, max(rP0(i), rP0(i+1)));
+ }
+ cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
+ cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
+
+ // Update sM and sL
+ cur_max *= scale_softmax_log2;
+ float new_max = max(sM(row_idx), cur_max);
+ float scale_for_old = exp2f(sM(row_idx) - new_max);
+ __syncwarp(); // Make sure all reads have finished before updating sM
+ if (idx_in_warpgroup%4 == 0) {
+ sScale0(row_idx) = scale_for_old;
+ sM(row_idx) = new_max;
+ }
+
+ // Scale-O
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {
+ rO0(i) *= scale_for_old;
+ rO0(i+1) *= scale_for_old;
+ }
+
+ // Scale, exp, and get row-wise expsum
+ float cur_sum = 0;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
+ rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max);
+ rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max);
+ rPb(i) = (typename T::InputT)rP0(i);
+ rPb(i+1) = (typename T::InputT)rP0(i+1);
+ cur_sum += rP0(i) + rP0(i+1);
+ }
+ rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;
+ }
+}
+
+
+template<
+ typename T,
+ bool IS_BLK0_LAST,
+ bool IS_BLK1_LAST,
+ bool IS_BLK2_LAST,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2,
+ typename Engine3, typename Layout3,
+ typename Engine4, typename Layout4,
+ typename Engine5, typename Layout5
+>
+__forceinline__ __device__ void wg1_bunch_0(
+ Tensor &rP1b, // ((2, 2, 8), 1, 1)
+ Tensor &sScale1, // (BLOCK_SIZE_M)
+ Tensor &rO1, // ((2, 2, 32), 1, 1)
+ Tensor &sM, // (BLOCK_SIZE_M)
+ float rL[2],
+ int rRightBorderForQSeq[2],
+ Tensor const &sScale0, // (BLOCK_SIZE_M)
+ Tensor &rP1, // ((2, 2, 8), 1, 1)
+ float scale_softmax_log2,
+ int start_token_idx,
+ int idx_in_warpgroup
+) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
+ int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
+
+ // Mask, and get row-wise max
+ float cur_max = MAX_INIT_VAL;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {
+ if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) {
+ // Need to apply the mask when either this block is the last one, or
+ // the next block is the last one (because of the causal mask)
+ int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;
+ rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL;
+ rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL;
+ } else if constexpr (IS_BLK0_LAST) {
+ rP1(i) = rP1(i+1) = MAX_INIT_VAL;
+ }
+ cur_max = max(cur_max, max(rP1(i), rP1(i+1)));
+ }
+ cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
+ cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
+ cur_max *= scale_softmax_log2;
+
+ float old_max = sM(row_idx);
+ float new_max = max(old_max, cur_max);
+ float scale_for_old = exp2f(old_max - new_max);
+ __syncwarp();
+ if (idx_in_warpgroup%4 == 0) {
+ sM(row_idx) = new_max;
+ sScale1(row_idx) = scale_for_old;
+ }
+
+ // Scale, exp, and get row-wise expsum
+ float cur_sum = 0;
+ if constexpr (!IS_BLK0_LAST) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {
+ rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max);
+ rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max);
+ rP1b(i) = (typename T::InputT)rP1(i);
+ rP1b(i+1) = (typename T::InputT)rP1(i+1);
+ cur_sum += rP1(i) + rP1(i+1);
+ }
+ }
+
+ // Scale O
+ float cur_scale_for_o1 = scale_for_old * sScale0(row_idx);
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) {
+ rO1(i) *= cur_scale_for_o1;
+ rO1(i+1) *= cur_scale_for_o1;
+ }
+
+ // Update rL
+ rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum;
+ }
+}
+
+
+// Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction
+template<
+ typename T,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1
+>
+__forceinline__ __device__ void save_rPb_to_sP(
+ Tensor &rPb,
+ Tensor &sP,
+ int idx_in_warpgroup
+) {
+ auto r2s_copy = make_tiled_copy_C(
+ Copy_Atom{},
+ (typename T::TiledMMA_QK_sQ){}
+ );
+ ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
+ Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
+ Tensor thr_copy_sP = thr_copy.partition_D(sP);
+ cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
+}
+
+
+// Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction
+template<
+ typename T,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1
+>
+__forceinline__ __device__ void retrieve_rP_from_sP(
+ Tensor &rPb,
+ Tensor const &sP,
+ int idx_in_warpgroup
+) {
+ TiledCopy s2r_copy = make_tiled_copy_A(
+ Copy_Atom{},
+ (typename T::TiledMMA_PV_LocalP){}
+ );
+ ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup);
+ Tensor thr_copy_sP = thr_copy.partition_S(sP);
+ Tensor thr_copy_rPb = thr_copy.retile_D(rPb);
+ cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb);
+}
+
+
+// Rescale rP0 and save the result to rPb
+template<
+ typename T,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2
+>
+__forceinline__ __device__ void wg0_scale_rP0(
+ Tensor const &sScale1, // (BLOCK_M)
+ Tensor const &rP0, // ((2, 2, 8), 1, 1)
+ Tensor &rPb, // ((2, 2, 8), 1, 1)
+ int idx_in_warpgroup
+) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
+ int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
+ float scale_factor = sScale1(row_idx);
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
+ rPb(i) = (typename T::InputT)(rP0(i)*scale_factor);
+ rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor);
+ }
+ }
+}
+
+
+// Rescale rO0 according to sScale1
+template<
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1
+>
+__forceinline__ __device__ void wg0_rescale_rO0(
+ Tensor &rO0,
+ Tensor &sScale1,
+ float rL[2],
+ int idx_in_warpgroup
+) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
+ int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
+ float scale_factor = sScale1(row_idx);
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {
+ rO0(i) *= scale_factor;
+ rO0(i+1) *= scale_factor;
+ }
+ rL[local_row_idx] *= scale_factor;
+ }
+}
+
+
+// Fill out-of-bound V with 0.0
+// We must fill it since it may contain NaN, which may propagate to the final result
+template<
+ typename T,
+ typename Engine0, typename Layout0
+>
+__forceinline__ __device__ void fill_oob_V(
+ Tensor &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape, Int>{}, LayoutRight{} );
+ int valid_window_size,
+ int idx_in_warpgroup
+) {
+ Tensor sV_int64 = make_tensor(
+ make_smem_ptr((int64_t*)(sV.data().get().get())),
+ tile_to_shape(
+ GMMA::Layout_MN_SW128_Atom{},
+ Shape, Int>{},
+ LayoutRight{}
+ )
+ );
+ valid_window_size = max(valid_window_size, 0);
+ int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds
+ for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) {
+ sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0;
+ }
+}
+
+
+// Store O / OAccum
+template<
+ typename T,
+ bool IS_NO_SPLIT,
+ typename TMAParams,
+ typename Engine0, typename Layout0,
+ typename Engine1, typename Layout1
+>
+__forceinline__ __device__ void store_o(
+ Tensor &rO, // ((2, 2, 32), 1, 1)
+ Tensor &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
+ float rL[2],
+ char* sO_addr,
+ TMAParams &tma_params,
+ int batch_idx,
+ int k_head_idx,
+ int m_block_idx,
+ int num_valid_seq_q,
+ int warpgroup_idx,
+ int idx_in_warpgroup
+) {
+ using InputT = typename T::InputT;
+ if constexpr (IS_NO_SPLIT) {
+ // Should convert the output to bfloat16 / float16, and save it to O
+ Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape(
+ GMMA::Layout_K_SW128_Atom{},
+ Shape, Int>{}
+ ));
+
+ Tensor rOb = make_tensor_like(rO);
+ CUTLASS_PRAGMA_UNROLL
+ for (int idx = 0; idx < size(rO); ++idx) {
+ rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]);
+ }
+
+ Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
+ TiledCopy r2s_tiled_copy = make_tiled_copy_C(
+ Copy_Atom{},
+ (typename T::TiledMMA_PV_LocalP){}
+ );
+ ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup);
+ Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb);
+ Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf);
+ cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf);
+ cutlass::arch::fence_view_async_shared();
+
+ __syncthreads();
+
+ if (threadIdx.x == 0) {
+ Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM)
+ auto thr_tma = tma_params.tma_O.get_slice(_0{});
+ Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, m_block_idx, _0{});
+ cute::copy(
+ tma_params.tma_O,
+ thr_tma.partition_S(sOutputBuf),
+ thr_tma.partition_D(my_tma_gO)
+ );
+ cute::tma_store_arrive();
+ }
+ } else {
+ // Should save the result to OAccum
+ Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout<
+ Shape<_64, _512>,
+ Stride, _1> // We use stride = 520 here to avoid bank conflict
+ >{});
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int idx = 0; idx < size(rO); idx += 2) {
+ int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
+ int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
+ *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 {
+ rO(idx) / rL[idx%4 >= 2],
+ rO(idx+1) / rL[idx%4 >= 2],
+ };
+ }
+ cutlass::arch::fence_view_async_shared();
+
+ __syncthreads();
+
+ int row = threadIdx.x;
+ if (row < num_valid_seq_q) {
+ SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float));
+ cute::tma_store_arrive();
+ }
+ }
+}
+
+template<
+ typename T,
+ typename TmaParams, typename Tensor0
+>
+__forceinline__ __device__ void launch_q_copy(
+ TmaParams const &tma_params,
+ int batch_idx,
+ int m_block_idx,
+ int k_head_idx,
+ Tensor0 &sQ,
+ TMABarrier* barrier_Q
+) {
+ if (threadIdx.x == 0) {
+ Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM)
+ auto thr_tma = tma_params.tma_Q.get_slice(_0{});
+ Tensor my_tma_gQ = flat_divide(tma_gQ, Shape