Feature:Support flashMLA decoding via flashAttn2(#29)

Changes:
1. Implement flashMLA with matrix absorption algorithm via flashAttn2
2. Add golden test on MXMACA platform
This commit is contained in:
Kevin Zhang 2025-02-24 23:55:21 +08:00
parent bcb90f2afd
commit e0557deb3a
18 changed files with 1197 additions and 702 deletions

11
LICENSE_MX Normal file
View File

@ -0,0 +1,11 @@
本软件适用MIT许可证。 本软件中的第三方开源组件的许可证详见第三方开源组件的声明。
This software is licensed under MIT. For the licenses of third-party open source components in this software, please refer to the statement of third-party open source components for details.
Copyright © 2025 MetaX Integrated Circuits (Shanghai) Co., Ltd.
版权所有©2025 沐曦集成电路(上海)股份有限公司。
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

28
README_MX.md Normal file
View File

@ -0,0 +1,28 @@
# FlashMLA on MXMACA
We provide the implementation of FlashMLA from FlashAttention-2(version 2.6.3), based on MACA toolkit and C500 chips.
FlashAttention-2 currently supports:
1. Datatype fp16 and bf16.
2. Multi-Token Parallelism = 1
3. Paged kvcache with block size equal to 2^n (n >= 0)
## How to run on MXMACA Device
## Installation
Requirements:
- MXMACA GPUs.
- MACA development toolkit.
- Mctlass source code.
- Pytorch2.0 from maca toolkit wheel package and above.
To install flash attn in conda env:
1. Make sure that maca pyTorch2.0 is installed.
2. Download mctlass source code from: https://sw-download.metax-tech.com/
### Set environment variables
```bash
export MACA_PATH=/your/maca/path
export CUDA_PATH=$MACA_PATH/tools/cu-bridge
export MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin
export LD_LIBRARY_PATH=$MACA_PATH/lib:$MACA_PATH/mxgpu_llvm/lib:$MACA_PATH/ompi/lib:$LD_LIBRARY_PATH
```

@ -1 +0,0 @@
Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008

View File

@ -1,61 +1,30 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/fast_math.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "flash_mla.h"
#include "static_switch.h"
#include <mctlass/numeric_types.h>
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#include "flash_api_mla.h"
std::vector<at::Tensor>
get_mla_metadata(
at::Tensor &seqlens_k,
const int num_heads_per_head_k,
const int num_heads_k
) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
auto options = seqlens_k.options();
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
Mla_metadata_params params = {};
params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts;
get_mla_metadata_func(params, stream);
return {tile_scheduler_metadata, num_splits};
}
);
std::vector<at::Tensor>
mha_fwd_kvcache_mla(
@ -69,135 +38,12 @@ mha_fwd_kvcache_mla(
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.is_causal = is_causal;
params.d = head_size;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.doc() = "FlashAttention";
//FlashMLA
m.def("get_mla_metadata", &get_mla_metadata);
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
}
}

214
csrc/flash_api_mla.cpp Normal file
View File

@ -0,0 +1,214 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <mctlass/fast_math.h>
#include "flash_api_mla.h"
#include "static_switch.h"
#include "run_mha.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
std::vector<at::Tensor>
get_mla_metadata(
at::Tensor &seqlens_k,
const int num_heads_per_head_k,
const int num_heads_k
) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
printf("%s,num_heads_per_head_k:%d,num_heads_k:%d\n",__func__,num_heads_per_head_k,num_heads_k);
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
auto options = seqlens_k.options();
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
int num_sm_parts = sm_count / num_heads_k / mctlass::ceil_div(num_heads_per_head_k, block_size_m);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
Mla_metadata_params params = {};
params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts;
// get_mla_metadata_func(params, stream);
return {tile_scheduler_metadata, num_splits};
}
std::vector<at::Tensor>
mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
mcFlashAttn::Flash_fwd_params params = {};
params.rotary_dim = 0;
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.is_seqlens_k_cumulative = false; // seqlens_k always has value
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
// params.ngroups = ngroups; to be check
params.is_causal = is_causal;
params.d = head_size;
params.d_rounded = head_size;
params.d_value = head_size_v;
params.d_value_rounded = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
// params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
// params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
// params.num_splits_ptr = num_splits.data_ptr<int>();
params.num_splits = 1;
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(torch::kFloat32));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
// run_mha_fwd_splitkv_mla<mctlass::bfloat16_t, 576>(params, stream);
params.is_bf16 = true;
run_mha_fwd(params,stream, /*force_split_kernel*/true);
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
return {out, softmax_lse};
}
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.doc() = "FlashMLA";
// m.def("get_mla_metadata", &get_mla_metadata);
// m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
// }

View File

@ -1,3 +0,0 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -0,0 +1,599 @@
#pragma once
#include <cute/algorithm/copy.hpp>
#include <mctlass/mctlass.h>
#include <mctlass/array.h>
#include <mctlass/numeric_types.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "mask.h"
#include "dropout.h"
#include "rotary.h"
#include "attn_mask.h"
namespace flash {
using namespace cute;
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Is_page_attn, typename Params>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_k64_mla_V1x8(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kBlockKSmem = Kernel_traits::kBlockKSmem;
constexpr int kBlockKGmem = Kernel_traits::UseWarpsNx1 ? Kernel_traits::kBlockKSmem : 128;
static_assert(kBlockKSmem == 64);
static_assert(kBlockM % (kNWarps * 16) == 0);
using GmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::GmemTiledCopyO,
typename Kernel_traits::GmemTiledCopyOaccum
>;
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
const int n_block_min = !Is_local
? n_split_idx * n_blocks_per_split
: std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
// Otherwise we might read OOB elements from gK and gV,
// or get wrong results when we combine gOaccum from different blocks.
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
clear(tOrOaccum);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d_value; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgOaccum); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
}
return;
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.v_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
//Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutVtNoSwizzle{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtNoSwizzle{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QK;
auto gmem_thr_copy_QK = gmem_tiled_copy_QK.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QK.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QK.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QK.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QK.partition_D(sK);
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_V;
auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx);
Tensor tVgV = gmem_thr_copy_V.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_V.partition_D(sV);
Tensor tVrV = make_fragment_like(tVgV);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomB64{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomB64{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK_ori = smem_thr_copy_K.partition_S(sK);
const int offset_swz334 = (__lane_id() % 2) ? ((__lane_id() / 16 % 2) ? -4 : 4) : 0;
Tensor tSsK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(tSsK_ori.data().get()) + offset_swz334), tSsK_ori.layout());
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVtNoSwizzle);
// PREDICATES
//
// // Allocate predicate tensors for m and n
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QK.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QK.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tVcV = gmem_thr_copy_V.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
// Set predicates for k bounds
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}
// Prologue
// Copy from Knew to K, optionally apply rotary embedding.
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
if constexpr (Append_KV) {
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.knew_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_QK.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew = gmem_thr_copy_QK.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
Tensor tVgVcache = gmem_thr_copy_QK.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data();
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
flash::copy_w_min_idx<Is_even_K>(
tVgVnew, tVgVcache, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
if (params.rotary_dim == 0) {
flash::copy_w_min_idx<Is_even_K>(
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
} else {
if (params.is_rotary_interleaved) {
// Don't clear OOB_K because we're writing to global memory
flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
);
tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
} else {
// Don't clear OOB_K because we're writing to global memory
flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
);
tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
}
}
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
if constexpr (!Is_page_attn) {
tVgVcache.data() = tVgVcache.data() + (-int(kBlockN * params.v_row_stride));
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgVcache.data() = tVgVcache.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
}
}
}
// Need this before we can read in K again, so that we'll see the updated K values.
__syncthreads();
tKgK.data() = tKgK_data;
}
// Read Q from gmem to smem, optionally apply rotary embedding.
Tensor tQrQ = make_fragment_like(tQgQ);
if (!Append_KV || params.rotary_dim == 0) {
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy_b128<Is_even_MN, Is_even_K>(tQgQ, tQrQ, tQcQ, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
if (params.is_rotary_interleaved) {
flash::copy_rotary_interleaved<Is_even_K>(
tQgQ, tQrQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
0, params.d, params.rotary_dim
);
} else {
flash::copy_rotary_contiguous<Is_even_K>(
tQgQ, tQrQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
0, params.d, params.rotary_dim
);
}
}
cute::copy(tQrQ, tQsQ);
if (Kernel_traits::Is_Q_in_regs) {
flash::sync_threads();
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ);
flash::sync_threads();
}
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
Tensor tKrK = make_fragment_like(tKgK);
if constexpr (!Is_page_attn) {
flash::copy_b128<Is_even_MN, Is_even_K>(tKgK, tKrK, tKVcKV, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
flash::copy_b128_page_one<Kernel_traits, Is_even_MN, Is_even_K>(gK, tKgK, tKrK, tKVcKV, params.d, n_block,
block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size, binfo.actual_seqlen_k - n_block * kBlockN);
}
// flash::cp_async_wait<0>();
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
// __syncthreads();
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
swap_swz334(tKrK);
cute::copy(tKrK, tKsK);
clear(acc_s);
// Advance gV
if (masking_step > 0) {
if constexpr (!Is_page_attn) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tVgV, tVrV, tVcV, params.d_value);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gV, tVgV, tVrV, tVcV, params.d_value, n_block,
block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size);
}
} else {
if constexpr (!Is_page_attn) {
// Clear the smem tiles to account for predicated off loads
flash::copy_b128<Is_even_MN, Is_even_K>(
tVgV, tVrV, tVcV, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN
);
} else {
flash::copy_b128_page_one<Kernel_traits, Is_even_MN, Is_even_K>(gV, tVgV, tVrV, tVcV, params.d_value, n_block,
block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
flash::sync_threads();
flash::gemm_opt</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 64) * 16 + (tidx & 0xf), kNWarps * 16
);
cute::copy(tVrV, tVsV);
if (n_block > n_block_min) {
// Advance gK
if constexpr (!Is_page_attn) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tKgK, tKrK, tKVcKV, params.d);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gK, tKgK, tKrK, tKVcKV, params.d, n_block - 1,
block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size);
}
}
// We have key_padding_mask so we'll need to Check_inf
masking_step == 0
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN, true, true>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN, true, true>(acc_s, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert acc_s from fp32 to fp16/bf16
//Tensor rP = flash::convert_type<Element>(acc_s);
CONVERT_TENSOR_TYPE(ElementAccum, Element, acc_s, rP)
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
//Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), acc_s.layout());
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
swap_swz334(tKrK);
cute::copy(tKrK, tKsK);
clear(acc_s);
// Advance gV
if constexpr (!Is_page_attn) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tVgV, tVrV, tVcV, params.d_value);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gV, tVgV, tVrV, tVcV, params.d_value, n_block,
block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size);
}
flash::sync_threads();
flash::gemm_opt</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
cute::copy(tVrV, tVsV);
if (n_block > n_block_min) {
// Advance gK
if constexpr (!Is_page_attn) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tKgK, tKrK, tKVcKV, params.d);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gK, tKgK, tKrK, tKVcKV, params.d, n_block - 1,
block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size);
}
}
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask</*Causal_mask=*/false>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 64) * 16 + (tidx & 0xf), kNWarps * 16
);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local, true, true>(acc_s, acc_o, params.scale_softmax_log2);
//Tensor rP = flash::convert_type<Element>(acc_s);
CONVERT_TENSOR_TYPE(ElementAccum, Element, acc_s, rP)
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
//Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), acc_s.layout());
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, /*Return_lse*/true, Split>(acc_o, params.scale_softmax);
// if (cute::thread0()) { print(lse); }
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
//Tensor rO = flash::convert_type<ElementO>(acc_o);
CONVERT_TENSOR_TYPE(ElementAccum, ElementO, acc_o, rO)
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sOaccum is larger than sQ, so we need to syncthreads here
// TODO: allocate enough smem for sOaccum
if constexpr (Split || Kernel_traits::Share_Q_K_smem) { flash::sync_threads(); }
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
flash::sync_threads();
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0>(taccOcO))::value == 4);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_4>{})(make_coord(0, _), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy_reg_to_global<Is_even_MN, Is_even_K>(
tOrOaccum, tOgOaccum, tOcO, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM
);
}
} // namespace flash

View File

@ -0,0 +1,20 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_parameter.h"
#include "flash_run_fwd_template_impl.h"
#include <mctlass/numeric_types.h>
template void run_flash_splitkv_fwd_template<
576,
32,
16,
2,
true,
true,
cutlass::bfloat16_t,
false,
512
>(Flash_fwd_params &params, mcFlashAttn::Flash_launch_params& launch_params,cudaStream_t stream);

View File

@ -0,0 +1,20 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_parameter.h"
#include "flash_run_fwd_template_impl.h"
#include <mctlass/numeric_types.h>
template void run_flash_splitkv_fwd_template<
576,
32,
16,
2,
true,
true,
cutlass::bfloat16_t,
true,
512
>(Flash_fwd_params &params, mcFlashAttn::Flash_launch_params& launch_params,cudaStream_t stream);

View File

@ -0,0 +1,20 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_parameter.h"
#include "flash_run_fwd_template_impl.h"
#include <mctlass/numeric_types.h>
template void run_flash_splitkv_fwd_template<
576,
32,
16,
2,
true,
true,
cutlass::half_t,
false,
512
>(Flash_fwd_params &params, mcFlashAttn::Flash_launch_params& launch_params,cudaStream_t stream);

View File

@ -0,0 +1,20 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_parameter.h"
#include "flash_run_fwd_template_impl.h"
#include <mctlass/numeric_types.h>
template void run_flash_splitkv_fwd_template<
576,
32,
16,
2,
true,
true,
cutlass::half_t,
true,
512
>(Flash_fwd_params &params, mcFlashAttn::Flash_launch_params& launch_params,cudaStream_t stream);

233
csrc/flash_parameter.h Normal file
View File

@ -0,0 +1,233 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
constexpr int maxValidBlockSizeM = 128;
namespace mcFlashAttn {
struct Qkv_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// kv cache dequant
index_t kscale_batch_stride;
index_t vscale_batch_stride;
index_t kscale_row_stride;
index_t vscale_row_stride;
index_t kscale_head_stride;
index_t vscale_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
// when page attn is not enable, page_block_size will has default value 0.
int page_block_size;
// KV Cache dequant
int dequant_group;
void *__restrict__ k_scale_ptr;
void *__restrict__ v_scale_ptr;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Local window size
int window_size_left, window_size_right;
// ratio of softcapping attention
// S = exp2(log2(e) * softcap * tanh(S * softmax_scale / softcap))
// only value > 0.0 will take effect
float softcap;
// Random state.
// at::PhiloxCudaState philox_args;
// the RNG seed and offset .
uint64_t rng_state_seed = 0;
uint64_t rng_state_offset = 0;
bool is_bf16;
bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
// attn_mask support for bert model Jira[C500-21935]
bool has_attn_mask;
void * __restrict__ attn_mask_ptr = nullptr;
index_t attn_mask_batch_stride = 0;
index_t attn_mask_nheads_stride = 0;
index_t attn_mask_row_stride = 0;
index_t attn_mask_col_stride = 1;
index_t attn_mask_batch_shape = 1;
index_t attn_mask_nheads_shape = 1;
index_t attn_mask_row_shape = 1;
index_t attn_mask_col_shape = 1;
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
int d_value;
int d_value_rounded;
bool is_support_splitkv = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
// asm mha_bwd kernel needs packed_seqlen
int packed_seqlen;
};
struct Flash_launch_params {
bool is_balance;
int rowblock_parallel;
int block_type;
bool performance_mode; // from offline
Flash_launch_params():
is_balance(false),rowblock_parallel(0),block_type(0),performance_mode(false){}
};
}

View File

@ -1,15 +0,0 @@
#pragma once
#include "cutlass/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class NamedBarriers {
SReady = 1,
SoftmaxReady = 2,
};
} // flash

View File

@ -1,197 +0,0 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
return tensor;
}
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
SumOp<float> sum_op;
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
}
}
template<typename Tensor0, typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scale_o); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scale_o;
clear(scale_o);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scale_o(mi) = scores_scale;
row_sum(mi) *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
return scale_o;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
};
} // namespace flash

View File

@ -1,65 +0,0 @@
#pragma once
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
#define FLASH_ASSERT(cond) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
exit(1); \
} \
} while(0)
#define FLASH_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while(0)
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()

View File

@ -1,238 +0,0 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_bf16.h>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
template<bool Transposed=false, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
template<typename MMA_Traits, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
} else {
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
// Will require register shuffling later to be correct.
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
get<1>(acc_layout),
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
// This combination is right but doesn't work with register shuffling.
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
// get<1>(acc_layout),
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -4,7 +4,11 @@ import random
import torch
import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
# from flash_mla import get_mla_metadata, flash_mla_with_kvcache
from flash_attn import (
get_mla_metadata,
flash_mla_with_kvcache
)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@ -32,12 +36,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, block_size):
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
@ -51,7 +55,6 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
@ -107,10 +110,10 @@ if __name__ == "__main__":
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
for block_size in [1,4,16,64]:
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1]: # TODO: to support MTP=2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, block_size)