FlashMLA/csrc/flash_api.cpp
Kevin Zhang e0557deb3a Feature:Support flashMLA decoding via flashAttn2(#29)
Changes:
1. Implement flashMLA with matrix absorption algorithm via flashAttn2
2. Add golden test on MXMACA platform
2025-02-24 23:56:05 +08:00

49 lines
1.7 KiB
C++

/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include <mctlass/numeric_types.h>
#include "flash_api_mla.h"
std::vector<at::Tensor>
get_mla_metadata(
at::Tensor &seqlens_k,
const int num_heads_per_head_k,
const int num_heads_k
);
std::vector<at::Tensor>
mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
//FlashMLA
m.def("get_mla_metadata", &get_mla_metadata);
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
}