/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. #include #include #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include // For at::cuda::philox::unpack #include #include "flash_api_mla.h" std::vector get_mla_metadata( at::Tensor &seqlens_k, const int num_heads_per_head_k, const int num_heads_k ); std::vector mha_fwd_kvcache_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits // batch_size + 1 ); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; //FlashMLA m.def("get_mla_metadata", &get_mla_metadata); m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); }