Performance Update (2025.04.22) (#71)

* Fix benchmark script

* Performance optimization for compute-bound cases

* Add new testcase (s_k = 16384)

* Update README.md

* Update comment

* Update README.md

* Add the deep-dive blog

* Add background color for MLA Kernel Sched.drawio.svg

* Use relative path for the schedule image

* Move flash_mla.h to kernels/params.h
This commit is contained in:
Shengyu Liu 2025-04-22 17:50:57 +08:00 committed by GitHub
parent b31bfe72a8
commit c2067be3ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 2757 additions and 1228 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ __pycache__/
dist/
*perf.csv
*.png
/.vscode

View File

@ -1,11 +1,28 @@
# FlashMLA
## Performance Update (2025.04.22)
We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀
Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up here: <LINK>
The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance.
## Introduction
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
- BF16, FP16
- Paged kvcache with block size of 64
## Requirements
- Hopper GPUs
- CUDA 12.3 and above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.0 and above
## Quick start
### Install
@ -20,7 +37,9 @@ python setup.py install
python tests/test_flash_mla.py
```
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance.
### Usage
@ -38,13 +57,6 @@ for i in range(num_layers):
...
```
## Requirements
- Hopper GPUs
- CUDA 12.3 and above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.0 and above
## Acknowledgement
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
@ -91,7 +103,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c
```bibtex
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li},
author={Jiashi Li, Shengyu Liu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},

View File

@ -435,7 +435,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton"]:
if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]:
# flash_infer has a different lse return value
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"

View File

@ -10,8 +10,11 @@
#include <cutlass/fast_math.h>
#include "flash_mla.h"
#include "static_switch.h"
#include "kernels/config.h"
#include "kernels/get_mla_metadata.h"
#include "kernels/mla_combine.h"
#include "kernels/params.h"
#include "kernels/splitkv_mla.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
@ -23,11 +26,6 @@ get_mla_metadata(
const int num_heads_per_head_k,
const int num_heads_k
) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
@ -38,7 +36,7 @@ get_mla_metadata(
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
@ -52,10 +50,10 @@ get_mla_metadata(
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.block_size_n = Config::PAGE_BLOCK_SIZE;
params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS;
params.num_sm_parts = num_sm_parts;
get_mla_metadata_func(params, stream);
run_get_mla_metadata_kernel(params, stream);
return {tile_scheduler_metadata, num_splits};
}
@ -64,7 +62,6 @@ std::vector<at::Tensor>
mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
@ -73,138 +70,141 @@ mha_fwd_kvcache_mla(
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
// Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
// Check device
CHECK_DEVICE(q);
CHECK_DEVICE(kcache);
CHECK_DEVICE(seqlens_k);
CHECK_DEVICE(block_table);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_DEVICE(num_splits);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_CONTIGUOUS(tile_scheduler_metadata);
CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse);
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse_accum);
CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
TORCH_CHECK(head_size_k == 576);
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
}
#endif
else {
run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, softmax_lse};
}

View File

@ -1,3 +0,0 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -1,3 +0,0 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -1,603 +0,0 @@
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
using namespace cute;
#include "named_barrier.h"
#include "utils.h"
#include "softmax.h"
#include "static_switch.h"
#include "flash_mla.h"
template<typename PrecType, int DIM, int DIM2 = DIM>
constexpr auto getSmemLayoutK() {
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
return GMMA::Layout_K_SW128_Atom<PrecType>{};
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
return GMMA::Layout_K_SW64_Atom<PrecType>{};
} else {
return GMMA::Layout_K_SW32_Atom<PrecType>{};
}
}
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kNWarpsS = 4;
static constexpr int kNThreadsS = kNWarpsS * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 32 == 0);
static_assert(kHeadDimV <= kHeadDim);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = decltype(make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
using TiledMmaO = decltype(make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim>(),
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutK = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutV = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomO = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
using GmemLayoutAtomOaccum = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
};
namespace flash {
using namespace cute;
template<typename Kernel_traits>
struct SharedStorageMLA {
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
};
struct {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
};
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
// Epilogue
const int split_offset = __ldg(params.num_splits_ptr + bidb);
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor rO = flash::convert_type<ElementO>(tOrO);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
__syncthreads();
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads();
if (tidx >= kNThreadsS) { return; }
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreads = Kernel_traits::kNThreads;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
static_assert(kNThreads == 256 and kNThreadsS == 128);
using Element = typename Kernel_traits::Element;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
int n_block = n_block_max - 1;
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
clear(tOrO);
flash::Softmax<2 * size<1>(tOrO)> softmax;
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 0) {
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
#pragma unroll 1
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
__syncthreads();
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
const bool is_masking_step = masking_step > 0;
const bool is_first_masking_step = masking_step == n_masking_steps;
if (is_masking_step) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
Tensor scale_o = is_first_masking_step
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: is_masking_step ?
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(tSrS);
cute::copy(rP, tPsP);
cute::copy(scale_o, tScale_osScale_o);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cute::copy(softmax.row_max, tRow_maxsRow_max);
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
} else {
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tKsK.data() = tKsK.data() + sK_offset;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need to clear the sK smem tiles because K is V.
const index_t offset_k = cur_block_table * params.k_batch_stride;
tKgK.data() = tKgK.data() + offset_k;
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
seqlen_k - n_block * kBlockN);
tKgK.data() = tKgK.data() + -offset_k;
cute::cp_async_fence();
if (n_block - 1 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 1]);
}
#pragma unroll 1
for (; n_block >= n_block_min; --n_block) {
flash::cp_async_wait<0>();
__syncthreads();
if (n_block - 1 >= n_block_min) {
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tKsK.data() = tKsK.data() + sK_offset;
const index_t offset_k = cur_block_table * params.k_batch_stride;
tKgK.data() = tKgK.data() + offset_k;
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
tKgK.data() = tKgK.data() + -offset_k;
cute::cp_async_fence();
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
if (n_block - 2 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 2]);
}
typename Kernel_traits::TiledMma tiled_mma;
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
Tensor rP = make_tensor<Element>(tSrS_layout);
Tensor scale_o = make_tensor<float>(Shape<_2>{});
cute::copy(tScale_osScale_o, scale_o);
cute::copy(tPsP, rP);
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
cute::copy(tRow_maxsRow_max, softmax.row_max);
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
}
if (NoSplit)
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
else
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
constexpr int kBlockN = Kernel_traits::kBlockN;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
__global__ void __launch_bounds__(256, 1, 1)
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
constexpr int kNThreads = 128;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int hs = params.h * params.seqlen_q;
const int batch_idx = bidx / hs;
const int hs_idx = bidx % hs;
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
if (actual_num_splits == 1) return;
__shared__ ElementAccum sLseScale[kMaxSplits];
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
const index_t row_offset_lse = bidx;
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
Shape<Int<kMaxSplits>>{}, make_stride(hs));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<_1>{}, Stride<_1>{});
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
float local_lse[kNLsePerThread];
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
}
float max_lse = -INFINITY;
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0;
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
if (tidx == 0) gLSE(0) = global_lse;
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
__syncthreads();
static_assert(kHeadDimV % kNThreads == 0);
constexpr int Elements = kHeadDimV / kNThreads;
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape<Int<kNThreads>>>{},
Layout<Shape<Int<Elements>>>{}));
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
clear(tOrO);
for (int split = 0; split < actual_num_splits; ++split) {
cute::copy(tOgOaccum, tOrOaccum);
ElementAccum lse_scale = sLseScale[split];
for (int i = 0; i < size(tOrO); ++i) {
tOrO(i) += lse_scale * tOrOaccum(i);
}
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
}
Tensor rO = flash::convert_type<Element>(tOrO);
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
cute::copy(rO, gO);
}
} // namespace flash
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = sizeof(SharedStorage);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
static_assert(Headdim == 576);
FLASH_ASSERT(params.d_v == 512);
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}

13
csrc/kernels/config.h Normal file
View File

@ -0,0 +1,13 @@
#pragma once
namespace Config {
static constexpr int BLOCK_SIZE_M = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5;
}

View File

@ -1,8 +1,11 @@
#include "flash_fwd_mla_kernel.h"
#include "get_mla_metadata.h"
static constexpr int MaxBatchSize = 4096;
#include <cuda_runtime_api.h>
#include <cutlass/fast_math.h>
__global__ void __launch_bounds__(256, 1, 1)
#include "utils.h"
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
@ -12,8 +15,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
__shared__ int num_blocks_shared[MaxBatchSize];
__shared__ int num_splits_shared[MaxBatchSize];
extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
@ -27,7 +31,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
__syncwarp();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
@ -70,8 +74,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.batch_size < MaxBatchSize);
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*2+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}

View File

@ -0,0 +1,5 @@
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream);

207
csrc/kernels/mla_combine.cu Normal file
View File

@ -0,0 +1,207 @@
#include "mla_combine.h"
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "params.h"
#include "utils.h"
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
using namespace cute;
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
const int batch_idx = blockIdx.x;
const int m_block_idx = blockIdx.y;
const int warp_idx = threadIdx.x / 32;
const int lane_idx = threadIdx.x % 32;
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx;
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
if (my_num_splits == 1) {
return;
}
const int num_q_seqs = params.q_seq_per_hk * params.h_k;
const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M);
Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
make_stride(num_q_seqs, _1{})
);
Tensor gLse = make_tensor(
make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
Shape<Int<BLOCK_SIZE_M>>{},
Stride<_1>{}
);
extern __shared__ float smem_buf[];
Tensor sLseScale = make_tensor(
make_smem_ptr(smem_buf),
Shape<Int<BLOCK_SIZE_M>, Int<MAX_SPLITS>>{},
Stride<Int<MAX_SPLITS+1>, _1>{} // +1 to avoid bank conflict
);
// Wait for the previous kernel (the MLA kernel) to finish
cudaGridDependencySynchronize();
// Read gLseAccum into sLseScale
{
#pragma unroll 4
for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) {
int split_idx = elem_idx / BLOCK_SIZE_M;
int seq_idx = elem_idx % BLOCK_SIZE_M;
sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY;
}
__syncthreads();
}
if (warp_idx >= num_cur_valid_q_seqs)
return;
// Warp #i gathers LseAccum for seq #i
{
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
float local_lse[NUM_LSE_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY;
}
float max_lse = -INFINITY;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
max_lse = max(max_lse, local_lse[i]);
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2)
max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2)
sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse;
if (lane_idx == 0)
gLse(warp_idx) = global_lse / (float)M_LOG2E;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse);
}
}
__syncwarp();
// Warp #i accumulates activation for seq #i
{
const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V;
Tensor gOaccum = make_tensor(
make_gmem_ptr(reinterpret_cast<float *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<MAX_SPLITS>, Int<HEAD_DIM_V>>{},
make_stride(num_q_seqs*HEAD_DIM_V, _1{})
);
static_assert(HEAD_DIM_V % 32 == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32;
float result[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
result[i] = 0.0f;
#pragma unroll 2
for (int split = 0; split < my_num_splits; ++split) {
float lse_scale = sLseScale(warp_idx, split);
if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
result[i] += lse_scale * gOaccum(split, lane_idx + i*32);
}
}
}
cudaTriggerProgrammaticLaunchCompletion();
const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx;
const int k_head_idx = q_seq_idx / params.q_seq_per_hk;
auto o_ptr = reinterpret_cast<ElementT *>(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride;
Tensor gO = make_tensor(
make_gmem_ptr(o_ptr),
Shape<Int<HEAD_DIM_V>>{},
Stride<_1>{}
);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
gO(lane_idx+i*32) = (ElementT)result[i];
}
}
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 8;
constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, Config::HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t combine_kernel_config = {
dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1),
dim3(NUM_THREADS, 1, 1),
smem_size,
stream,
attribute,
1
};
cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_mla_combine_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#endif

View File

@ -0,0 +1,6 @@
#pragma once
#include "params.h"
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -5,39 +5,41 @@
struct Flash_fwd_mla_params {
using index_t = int64_t;
int b, seqlen_q, d, d_v;
int h, h_h_k_ratio, ngroups;
int b; // batch size
int s_q;
int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q
int d, d_v; // K/V dimension
int h_q, h_k; // The number of Q/K heads
int num_blocks; // Number of blocks in total
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
int *__restrict__ cu_seqlens_k;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
@ -45,11 +47,6 @@ struct Flash_fwd_mla_params {
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
@ -59,5 +56,3 @@ struct Mla_metadata_params {
int fixed_overhead_num_blocks;
int num_sm_parts;
};
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);

1350
csrc/kernels/splitkv_mla.cu Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
#pragma once
#include "params.h"
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);

106
csrc/kernels/traits.h Normal file
View File

@ -0,0 +1,106 @@
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>
#include "config.h"
using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
using namespace cute;
template<typename InputT_>
struct Traits {
using InputT = InputT_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
using TiledMMA_QK_sQ = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using SmemLayoutQ = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutK = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutV = decltype(composition(
SmemLayoutK{},
make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})
)); // A transposed version of SmemLayoutK
using SmemLayoutP0 = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
));
using rP0Layout = decltype(layout(partition_fragment_C(
TiledMMA_QK_sQ{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
)));
struct SharedMemoryPlan {
cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;
cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;
cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;
TMABarrier barriers_K0[HEAD_DIM_K/64];
TMABarrier barriers_K1[HEAD_DIM_K/64];
TMABarrier barrier_Q;
};
};
template<
typename ShapeQ, typename TMA_Q,
typename ShapeK, typename TMA_K,
typename ShapeO, typename TMA_O
>
struct TmaParams {
ShapeQ shape_Q;
TMA_Q tma_Q;
ShapeK shape_K;
TMA_K tma_K;
ShapeO shape_O;
TMA_O tma_O;
};
enum NamedBarriers : int {
sScale0Ready = 0,
sScale1Ready = 1,
sP0Ready = 2,
rO1sP0sV0RIssued = 3
};

View File

@ -5,7 +5,7 @@
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
exit(1); \
} \
} while(0)
@ -29,37 +29,4 @@
} \
} while(0)
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()
#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); }

View File

@ -1,15 +0,0 @@
#pragma once
#include "cutlass/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class NamedBarriers {
SReady = 1,
SoftmaxReady = 2,
};
} // flash

View File

@ -1,200 +0,0 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
return tensor;
}
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
SumOp<float> sum_op;
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
}
}
template<typename Tensor0, typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scale_o); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scale_o;
clear(scale_o);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scale_o(mi) = scores_scale;
row_sum(mi) *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
return scale_o;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
};
} // namespace flash

View File

@ -1,241 +0,0 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_bf16.h>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
template<bool Transposed=false, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
template<typename MMA_Traits, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
} else {
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
// Will require register shuffling later to be correct.
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
get<1>(acc_layout),
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
// This combination is right but doesn't work with register shuffling.
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
// get<1>(acc_layout),
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -0,0 +1,77 @@
# A Deep-Dive Into the New Flash MLA Kernel
In the [previous version](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) of the Flash MLA kernel, we have achieved impressive performance: 3000 GB/s in memory-intensive settings and 580 TFlops in compute-bound settings. Now, we're pushing these numbers even further, reaching up to 660 TFlops.
In this blog, we present a deep dive into the new kernel, explaining the optimizations and techniques behind this performance boost. We'll first explain why the MLA kernel is compute-bound despite being a decoding-stage attention kernel, then discuss our high-level kernel schedule design, and finally cover the technical details of the new kernel.
## A Theoretical Analysis of the MLA Algorithm
GPU kernels can be classified as either compute-bound (limited by floating-point operations per second, FLOPs) or memory-bound (limited by memory bandwidth). To identify the kernel's bottleneck, we calculate the ratio of FLOPs to memory bandwidth (FLOPs/byte) and compare it with the GPU's capacity.
Assume the number of q heads is $h_q$, the number of q tokens per request is $s_q$ (should be 1 if MTP / speculative decoding is disabled), the number of kv tokens per request is $s_k\ (s_k \gg h_q s_q)$, and the head dimensions of K and V are $d_k$ and $d_v$ respectively. The number of FLOPs is roughly $2 (h_q s_q \cdot d_k \cdot s_k + h_q s_q \cdot s_k \cdot d_v) = 2 h_q s_q s_k (d_k+d_v)$, and the memory access volume (in bytes) is $\mathop{\text{sizeof}}(\text{bfloat16}) \times (h_q s_q d_k + s_k d_k + h_q s_q d_v) \approx 2s_k d_k$. Thus, the compute-memory ratio is $h_q s_q \cdot \frac{d_k+d_v}{d_k} \approx 2 h_q s_q$.
An NVIDIA H800 SXM5 GPU has a peak memory bandwidth of 3.35 TB/s and peak FLOPs of 990 TFlops. However, due to throttling (reducing to ~1600 MHz in our case), the practical peak FLOPs drops to ~865 TFlops. Therefore, when $h_qs_q \ge \frac{1}{2} \cdot \frac{865}{3.35} = 128$, the kernel is compute-bound; otherwise, it's memory-bound.
According to [the overview of DeepSeek's Online Inference System](https://github.com/deepseek-ai/open-infra-index/blob/main/202502OpenSourceWeek/day_6_one_more_thing_deepseekV3R1_inference_system_overview.md), we don't use Tensor Parallel for decoding instances, meaning $h_q$ is 128 and the kernel is compute-bound. Thus, we need to optimize the kernel for compute-bound settings.
## High-Level Design of the New Kernel
To fully utilize GPU compute resources, we need to overlap CUDA Core operations with Tensor Core operations and memory access with computation, keeping the Tensor Core constantly busy. This requires redesigning the kernel's "schedule."
[FlashAttention-3's paper](https://arxiv.org/abs/2205.14135) introduces ping-pong scheduling and intra-warpgroup GEMM-softmax pipelining to overlap block-wise matmul and CUDA Core operations. However, these techniques can't be directly applied here due to resource constraints. The output matrix (scaled and accumulated during each mainloop round, similar to [FlashAttention's algorithm](https://arxiv.org/abs/2205.14135)) must be stored in registers due to [WGMMA instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) requirements. Each $64 \times 512$ output matrix occupies 32,768 32-bit registers. With only 65,536 32-bit registers per SM, we can store only one output matrix per SM. This eliminates the possiblility of having two output matrices and letting them use CUDA Core and Tensor Core in a interleaved manner. We need to find another clever way to overlap CUDA Core and Tensor Core computation.
(You might pause here to ponder - perhaps you can find a better solution than ours!)
Our solution involves an additional mathematical transformation beyond FlashAttention's online softmax and accumulation approach. In each step, we take two KV blocks (called $K_0$, $K_1$, $V_0$, and $V_1$). Since the output matrix occupies 32,768 registers (too many for one warpgroup), we split it vertically into $O_L$ and $O_R$ (each $64 \times 256$). We similarly split $V_0$ and $V_1$ into $V_{0L}$, $V_{0R}$, $V_{1L}$, and $V_{1R}$ (each $64 \times 256$). The output matrix is then computed as follows:
0. Maintain a running max $m$ (initialized to $-\infty$, shared between the two warpgroups) and output matrices $\vec o_L, \vec o_R$ (initialized to 0).
1. [0] Compute $\vec p_0 = \vec q K_0^\intercal / qk\_scale$.
2. [1] Compute $\vec p_1 = \vec q K_1^\intercal / qk\_scale$.
3. [0] Compute $mp_0 = \max(\vec p_0)$, $m\_new_0 = \max(m, mp_0)$, and $scale_0 = \exp(m\_new_0 - m)$. Update $m \gets m\_new_0$.
4. [0] Perform softmax on $\vec p_0$: $\vec p_0 \gets \exp(\vec p_0 - m\_new_0)$.
5. [0] Update $\vec o_L \gets \vec o_L \cdot scale_0 + \vec p_0 V_{0L}$.
6. [1] Compute $mp_1 = \max(\vec p_1)$, $m\_new_1 = \max(m, mp_1)$, and $scale_1 = \exp(m\_new_1 - m)$. Update $m \gets m\_new_1$.
7. [1] Perform softmax on $\vec p_1$: $\vec p_1 \gets \exp(\vec p_1 - m\_new_1)$.
8. [1] Update $\vec o_R \gets \vec o_R \cdot (scale_0 \cdot scale_1) + \vec p_1 V_{1R}$.
9. [0] Update $\vec p_0 \gets \vec p_0 \cdot scale_1$.
10. [1] Update $\vec o_R \gets \vec o_R + \vec p_0 V_{0R}$.
11. [0] Update $\vec o_L \gets \vec o_L \cdot scale_1 + \vec p_1 V_{1L}$.
Note: We assume one q head for simplicity, so $\vec q$ and $\vec o$ are vectors. Bracketed numbers indicate the warpgroup performing the operation. Assume $\vec o_L$ resides in warpgroup 0's register and $\vec o_R$ resides in warpgroup 1's register.
This schedule can be viewed as a "ping-pong" variant using one output matrix—we call it "seesaw" scheduling. It's mathematically equivalent to FlashAttention's online softmax algorithm. This schedule allows us to overlap CUDA Core and Tensor Core operations by interleaving the two warpgroups, and also allows us to overlap memory access with computation since we can launch the corresponding Tensor Memory Accelerator (TMA) instructions right after data is no longer needed.
The complete schedule is shown below (remember that in MLA, $K$ and $V$ are the same with different names):
![MLA Kernel Sched](assets/MLA%20Kernel%20Sched.drawio.svg)
## Discussion of Technical Details
This section covers technical details of the new kernel.
First, although the kernel targets compute-bound scenarios (where memory bandwidth isn't the bottleneck), we can't ignore memory latency. If the data is not ready when we want to use it, we have to wait. To solve this problem, we employ the following techniques:
- **Fine-grained TMA copy - GEMM pipelining:** For a $64 \times 576$ K block, we launch 9 TMA copies (each moving a $64 \times 64$ block). GEMM operations begin as soon as each TMA copy completes (When the first TMA copy is done, we can start the first GEMM operation, and so on), improving memory latency tolerance.
- **Cache hints:** Using `cute::TMA::CacheHintSm90::EVICT_FIRST` for TMA copies improves L2 cache hit rates, as shown by experiments.
These optimizations achieve up to 80% Tensor Core utilization (of the throttled theoretical peak) and 3 TB/s memory bandwidth on an H800 SXM5 GPU. While slightly slower (~2%) than the old ping-pong buffer version in memory-bound settings, this is acceptable.
Other performance improvements include:
- **Programmatic Dependent Launch.** We use [programmatic dependent launch](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization) to overlap `splitkv_mla` and `combine` kernels.
- **Tile Scheduler.** We implement a tile scheduler to allocate jobs (requests and blocks) to SMs. This ensures a balanced load across SMs.
## Acknowledgements
FlashMLA's algorithm and scheduling is inspired by [FlashAttention](https://github.com/dao-AILab/flash-attention/), [Flash-Decoding](https://crfm.stanford.edu/2023/10/12/flashdecoding.html), and [CUTLASS](https://github.com/nvidia/cutlass), as well as many projects behind them. We thank the authors for their great work.
## Citation
```bibtex
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li, Shengyu Liu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}
```

View File

@ -0,0 +1,856 @@
<svg host="65bd71144e" xmlns="http://www.w3.org/2000/svg" style="background: #ffffff; background-color: light-dark(#ffffff, #121212);" xmlns:xlink="http://www.w3.org/1999/xlink" version="1.1" width="540px" height="840px" viewBox="-0.5 -0.5 540 840" content="&lt;mxfile scale=&quot;1&quot; border=&quot;0&quot;&gt;&lt;diagram name=&quot;Page-1&quot; id=&quot;t0OnRo4-AS1naIkeEJg6&quot;&gt;7V1bc5s4FP41nkkenOFu/Jik2+3OplPvpt3dtx1sZJstBgo4TvvrVwIJg4SNAckltpKZBAQSQud8R+cmMdIfN6+/xk60/hi6wB9pivs60t+NNE1XTAP+QyXf8xJN0a28ZBV7bl6m7guevR8AFyq4dOu5IKncmIahn3pRtXARBgFYpJUyJ47DXfW2ZehXnxo5K/xEZV/wvHB8wNz2t+em67zUNkt3fwDeak2erCr4ytxZfF3F4TbAzxtp+jL7yS9vHNIWvj9ZO264KxXpv4z0xzgM0/xo8/oIfDS4ZNjyeu8PXC36HYMgPaXCVMtrvDj+FpAuZx1Lv5PByF4HoArKSH/Yrb0UPEfOAl3dQfrDsnW68eGZCg+TNA6/gsfQD2NYEoQBvO0BPwPEKXg92FG1eH3IVyDcgDT+Dm/BFcYTMmSYp8amhQt2ewqZBi5bl6hjk0IHc8WqaH0/MvAAD079QJ0wTiBw7xHvwbO5k4B3TrKujk11IP/bbqJnXBuP09IPd/eBB7nECwN8W8Eh2Um4jRdgBmIPvgCIERm8YEUupk6cYiwZ5LymR1n5ew+9alYP9rt6VmrjAJk0BbgrcJRuZarUEIWUxcCHb/tSxV0dofATZqEHe7JnC8ug2EKzKGrjMcvrlaHQ2JRBMw4cuBVImaYy3ile/SR20hl2imfwYUp2rCR/oEP0dCX5XWEZDY79kzOHQrfCXo7vrRDXLCCtAETfAwKcB4XaPb6w8VwXtfEQg8T74cyz9hDNI/RG2TuaDyPzXS3dtYrExZX3gowB90EkK3eqYUz7UZ/cEi6XCehLC4M7tGkkVjFVxdupKB08Eg1FubOmyv7HrKLJ7orLpoYZxPODqVkDU7UepuowYGpcLEwt/jCtTK0VxcUFS2frp63ItQyD9L2z8XzEmB+A/wLQffgCxr+q4nP2QT5ilYdCf2RvqFEYjguaWnWiUYNorTEMXjIVmiKRGGpnWcQ0JVD4TGoY3so4wfVeKoxvfduGGQt5ARiTEbyHt0xh58yMROQWeLTC//dNVYsQe9Y2jy6Mk4y/ssaj16yW0uKxv4Ks6czUUsgD4UDkz6z2AxbX9O7MHf4SuU4KUOsfO3VXZN/i2byYi8rq43O4TDfO6w0su23o9Ll7/KnUS8QDN7CkqY/D4oP4qVN36dLz6wqipyd2GrB46iLKZEC6iC3GZGg99R6f/9/gxNzZeG+e4/lNzFOG+r8lCTwchAFg8wLdWLlTjAFhjviGpQtuYCi2aDdZd4+bOMyqKsM9WBnRHogGNd+b838pT8OAM+H6CzToVf4+dWnRX4dFb9i0Gd7ZuUi3JFAEsU7/n2zQH2leOWAzHTPoVeGGHIcOCzPoOfQtnql7i35vxRfjmkROUPvsfdh5vMilzH1m3sK+jF0n/npTikM/ZiwY34zHK5BdxFWyC/A+VUO/t7d5z+ubi1fzGyWvQf7d5v/RFc0085PyAW6QfuWy8Z+/HkOEvn4LHoT5tPf6E7+FenvBZOlGlZMBfsEjdhbZdYoT6gDdpG+qonxeoHNKbRfQXvhOkniLjmo1xNnUmuiOdWWK9RJqy6UxAKprggksX2zjl2LcXKhsZyfqW1CqTZNShbuHyYzp+bRqAXHha/TFGmZnYtN8QxtjHInNxkQH5HpVuQU8Bud75R/wkL5XLqhVqtjrk2QlDLRsvATbU3vXq1r2vap/DgTP3GIpg/O9Fq8gfa/CVMQL9b3SkdZpV5FDNzQRJoG0g8GfikcHmf7XYPsS8POQa8pUVytU1LhwmVatwEfmtYs3SbtY2sW5nFIa7JvTJR5tY3EUcfxXMUgNnQf30DkOGrdsCSZZiiM31Sx8qGZLJPk6CKKxDyRb4oKXPwhY/yABzskEvzNKq5Oq6lD3jMYm054j2lmvGmWgJ3l2OcmNGoh9zjfBeFhob+dw662rzm0TuXGvWldd2guwWBzVVdHJzElhS0FWoilvQ4Olk6ON7oukdbopXZxcYh2HEgQSBF3n6QnFubraeWpWaYtQnGFH4CVBIEHAYSagQaDSXtfuMwHtCOaIAdZ/KzEgMdB1IqBXnfSfBwSwfDtntQzQyQBdSdu/U6f7H7vC7Z3dj8ebFZfzo7PLJpLyYvc8eDe7luAdEQxXFrzT2fjGZxAkKN2bojp8+1QcwZ1tGiaYrowYJU7U0kxKis6f4NyWHcYW7SUwCU+VBJdVI7hoD0CX/fV0NuLw+OXdvaQuL+qaE7OZunXTEhfqssELSV2e1GUSrs9KXTZ+IGUzZ/oyu5+eUzazwYi/nThC7xuNNBKmkmTmMAXTPpViQ9sSmaeiyMy628tkZjc4lGTuimZ60fk5yUwedaJnAQ+r9Cqc6lU4yCBD8Q8wer7ROYVXb2yKnx/AOLMTWKZvXkj6JrU8rbMDjNaxxe0ObtTtFNLT+XuVaxqpebZ7LpZ9NtKzfq4BrXAkjHl5KxwN1v8kdHaRIcaLDjFSYfbua6upTBNxSaBGu/xiqV5J9apWK+q+gyulp4lbDmicObnWteaWeeWsvlwutQZZP3hWt2j3y+m8LZCZ2yXJSlvhxBxptXNWUNMCaY6fp2FdeUMyFjBnXp6xYLZzRUnMnTjl99iECDZ1RzY+K1oTpy2bGsMBMy8CaI8/FxZ//oi28ds5XsZ6ASryhgNMwr4XCMx2jjNpxUsr/ph4oqx4ld4jqXOuvMCl2+Tbu3Ju4rzFuN6Z+mxT4tbMmawf8w1NTNyW7w9tYrJYTV2isg8q+4NQXMa6xZoHSemzoflud0PAG+HKQeCNtETb3Xzwx+rrDAmgqhehQ6jH+N8fYqhjAaRQNXzwfA9KlahnszDxsDLlg2VaotoTdbmgHqFznEPmMJkPIrRNlhmdTFjAs4Ri9QiK+6QfWayKvlsp4/k2WKzHbySVkAcRJsyuUmfMAbNOUBJ5wGHpe9GHA9CgWb0RGxSWhEDDoBMw7TMig1Xddiv1+pDB0GByRmAI8KbLBMlrWXZZYdvuuwQK/KT0wYABtgxvglfUlb+UpytZTmlxC1K8qeWUEwFeIinnrlPOcftmr8C0ygmrW9ULPvVaBN+Em6/tbQm+M6e5ydyfC8v96b2NLx1VEin12IW5ktklszd72Ruyjk93uk+ON8SR1fl/oei6Ayy9Y950SwJpzzptDih3f/4U5a5G+eL2ZaGhBTpt1sMgccgDh9wCnwJzT2w28HnAvTQUIJKxuUAgaq2AKHU/qfvV6n49PmBAp88JFDxnXr0pmf0ymJ3Zb6eHM5PeWZlpiiO7SyeWZHcOsp35LHdn2S7wA98268RK/vlnpCN98jkN4yyvFZIUcv/ayc82YBOigbAyVnG9F3i4QodxfcUYrLwEMbSmQKoDUhH2rFSXQdjhBIymTJWe6Rc+4oyx68RfbyBjLhT0O9LgQMK/+mSJfm8pBsYQOZKVDy1SxT7GxC1UYZ3ejbD43kg5jaYuh4P2hXbJ4bDbuYLYDa4o5LsmsF2jTi7a2ly3rLeAfDbbjCSKt5/oGpviiP126ThXQUpm/CddKak1tcSPkATrkpBtgNQdkwIpWbf2NZ8vkSpXIan1bRuSC+N8OkOT7zR6zehDLpMZ9ikMI3h5jjQ+J6OEh2ZnOAiZh8YtT8z5s8jMzEz0vLp0E/k/MlIvM+0gDbOebDYAzde4fyNL2a3Uf5PtHOqqqQc510Kb0R/tamfV4ifvrEk096OZKn00hWLZfQmEhZzqqSnA0zhE5N2zfuxE64+Qs9Ad/wM=&lt;/diagram&gt;&lt;/mxfile&gt;">
<defs/>
<rect fill="#ffffff" width="100%" height="100%" x="0" y="0" style="fill: light-dark(rgb(255, 255, 255), rgb(18, 18, 18));"/>
<g>
<g>
<rect x="0" y="0" width="540" height="840" fill="#ffffff" stroke="none" pointer-events="all" style="fill: light-dark(#ffffff, var(--ge-dark-color, #121212));"/>
</g>
<g>
<path d="M 80 300 L 80 120" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 77 300 L 83 300" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 83 120 L 77 120" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 198px; margin-left: 81px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rP0 = sQ @ sK0
</div>
</div>
</div>
</foreignObject>
<text x="81" y="201" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rP0 = sQ @ sK0
</text>
</switch>
</g>
</g>
<g>
<path d="M 319.31 480 L 319.31 300" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 316.31 480 L 322.31 480" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 322.31 300 L 316.31 300" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 378px; margin-left: 320px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rP1 = sQ @ sK1
</div>
</div>
</div>
</foreignObject>
<text x="320" y="381" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rP1 = sQ @ sK1
</text>
</switch>
</g>
</g>
<g>
<path d="M 160 380 L 160 300" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 157 380 L 163 380" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 163 300 L 157 300" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 335px; margin-left: 161px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
<div style="line-height: 90%;">
<div>
<font style="font-size: 9px; line-height: 90%;">
Get sScale0
</font>
</div>
<div>
<font style="font-size: 9px; line-height: 90%;">
Update sM
</font>
</div>
<font style="font-size: 9px; line-height: 90%;">
rPb = rP0 = Softmax(rP0)
</font>
<div>
<font style="font-size: 9px; line-height: 90%;">
rO0 = Scale(rO0)
</font>
</div>
<div>
<font style="font-size: 9px; line-height: 90%;">
Update rL
</font>
</div>
</div>
</div>
</div>
</div>
</foreignObject>
<text x="161" y="338" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Get sScale0...
</text>
</switch>
</g>
</g>
<g>
<path d="M 160 400 L 160 380" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 157 400 L 163 400" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 163 380 L 157 380" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 391px; margin-left: 161px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue
</div>
</div>
</div>
</foreignObject>
<text x="161" y="394" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue
</text>
</switch>
</g>
</g>
<g>
<path d="M 80 560 L 80 480" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 77 560 L 83 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 83 480 L 77 480" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 515px; margin-left: 81px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rO0 += rPb @ sV0L
</div>
</div>
</div>
</foreignObject>
<text x="81" y="518" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rO0 += rPb @ sV0L
</text>
</switch>
</g>
</g>
<g>
<path d="M 240 540 L 240 480" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 237 540 L 243 540" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 243 480 L 237 480" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 507px; margin-left: 241px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
<div style="line-height: 90%;">
<div>
<font style="line-height: 90%; font-size: 9px;">
Get sScale1
</font>
</div>
<div>
<font style="line-height: 90%; font-size: 9px;">
Update sM
</font>
</div>
<font style="line-height: 90%; font-size: 9px;">
rP1b = Softmax(rP1
<span style="background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); color: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));">
)
</span>
</font>
<div>
<font style="line-height: 90%; font-size: 9px;">
rO1 = Scale(rO1)
<span style="background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); color: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"></span>
</font>
</div>
<div>
<span style="background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); color: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));">
<font style="line-height: 90%; font-size: 9px;">
Update rL
</font>
</span>
</div>
</div>
</div>
</div>
</div>
</foreignObject>
<text x="241" y="510" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Get sScale1...
</text>
</switch>
</g>
</g>
<g>
<path d="M 170 380 L 227.5 475.83" fill="none" stroke="#9673a6" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
<path d="M 229.42 479.04 L 224.71 476.04 L 227.5 475.83 L 229 473.47 Z" fill="#9673a6" stroke="#9673a6" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(150, 115, 166), rgb(149, 119, 163)); stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
</g>
<g>
<path d="M 270 560 L 270 540" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 267 560 L 273 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 273 540 L 267 540" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 551px; margin-left: 271px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue
</div>
</div>
</div>
</foreignObject>
<text x="271" y="554" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue
</text>
</switch>
</g>
</g>
<g>
<path d="M 320 640 L 320 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 317 640 L 323 640" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 323 560 L 317 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 595px; margin-left: 321px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rO1 += rP1b @ sV1R
</div>
</div>
</div>
</foreignObject>
<text x="321" y="598" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rO1 += rP1b @ sV1R
</text>
</switch>
</g>
</g>
<g>
<path d="M 160 650 L 160 630" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 157 650 L 163 650" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 163 630 L 157 630" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 640px; margin-left: 161px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rO0 = Scale(rO0)
</div>
</div>
</div>
</foreignObject>
<text x="161" y="643" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rO0 = Scale(rO0)
</text>
</switch>
</g>
</g>
<g>
<path d="M 220 540 L 174.52 558.19" fill="none" stroke="#9673a6" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
<path d="M 171.04 559.58 L 174.75 555.41 L 174.52 558.19 L 176.61 560.05 Z" fill="#9673a6" stroke="#9673a6" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(150, 115, 166), rgb(149, 119, 163)); stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
</g>
<g>
<path d="M 80 800 L 80 720" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 77 800 L 83 800" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 83 720 L 77 720" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 755px; margin-left: 81px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rO0 += sP1 @ sV1L
</div>
</div>
</div>
</foreignObject>
<text x="81" y="758" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rO0 += sP1 @ sV1L
</text>
</switch>
</g>
</g>
<g>
<path d="M 319.6 720 L 320 640" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 316.6 719.99 L 322.6 720.01" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 323 640.01 L 317 639.99" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 675px; margin-left: 320px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
rO1 += sP0 @ sV0R
</div>
</div>
</div>
</foreignObject>
<text x="320" y="678" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
rO1 += sP0 @ sV0R
</text>
</switch>
</g>
</g>
<g>
<path d="M 160 100 L 94.47 128.08" fill="none" stroke="#b85450" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke" style="stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
<path d="M 91.03 129.56 L 94.64 125.29 L 94.47 128.08 L 96.61 129.89 Z" fill="#b85450" stroke="#b85450" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(184, 84, 80), rgb(215, 129, 126)); stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
</g>
<g>
<path d="M 250 250 L 306.56 306.56" fill="none" stroke="#b85450" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke" style="stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
<path d="M 309.21 309.21 L 303.91 307.44 L 306.56 306.56 L 307.44 303.91 Z" fill="#b85450" stroke="#b85450" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(184, 84, 80), rgb(215, 129, 126)); stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
</g>
<g>
<path d="M 150 390 L 92.92 466.11" fill="none" stroke="#b85450" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke" style="stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
<path d="M 90.67 469.11 L 91.67 463.61 L 92.92 466.11 L 95.67 466.61 Z" fill="#b85450" stroke="#b85450" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(184, 84, 80), rgb(215, 129, 126)); stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
</g>
<g>
<path d="M 280 550 L 305.38 558.46" fill="none" stroke="#b85450" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke" style="stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
<path d="M 308.94 559.65 L 303.41 560.44 L 305.38 558.46 L 304.99 555.69 Z" fill="#b85450" stroke="#b85450" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(184, 84, 80), rgb(215, 129, 126)); stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
</g>
<g>
<path d="M 159.8 600 L 159.8 580" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 156.8 600 L 162.8 600" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 162.8 580 L 156.8 580" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 590px; margin-left: 160px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
sP0 = Scale(rP0)
</div>
</div>
</div>
</foreignObject>
<text x="160" y="593" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
sP0 = Scale(rP0)
</text>
</switch>
</g>
</g>
<g>
<rect x="60" y="40" width="60" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 55px; margin-left: 90px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Tensor
</div>
</div>
</div>
</foreignObject>
<text x="90" y="58" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Tensor
</text>
</switch>
</g>
</g>
<g>
<rect x="145" y="40" width="50" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 55px; margin-left: 170px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
CUDA
</div>
</div>
</div>
</foreignObject>
<text x="170" y="58" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
CUDA
</text>
</switch>
</g>
</g>
<g>
<rect x="230" y="40" width="50" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 55px; margin-left: 255px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
CUDA
</div>
</div>
</div>
</foreignObject>
<text x="255" y="58" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
CUDA
</text>
</switch>
</g>
</g>
<g>
<rect x="300" y="40" width="60" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 55px; margin-left: 330px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Tensor
</div>
</div>
</div>
</foreignObject>
<text x="330" y="58" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Tensor
</text>
</switch>
</g>
</g>
<g>
<rect x="90" y="20" width="90" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 35px; margin-left: 135px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Warpgroup 0
</div>
</div>
</div>
</foreignObject>
<text x="135" y="38" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Warpgroup 0
</text>
</switch>
</g>
</g>
<g>
<rect x="240" y="20" width="90" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 35px; margin-left: 285px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Warpgroup 1
</div>
</div>
</div>
</foreignObject>
<text x="285" y="38" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Warpgroup 1
</text>
</switch>
</g>
</g>
<g>
<path d="M 60 70 L 360 70" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<path d="M 170 600 L 225.13 600" fill="none" stroke="#9673a6" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
<path d="M 228.88 600 L 223.88 602.5 L 225.13 600 L 223.88 597.5 Z" fill="#9673a6" stroke="#9673a6" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(150, 115, 166), rgb(149, 119, 163)); stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
</g>
<g>
<path d="M 240 620 L 240 600" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 237 620 L 243 620" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 243 600 L 237 600" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 611px; margin-left: 241px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue
</div>
</div>
</div>
</foreignObject>
<text x="241" y="614" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue
</text>
</switch>
</g>
</g>
<g>
<path d="M 250 610 L 305.65 637.82" fill="none" stroke="#b85450" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke" style="stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
<path d="M 309 639.5 L 303.41 639.5 L 305.65 637.82 L 305.65 635.03 Z" fill="#b85450" stroke="#b85450" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(184, 84, 80), rgb(215, 129, 126)); stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
</g>
<g>
<path d="M 230 620 L 174.8 629.2" fill="none" stroke="#9673a6" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
<path d="M 171.1 629.82 L 175.62 626.53 L 174.8 629.2 L 176.45 631.46 Z" fill="#9673a6" stroke="#9673a6" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(150, 115, 166), rgb(149, 119, 163)); stroke: light-dark(rgb(150, 115, 166), rgb(149, 119, 163));"/>
</g>
<g>
<path d="M 90 560 L 145.13 560" fill="none" stroke="#d6b656" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
<path d="M 148.88 560 L 143.88 562.5 L 145.13 560 L 143.88 557.5 Z" fill="#d6b656" stroke="#d6b656" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(214, 182, 86), rgb(109, 81, 0)); stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
</g>
<g>
<path d="M 160 670 L 160 650" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 157 670 L 163 670" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 163 650 L 157 650" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 661px; margin-left: 161px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue
</div>
</div>
</div>
</foreignObject>
<text x="161" y="664" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue
</text>
</switch>
</g>
</g>
<g>
<path d="M 170 110 L 169.86 80" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 167 110.01 L 173 109.99" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 172.86 79.99 L 166.86 80.01" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 97px; margin-left: 170px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Pipelined TMA wait and issue
</div>
</div>
</div>
</foreignObject>
<text x="170" y="100" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Pipelined TMA wait and issue
</text>
</switch>
</g>
</g>
<g>
<path d="M 150 660 L 93.44 716.56" fill="none" stroke="#b85450" stroke-miterlimit="10" stroke-dasharray="1 1" pointer-events="stroke" style="stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
<path d="M 90.79 719.21 L 92.56 713.91 L 93.44 716.56 L 96.09 717.44 Z" fill="#b85450" stroke="#b85450" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(184, 84, 80), rgb(215, 129, 126)); stroke: light-dark(rgb(184, 84, 80), rgb(215, 129, 126));"/>
</g>
<g>
<path d="M 240 260 L 240 230" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 237 260 L 243 260" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 243 230 L 237 230" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 247px; margin-left: 241px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Pipelined TMA wait and issue
</div>
</div>
</div>
</foreignObject>
<text x="241" y="250" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Pipelined TMA wait and issue
</text>
</switch>
</g>
</g>
<g>
<path d="M 240 560 L 240 540" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 237 560 L 243 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 243 540 L 237 540" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 551px; margin-left: 231px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
sP1 = rP1b
</div>
</div>
</div>
</foreignObject>
<text x="231" y="554" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
sP1 = rP1b
</text>
</switch>
</g>
</g>
<g>
<path d="M 110 310 L 107.5 310 Q 105 310 105 320 L 105 327.5 Q 105 335 102.5 335 L 101.25 335 Q 100 335 102.5 335 L 103.75 335 Q 105 335 105 345 L 105 352.5 Q 105 360 107.5 360 L 110 360" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<rect x="20" y="320" width="90" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 335px; margin-left: 65px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; ">
<div style="display: inline-block; font-size: 12px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; white-space: nowrap; ">
wg0-bunch-0
</div>
</div>
</div>
</foreignObject>
<text x="65" y="339" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="12px" text-anchor="middle">
wg0-bunch-0
</text>
</switch>
</g>
</g>
<g>
<path d="M 300 480 L 297.5 480 Q 295 480 295 490 L 295 497.5 Q 295 505 292.5 505 L 291.25 505 Q 290 505 292.5 505 L 293.75 505 Q 295 505 295 515 L 295 522.5 Q 295 530 297.5 530 L 300 530" fill="none" stroke="#000000" stroke-miterlimit="10" transform="translate(295,0)scale(-1,1)translate(-295,0)" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<rect x="290" y="490" width="90" height="30" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 505px; margin-left: 335px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; ">
<div style="display: inline-block; font-size: 12px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; white-space: nowrap; ">
wg1-bunch-0
</div>
</div>
</div>
</foreignObject>
<text x="335" y="509" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="12px" text-anchor="middle">
wg1-bunch-0
</text>
</switch>
</g>
</g>
<g>
<path d="M 160 580 L 160 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 157 580 L 163 580" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 163 560 L 157 560" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 570px; margin-left: 161px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue TMA (nxt V0L)
</div>
</div>
</div>
</foreignObject>
<text x="161" y="573" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue TMA (nxt V0L)
</text>
</switch>
</g>
</g>
<g>
<path d="M 160 820 L 160 800" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 157 820 L 163 820" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 163 800 L 157 800" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 810px; margin-left: 161px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue TMA (nxt V1L)
</div>
</div>
</div>
</foreignObject>
<text x="161" y="813" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue TMA (nxt V1L)
</text>
</switch>
</g>
</g>
<g>
<path d="M 90 800 L 145.13 800" fill="none" stroke="#d6b656" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
<path d="M 148.88 800 L 143.88 802.5 L 145.13 800 L 143.88 797.5 Z" fill="#d6b656" stroke="#d6b656" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(214, 182, 86), rgb(109, 81, 0)); stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
</g>
<g>
<path d="M 310 640 L 254.87 640" fill="none" stroke="#d6b656" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
<path d="M 251.12 640 L 256.12 637.5 L 254.87 640 L 256.12 642.5 Z" fill="#d6b656" stroke="#d6b656" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(214, 182, 86), rgb(109, 81, 0)); stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
</g>
<g>
<path d="M 240 660 L 240 640" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 237 660 L 243 660" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 243 640 L 237 640" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 651px; margin-left: 241px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue TMA (nxt V1R)
</div>
</div>
</div>
</foreignObject>
<text x="241" y="654" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue TMA (nxt V1R)
</text>
</switch>
</g>
</g>
<g>
<path d="M 240 740 L 240 720" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 237 740 L 243 740" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
<path d="M 243 720 L 237 720" fill="none" stroke="#000000" stroke-miterlimit="10" pointer-events="all" style="stroke: light-dark(rgb(0, 0, 0), rgb(255, 255, 255));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 731px; margin-left: 241px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; background-color: #ffffff; ">
<div style="display: inline-block; font-size: 11px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; background-color: light-dark(#ffffff, var(--ge-dark-color, #121212)); white-space: nowrap; ">
Issue TMA (nxt V0R)
</div>
</div>
</div>
</foreignObject>
<text x="241" y="734" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="11px" text-anchor="middle">
Issue TMA (nxt V0R)
</text>
</switch>
</g>
</g>
<g>
<path d="M 310 720 L 254.87 720" fill="none" stroke="#d6b656" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
<path d="M 251.12 720 L 256.12 717.5 L 254.87 720 L 256.12 722.5 Z" fill="#d6b656" stroke="#d6b656" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(214, 182, 86), rgb(109, 81, 0)); stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
</g>
<g>
<path d="M 90 300 L 145.13 300" fill="none" stroke="#d6b656" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
<path d="M 148.88 300 L 143.88 302.5 L 145.13 300 L 143.88 297.5 Z" fill="#d6b656" stroke="#d6b656" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(214, 182, 86), rgb(109, 81, 0)); stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
</g>
<g>
<path d="M 310 480 L 254.87 480" fill="none" stroke="#d6b656" stroke-miterlimit="10" stroke-dasharray="3 3" pointer-events="stroke" style="stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
<path d="M 251.12 480 L 256.12 477.5 L 254.87 480 L 256.12 482.5 Z" fill="#d6b656" stroke="#d6b656" stroke-miterlimit="10" pointer-events="all" style="fill: light-dark(rgb(214, 182, 86), rgb(109, 81, 0)); stroke: light-dark(rgb(214, 182, 86), rgb(109, 81, 0));"/>
</g>
<g>
<rect x="330" y="100" width="190" height="40" fill="none" stroke="#c0c0c0" stroke-dasharray="8 8" pointer-events="all" style="stroke: light-dark(rgb(192, 192, 192), rgb(127, 127, 127));"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe flex-start; width: 1px; height: 1px; padding-top: 120px; margin-left: 332px;">
<div style="box-sizing: border-box; font-size: 0; text-align: left; color: #000000; ">
<div style="display: inline-block; font-size: 12px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; white-space: nowrap; ">
sXX: Stored on shared memory
<div>
rXX: Stored on register file
</div>
</div>
</div>
</div>
</foreignObject>
<text x="332" y="124" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="12px">
sXX: Stored on shared memory...
</text>
</switch>
</g>
</g>
<g>
<path d="M 20 305 L 220 305" fill="none" stroke="#82b366" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(130, 179, 102), rgb(68, 110, 44));"/>
</g>
<g>
<path d="M 220 485 L 520 485" fill="none" stroke="#82b366" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(130, 179, 102), rgb(68, 110, 44));"/>
</g>
<g>
<path d="M 220 305 L 220 485" fill="none" stroke="#82b366" stroke-miterlimit="10" pointer-events="stroke" style="stroke: light-dark(rgb(130, 179, 102), rgb(68, 110, 44));"/>
</g>
<g>
<rect x="330" y="450" width="200" height="40" fill="none" stroke="none" pointer-events="all"/>
</g>
<g>
<g transform="translate(-0.5 -0.5)">
<switch>
<foreignObject style="overflow: visible; text-align: left;" pointer-events="none" width="100%" height="100%" requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility">
<div xmlns="http://www.w3.org/1999/xhtml" style="display: flex; align-items: unsafe center; justify-content: unsafe center; width: 1px; height: 1px; padding-top: 470px; margin-left: 430px;">
<div style="box-sizing: border-box; font-size: 0; text-align: center; color: #000000; ">
<div style="display: inline-block; font-size: 10px; font-family: &quot;Helvetica&quot;; color: light-dark(#000000, #ffffff); line-height: 1.2; pointer-events: all; white-space: nowrap; ">
<font style="font-size: 9px;">
Loop boundary in our code
</font>
<div>
<font style="font-size: 9px;">
(plz refer to comments in `wg1_subroutine`)
</font>
</div>
</div>
</div>
</div>
</foreignObject>
<text x="430" y="473" fill="light-dark(#000000, #ffffff)" font-family="&quot;Helvetica&quot;" font-size="10px" text-anchor="middle">
Loop boundary in our code...
</text>
</switch>
</g>
</g>
</g>
<switch>
<g requiredFeatures="http://www.w3.org/TR/SVG11/feature#Extensibility"/>
<a transform="translate(0,-5)" xlink:href="https://www.drawio.com/doc/faq/svg-export-text-problems" target="_blank">
<text text-anchor="middle" font-size="10px" x="50%" y="100%">
Text is not SVG - cannot display
</text>
</a>
</switch>
</svg>

After

Width:  |  Height:  |  Size: 74 KiB

View File

@ -55,7 +55,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,

View File

@ -11,29 +11,13 @@ from torch.utils.cpp_extension import (
IS_WINDOWS,
)
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return nvcc_extra_args + ["--threads", nvcc_threads]
def get_sources():
sources = [
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_metadata.cu",
]
if not DISABLE_FP16:
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
return sources
def get_features_args():
features_args = []
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
if DISABLE_FP16:
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args
@ -56,7 +40,12 @@ ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_cuda",
sources=get_sources(),
sources=[
"csrc/flash_api.cpp",
"csrc/kernels/get_mla_metadata.cu",
"csrc/kernels/mla_combine.cu",
"csrc/kernels/splitkv_mla.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
"nvcc": append_nvcc_threads(

View File

@ -127,7 +127,7 @@ def main(torch_dtype):
causal = True
for b in [128]:
for s in [4096, 8192]:
for s in [4096, 8192, 16384]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]: