mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Changes: 1. Implement flashMLA with matrix absorption algorithm via flashAttn2 2. Add golden test on MXMACA platform
49 lines
1.7 KiB
C++
49 lines
1.7 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
|
#include <torch/python.h>
|
|
#include <torch/nn/functional.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#ifdef OLD_GENERATOR_PATH
|
|
#include <ATen/CUDAGeneratorImpl.h>
|
|
#else
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
#endif
|
|
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
|
|
|
#include <mctlass/numeric_types.h>
|
|
|
|
#include "flash_api_mla.h"
|
|
|
|
std::vector<at::Tensor>
|
|
get_mla_metadata(
|
|
at::Tensor &seqlens_k,
|
|
const int num_heads_per_head_k,
|
|
const int num_heads_k
|
|
);
|
|
|
|
std::vector<at::Tensor>
|
|
mha_fwd_kvcache_mla(
|
|
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
|
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
|
const int head_size_v,
|
|
const at::Tensor &seqlens_k, // batch_size
|
|
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
|
const float softmax_scale,
|
|
bool is_causal,
|
|
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
|
const at::Tensor &num_splits // batch_size + 1
|
|
);
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.doc() = "FlashAttention";
|
|
|
|
//FlashMLA
|
|
m.def("get_mla_metadata", &get_mla_metadata);
|
|
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
|
|
} |