mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Initial commit
i
This commit is contained in:
commit
414a2f3eed
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
build
|
||||
*.so
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
dist/
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "csrc/cutlass"]
|
||||
path = csrc/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
61
README.md
Normal file
61
README.md
Normal file
@ -0,0 +1,61 @@
|
||||
# FlashMLA
|
||||
|
||||
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
||||
|
||||
Currently released:
|
||||
- BF16
|
||||
- Paged kvcache with block size of 64
|
||||
|
||||
## Quick start
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Benchmark
|
||||
|
||||
```bash
|
||||
python tests/test_flash_mla.py
|
||||
```
|
||||
|
||||
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6.
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
for i in range(num_layers):
|
||||
...
|
||||
o_i, lse_i = flash_mla_with_kvcache(
|
||||
q_i, kvcache_i, block_table, cache_seqlens, dv,
|
||||
tile_scheduler_metadata, num_splits, causal=True,
|
||||
)
|
||||
...
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Hopper GPUs
|
||||
- CUDA 12.3 and above
|
||||
- PyTorch 2.0 and above
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{flashmla2025,
|
||||
title={FlashMLA: Efficient MLA decoding kernel},
|
||||
author={Jiashi Li},
|
||||
year={2025},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
|
||||
}
|
||||
```
|
||||
1
csrc/cutlass
Submodule
1
csrc/cutlass
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008
|
||||
203
csrc/flash_api.cpp
Normal file
203
csrc/flash_api.cpp
Normal file
@ -0,0 +1,203 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
|
||||
|
||||
#include <torch/python.h>
|
||||
#include <torch/nn/functional.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cutlass/fast_math.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
std::vector<at::Tensor>
|
||||
get_mla_metadata(
|
||||
at::Tensor &seqlens_k,
|
||||
const int num_heads_per_head_k,
|
||||
const int num_heads_k
|
||||
) {
|
||||
// This should match the logic in the MLA kernel.
|
||||
static constexpr int block_size_m = 64;
|
||||
static constexpr int block_size_n = 64;
|
||||
static constexpr int fixed_overhead_num_blocks = 5;
|
||||
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
TORCH_CHECK(seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
int batch_size = seqlens_k.size(0);
|
||||
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
|
||||
auto options = seqlens_k.options();
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
int sm_count = dprops->multiProcessorCount;
|
||||
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
|
||||
|
||||
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
|
||||
auto num_splits = torch::empty({batch_size + 1}, options);
|
||||
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
int *num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Mla_metadata_params params = {};
|
||||
params.seqlens_k_ptr = seqlens_k_ptr;
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
|
||||
params.num_splits_ptr = num_splits_ptr;
|
||||
params.batch_size = batch_size;
|
||||
params.block_size_n = block_size_n;
|
||||
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
||||
params.num_sm_parts = num_sm_parts;
|
||||
get_mla_metadata_func(params, stream);
|
||||
|
||||
return {tile_scheduler_metadata, num_splits};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_kvcache_mla(
|
||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
const int head_size_v,
|
||||
const at::Tensor &seqlens_k, // batch_size
|
||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor &num_splits // batch_size + 1
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90);
|
||||
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kBFloat16);
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q_ori = sizes[1];
|
||||
const int num_heads_ori = sizes[2];
|
||||
const int head_size = sizes[3];
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
|
||||
|
||||
const int max_num_blocks_per_seq = block_table.size(1);
|
||||
const int num_blocks = kcache.size(0);
|
||||
const int page_block_size = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
const int ngroups = num_heads_ori / num_heads_k;
|
||||
const int seqlen_q = seqlen_q_ori * ngroups;
|
||||
const int num_heads = num_heads_k;
|
||||
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q, num_heads, head_size});
|
||||
|
||||
int head_size_k = head_size;
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
|
||||
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
|
||||
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
CHECK_SHAPE(seqlens_k, batch_size);
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
Flash_fwd_mla_params params = {};
|
||||
// Set the sizes.
|
||||
params.b = batch_size;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
|
||||
params.h = num_heads;
|
||||
params.h_h_k_ratio = num_heads / num_heads_k;
|
||||
params.ngroups = ngroups;
|
||||
params.is_causal = is_causal;
|
||||
params.d = head_size;
|
||||
params.d_v = head_size_v;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = kcache.data_ptr();
|
||||
params.v_ptr = vcache.data_ptr();
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = kcache.stride(0);
|
||||
params.v_batch_stride = vcache.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = kcache.stride(-3);
|
||||
params.v_row_stride = vcache.stride(-3);
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = kcache.stride(-2);
|
||||
params.v_head_stride = vcache.stride(-2);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
params.num_sm_parts = tile_scheduler_metadata.size(0);
|
||||
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
||||
CHECK_DEVICE(num_splits);
|
||||
CHECK_CONTIGUOUS(num_splits);
|
||||
params.num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
||||
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
||||
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "FlashMLA";
|
||||
m.def("get_mla_metadata", &get_mla_metadata);
|
||||
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
|
||||
}
|
||||
3
csrc/flash_fwd_mla_bf16_sm90.cu
Normal file
3
csrc/flash_fwd_mla_bf16_sm90.cu
Normal file
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
679
csrc/flash_fwd_mla_kernel.h
Normal file
679
csrc/flash_fwd_mla_kernel.h
Normal file
@ -0,0 +1,679 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#include "named_barrier.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "static_switch.h"
|
||||
#include "flash_mla.h"
|
||||
|
||||
|
||||
template<typename PrecType, int DIM, int DIM2 = DIM>
|
||||
constexpr auto getSmemLayoutK() {
|
||||
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
||||
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
||||
|
||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||
return GMMA::Layout_K_SW128_Atom<PrecType>{};
|
||||
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
||||
return GMMA::Layout_K_SW64_Atom<PrecType>{};
|
||||
} else {
|
||||
return GMMA::Layout_K_SW32_Atom<PrecType>{};
|
||||
}
|
||||
}
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
struct Flash_fwd_kernel_traits_mla {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
static constexpr int kNWarpsS = 4;
|
||||
static constexpr int kNThreadsS = kNWarpsS * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
||||
static_assert(kHeadDimV % 32 == 0);
|
||||
static_assert(kHeadDimV <= kHeadDim);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = decltype(make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
||||
GMMA::Major::K, GMMA::Major::K>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
|
||||
|
||||
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
|
||||
using TiledMmaO = decltype(make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim>(),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
|
||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(composition(
|
||||
Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
||||
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomO = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomO{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
||||
using GmemLayoutAtomOaccum = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
||||
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct SharedStorageMLA {
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||
};
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
||||
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
|
||||
// Epilogue
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
||||
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
||||
|
||||
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
using SmemTiledCopyO = std::conditional_t<
|
||||
!Split,
|
||||
typename Kernel_traits::SmemCopyAtomO,
|
||||
typename Kernel_traits::SmemCopyAtomOaccum
|
||||
>;
|
||||
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
|
||||
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor rO = flash::convert_type<ElementO>(tOrO);
|
||||
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
__syncthreads();
|
||||
|
||||
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
||||
|
||||
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
|
||||
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
|
||||
GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tidx >= kNThreadsS) { return; }
|
||||
|
||||
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
||||
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
|
||||
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
|
||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||
if (get<1>(taccOcO_row(0)) == 0) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccOcO_row(mi));
|
||||
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms,
|
||||
const int bidb, const int bidh, const int m_block,
|
||||
const int n_split_idx, const int seqlen_k,
|
||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
||||
SharedStorage &shared_storage) {
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
||||
constexpr int kNThreads = Kernel_traits::kNThreads;
|
||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
||||
static_assert(kNThreads == 256 and kNThreadsS == 128);
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
|
||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
||||
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
|
||||
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
|
||||
|
||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
||||
clear(tOrO);
|
||||
|
||||
flash::Softmax<2 * size<1>(tOrO)> softmax;
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 0) {
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
if (n_block % 2 == 1) {
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
|
||||
__syncthreads();
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
|
||||
|
||||
const bool is_masking_step = masking_step > 0;
|
||||
const bool is_first_masking_step = masking_step == n_masking_steps;
|
||||
|
||||
if (is_masking_step) {
|
||||
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
|
||||
Tensor tScS = thr_mma.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if constexpr (!Is_causal) { // Just masking based on col
|
||||
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
|
||||
} else {
|
||||
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
||||
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
||||
int row = int(get<0>(tScS(i)));
|
||||
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
|
||||
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We have key_padding_mask so we'll need to Check_inf
|
||||
Tensor scale_o = is_first_masking_step
|
||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: is_masking_step ?
|
||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
||||
cute::copy(rP, tPsP);
|
||||
cute::copy(scale_o, tScale_osScale_o);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
} else {
|
||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||
int cur_block_table = __ldg(&block_table[n_block]);
|
||||
|
||||
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
params.seqlen_q - m_block * kBlockM);
|
||||
|
||||
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
|
||||
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
|
||||
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
|
||||
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
|
||||
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
if (n_block % 2 == 1) {
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need to clear the sK smem tiles because K is V.
|
||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
||||
tKgK.data() = tKgK.data() + offset_k;
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
|
||||
seqlen_k - n_block * kBlockN);
|
||||
tKgK.data() = tKgK.data() + -offset_k;
|
||||
cute::cp_async_fence();
|
||||
|
||||
if (n_block - 1 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 1]);
|
||||
}
|
||||
|
||||
#pragma unroll 1
|
||||
for (; n_block >= n_block_min; --n_block) {
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if (n_block - 1 >= n_block_min) {
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
|
||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
||||
tKgK.data() = tKgK.data() + offset_k;
|
||||
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
|
||||
tKgK.data() = tKgK.data() + -offset_k;
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
if (n_block - 2 >= n_block_min) {
|
||||
cur_block_table = __ldg(&block_table[n_block - 2]);
|
||||
}
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
||||
Tensor rP = make_tensor<Element>(tSrS_layout);
|
||||
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
||||
cute::copy(tScale_osScale_o, scale_o);
|
||||
cute::copy(tPsP, rP);
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
cute::copy(tRow_maxsRow_max, softmax.row_max);
|
||||
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
|
||||
}
|
||||
|
||||
if (NoSplit)
|
||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
else
|
||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
||||
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
|
||||
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
|
||||
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
||||
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
||||
int begin_idx = tile_scheduler_metadata.x;
|
||||
int begin_seqlen = tile_scheduler_metadata.y;
|
||||
int end_idx = tile_scheduler_metadata.z;
|
||||
int end_seqlen = tile_scheduler_metadata.w;
|
||||
if (begin_idx >= params.b) return;
|
||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
||||
|
||||
#pragma unroll 1
|
||||
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
|
||||
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
|
||||
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
|
||||
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
|
||||
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
|
||||
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
|
||||
if (batch_id > begin_idx) {
|
||||
__syncthreads(); // Barrier between two tiles.
|
||||
}
|
||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
constexpr int kNThreads = 128;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
const int hs = params.h * params.seqlen_q;
|
||||
const int batch_idx = bidx / hs;
|
||||
const int hs_idx = bidx % hs;
|
||||
|
||||
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
|
||||
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
|
||||
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
|
||||
if (actual_num_splits == 1) return;
|
||||
|
||||
__shared__ ElementAccum sLseScale[kMaxSplits];
|
||||
|
||||
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
|
||||
const index_t row_offset_lse = bidx;
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
||||
Shape<Int<kMaxSplits>>{}, make_stride(hs));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
float local_lse[kNLsePerThread];
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + tidx;
|
||||
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
|
||||
}
|
||||
|
||||
float max_lse = -INFINITY;
|
||||
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
|
||||
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
|
||||
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
|
||||
|
||||
float sum_lse = 0;
|
||||
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
|
||||
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
|
||||
|
||||
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
|
||||
if (tidx == 0) gLSE(0) = global_lse;
|
||||
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + tidx;
|
||||
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
static_assert(kHeadDimV % kNThreads == 0);
|
||||
constexpr int Elements = kHeadDimV / kNThreads;
|
||||
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape<Int<kNThreads>>>{},
|
||||
Layout<Shape<Int<Elements>>>{}));
|
||||
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
||||
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
clear(tOrO);
|
||||
|
||||
for (int split = 0; split < actual_num_splits; ++split) {
|
||||
cute::copy(tOgOaccum, tOrOaccum);
|
||||
ElementAccum lse_scale = sLseScale[split];
|
||||
for (int i = 0; i < size(tOrO); ++i) {
|
||||
tOrO(i) += lse_scale * tOrOaccum(i);
|
||||
}
|
||||
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
|
||||
}
|
||||
|
||||
Tensor rO = flash::convert_type<Element>(tOrO);
|
||||
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
|
||||
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
|
||||
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
|
||||
cute::copy(rO, gO);
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, typename SharedStorage>
|
||||
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
|
||||
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
|
||||
constexpr size_t smem_size = sizeof(SharedStorage);
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
|
||||
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
||||
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
||||
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
||||
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
|
||||
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(Headdim == 576);
|
||||
FLASH_ASSERT(params.d_v == 512);
|
||||
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||
}
|
||||
|
||||
static constexpr int MaxBatchSize = 4096;
|
||||
|
||||
__global__ void __launch_bounds__(256, 1, 1)
|
||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||
int *num_splits_ptr = params.num_splits_ptr;
|
||||
int batch_size = params.batch_size;
|
||||
int block_size_n = params.block_size_n;
|
||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||
int num_sm_parts = params.num_sm_parts;
|
||||
|
||||
__shared__ int num_blocks_shared[MaxBatchSize];
|
||||
__shared__ int num_splits_shared[MaxBatchSize];
|
||||
|
||||
int total_num_blocks = 0;
|
||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||
num_blocks_shared[i] = num_blocks;
|
||||
}
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
|
||||
|
||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||
num_splits_shared[0] = 0;
|
||||
for (int i = 0; i < num_sm_parts; ++i) {
|
||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||
tile_scheduler_metadata0[0] = now_idx;
|
||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||
tile_scheduler_metadata1 = now_n_split_idx;
|
||||
int remain_payload = payload;
|
||||
while (now_idx < batch_size) {
|
||||
int num_blocks = num_blocks_shared[now_idx];
|
||||
int now_remain_blocks = num_blocks - now_block;
|
||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||
cum_num_splits += now_n_split_idx + 1;
|
||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||
++now_idx;
|
||||
now_block = 0;
|
||||
now_n_split_idx = 0;
|
||||
} else {
|
||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||
++now_n_split_idx;
|
||||
remain_payload = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||
}
|
||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||
num_splits_ptr[i] = num_splits_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||
FLASH_ASSERT(params.batch_size < MaxBatchSize);
|
||||
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
63
csrc/flash_mla.h
Normal file
63
csrc/flash_mla.h
Normal file
@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_mla_params {
|
||||
using index_t = int64_t;
|
||||
|
||||
int b, seqlen_q, d, d_v;
|
||||
int h, h_h_k_ratio, ngroups;
|
||||
bool is_causal;
|
||||
float scale_softmax, scale_softmax_log2;
|
||||
int *__restrict__ cu_seqlens_k;
|
||||
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void *__restrict__ o_ptr;
|
||||
void *__restrict__ softmax_lse_ptr;
|
||||
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t o_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t o_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
int *__restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
int *__restrict__ tile_scheduler_metadata_ptr;
|
||||
int num_sm_parts;
|
||||
int *__restrict__ num_splits_ptr;
|
||||
|
||||
void *__restrict__ softmax_lseaccum_ptr;
|
||||
void *__restrict__ oaccum_ptr;
|
||||
};
|
||||
|
||||
static constexpr int TileSchedulerMetaDataSize = 8;
|
||||
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
||||
struct Mla_metadata_params {
|
||||
int *__restrict__ seqlens_k_ptr;
|
||||
int *__restrict__ tile_scheduler_metadata_ptr;
|
||||
int *__restrict__ num_splits_ptr;
|
||||
int batch_size;
|
||||
int block_size_n;
|
||||
int fixed_overhead_num_blocks;
|
||||
int num_sm_parts;
|
||||
};
|
||||
|
||||
void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
15
csrc/named_barrier.h
Normal file
15
csrc/named_barrier.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Enumerates the reserved named barriers to avoid potential conflicts
|
||||
|
||||
enum class NamedBarriers {
|
||||
SReady = 1,
|
||||
SoftmaxReady = 2,
|
||||
};
|
||||
|
||||
} // flash
|
||||
197
csrc/softmax.h
Normal file
197
csrc/softmax.h
Normal file
@ -0,0 +1,197 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
// The following macro will disable the use of fma.
|
||||
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
||||
// This macro is set in PyTorch and not FlashAttention
|
||||
#ifdef UNFUSE_FMA
|
||||
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
||||
#else
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
SumOp<float> sum_op;
|
||||
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scale_o); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scale_o;
|
||||
clear(scale_o);
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scale_o(mi) = scores_scale;
|
||||
row_sum(mi) *= scores_scale;
|
||||
}
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
return scale_o;
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
65
csrc/static_switch.h
Normal file
65
csrc/static_switch.h
Normal file
@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if (status_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
||||
|
||||
|
||||
#define FLASH_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define FLASH_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
|
||||
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
|
||||
[&] { \
|
||||
if (NUM_SPLITS <= 32) { \
|
||||
constexpr static int NAME = 32; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 64) { \
|
||||
constexpr static int NAME = 64; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 96) { \
|
||||
constexpr static int NAME = 96; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 128) { \
|
||||
constexpr static int NAME = 128; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 160) { \
|
||||
constexpr static int NAME = 160; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
FLASH_ASSERT(false); \
|
||||
} \
|
||||
}()
|
||||
238
csrc/utils.h
Normal file
238
csrc/utils.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
template<bool Transposed=false, typename Layout0>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
if constexpr (!Transposed) {
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else {
|
||||
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
||||
}
|
||||
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
if constexpr (!Transposed) {
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
} else {
|
||||
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
||||
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
|
||||
template<typename MMA_Traits, typename Layout0>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
|
||||
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
|
||||
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
||||
} else {
|
||||
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
|
||||
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
|
||||
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
|
||||
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
|
||||
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
|
||||
// Will require register shuffling later to be correct.
|
||||
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
|
||||
get<1>(acc_layout),
|
||||
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
|
||||
// This combination is right but doesn't work with register shuffling.
|
||||
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
|
||||
// get<1>(acc_layout),
|
||||
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
|
||||
}
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
||||
// (which is equivalent to commit_group then wait_group 0).
|
||||
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
||||
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void cp_async_wait() {
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
cute::clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
6
flash_mla/__init__.py
Normal file
6
flash_mla/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from flash_mla.flash_mla_interface import (
|
||||
get_mla_metadata,
|
||||
flash_mla_with_kvcache,
|
||||
)
|
||||
67
flash_mla/flash_mla_interface.py
Normal file
67
flash_mla/flash_mla_interface.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import flash_mla_cuda
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
78
setup.py
Normal file
78
setup.py
Normal file
@ -0,0 +1,78 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import subprocess
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CUDAExtension,
|
||||
)
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
|
||||
cc_flag = []
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90a,code=sm_90a")
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
sources=[
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"],
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-DNDEBUG",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v,--register-usage-level=10"
|
||||
]
|
||||
+ cc_flag
|
||||
),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / "csrc",
|
||||
Path(this_dir) / "csrc" / "cutlass" / "include",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except Exception as _:
|
||||
now = datetime.now()
|
||||
date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
rev = '+' + date_time_str
|
||||
|
||||
|
||||
setup(
|
||||
name="flash_mla",
|
||||
version="1.0.0" + rev,
|
||||
packages=find_packages(include=['flash_mla']),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
114
tests/test_flash_mla.py
Normal file
114
tests/test_flash_mla.py
Normal file
@ -0,0 +1,114 @@
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
amax_diff = (x - y).abs().max().item()
|
||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
|
||||
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
mean_seqlens = cache_seqlens.float().mean().int().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q, blocked_k, block_table, cache_seqlens, dv,
|
||||
tile_scheduler_metadata, num_splits, causal=causal,
|
||||
)
|
||||
|
||||
def ref_mla():
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla, fast_flush=False)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
h_kv = 1
|
||||
d, dv = 576, 512
|
||||
causal = True
|
||||
|
||||
for b in [128]:
|
||||
for s in [4096, 8192]:
|
||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [1, 2]: # MTP = 1, 2
|
||||
for varlen in [False, True]:
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
|
||||
Loading…
Reference in New Issue
Block a user