mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Performance optimization for compute-bound cases
This commit is contained in:
@@ -10,8 +10,10 @@
|
||||
|
||||
#include <cutlass/fast_math.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "static_switch.h"
|
||||
#include "kernels/config.h"
|
||||
#include "kernels/get_mla_metadata.h"
|
||||
#include "kernels/mla_combine.h"
|
||||
#include "kernels/splitkv_mla.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
@@ -23,11 +25,6 @@ get_mla_metadata(
|
||||
const int num_heads_per_head_k,
|
||||
const int num_heads_k
|
||||
) {
|
||||
// This should match the logic in the MLA kernel.
|
||||
static constexpr int block_size_m = 64;
|
||||
static constexpr int block_size_n = 64;
|
||||
static constexpr int fixed_overhead_num_blocks = 5;
|
||||
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
TORCH_CHECK(seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
|
||||
@@ -38,7 +35,7 @@ get_mla_metadata(
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
int sm_count = dprops->multiProcessorCount;
|
||||
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
|
||||
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M);
|
||||
|
||||
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
|
||||
auto num_splits = torch::empty({batch_size + 1}, options);
|
||||
@@ -52,10 +49,10 @@ get_mla_metadata(
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
|
||||
params.num_splits_ptr = num_splits_ptr;
|
||||
params.batch_size = batch_size;
|
||||
params.block_size_n = block_size_n;
|
||||
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
|
||||
params.block_size_n = Config::PAGE_BLOCK_SIZE;
|
||||
params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS;
|
||||
params.num_sm_parts = num_sm_parts;
|
||||
get_mla_metadata_func(params, stream);
|
||||
run_get_mla_metadata_kernel(params, stream);
|
||||
|
||||
return {tile_scheduler_metadata, num_splits};
|
||||
}
|
||||
@@ -64,7 +61,6 @@ std::vector<at::Tensor>
|
||||
mha_fwd_kvcache_mla(
|
||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
const int head_size_v,
|
||||
const at::Tensor &seqlens_k, // batch_size
|
||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||
@@ -73,138 +69,141 @@ mha_fwd_kvcache_mla(
|
||||
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor &num_splits // batch_size + 1
|
||||
) {
|
||||
// Check the architecture
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90);
|
||||
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
// Check data types
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
// Check device
|
||||
CHECK_DEVICE(q);
|
||||
CHECK_DEVICE(kcache);
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_DEVICE(block_table);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
CHECK_DEVICE(num_splits);
|
||||
|
||||
// Check layout
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
||||
CHECK_CONTIGUOUS(num_splits);
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q_ori = sizes[1];
|
||||
const int num_heads_ori = sizes[2];
|
||||
const int head_size = sizes[3];
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
|
||||
const int num_heads_q = sizes[2];
|
||||
const int head_size_k = sizes[3];
|
||||
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
|
||||
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
|
||||
|
||||
const int max_num_blocks_per_seq = block_table.size(1);
|
||||
const int num_blocks = kcache.size(0);
|
||||
const int page_block_size = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q_ori == 1) { is_causal = false; }
|
||||
|
||||
const int ngroups = num_heads_ori / num_heads_k;
|
||||
const int seqlen_q = seqlen_q_ori * ngroups;
|
||||
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
|
||||
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
|
||||
const int num_heads = num_heads_k;
|
||||
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q, num_heads, head_size});
|
||||
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
|
||||
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
|
||||
|
||||
int head_size_k = head_size;
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
|
||||
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
|
||||
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
|
||||
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
CHECK_SHAPE(seqlens_k, batch_size);
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_SHAPE(num_splits, batch_size+1);
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
|
||||
CHECK_CONTIGUOUS(softmax_lse);
|
||||
|
||||
Flash_fwd_mla_params params = {};
|
||||
// Set the sizes.
|
||||
params.b = batch_size;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
|
||||
params.h = num_heads;
|
||||
params.h_h_k_ratio = num_heads / num_heads_k;
|
||||
params.ngroups = ngroups;
|
||||
params.s_q = seqlen_q_ori;
|
||||
params.q_seq_per_hk = q_seq_per_hk;
|
||||
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
|
||||
params.h_q = num_heads_q;
|
||||
params.h_k = num_heads_k;
|
||||
params.num_blocks = num_blocks;
|
||||
params.q_head_per_hk = num_q_heads_per_hk;
|
||||
params.is_causal = is_causal;
|
||||
params.d = head_size;
|
||||
params.d = head_size_k;
|
||||
params.d_v = head_size_v;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = kcache.data_ptr();
|
||||
params.v_ptr = vcache.data_ptr();
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = kcache.stride(0);
|
||||
params.v_batch_stride = vcache.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = kcache.stride(-3);
|
||||
params.v_row_stride = vcache.stride(-3);
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = kcache.stride(-2);
|
||||
params.v_head_stride = vcache.stride(-2);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
|
||||
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
|
||||
CHECK_DEVICE(tile_scheduler_metadata);
|
||||
CHECK_CONTIGUOUS(tile_scheduler_metadata);
|
||||
|
||||
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
|
||||
params.num_sm_parts = tile_scheduler_metadata.size(0);
|
||||
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
|
||||
CHECK_DEVICE(num_splits);
|
||||
CHECK_CONTIGUOUS(num_splits);
|
||||
params.num_splits_ptr = num_splits.data_ptr<int>();
|
||||
|
||||
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
|
||||
const int total_num_splits = batch_size + params.num_sm_parts;
|
||||
at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
|
||||
CHECK_CONTIGUOUS(softmax_lse_accum);
|
||||
CHECK_CONTIGUOUS(out_accum);
|
||||
params.total_num_splits = total_num_splits;
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
|
||||
TORCH_CHECK(head_size_k == 576);
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
}
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
else if (q_dtype == torch::kHalf) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
|
||||
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
|
||||
} else if (q_dtype == torch::kHalf) {
|
||||
#ifdef FLASH_MLA_DISABLE_FP16
|
||||
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
|
||||
#else
|
||||
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
|
||||
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||
}
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
|
||||
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
|
||||
out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
|
||||
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
|
||||
.reshape({batch_size, num_heads_q, seqlen_q_ori});
|
||||
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user