Hi, I think we can optimize the process that load q from gmem to smem. In function compute_attn_1rowblock_splitkv_mla, after the last calculation of q @ k_T, we no longer need q. So we can use async copy before the next function compute_attn_1rowblock_splitkv_mla run, This means that load the next q to the smem in advance. In order to prevent the valid values from being overwritten in smem, I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py to test. The test can pass normally without any calculation errors. I tested it on H800, and I use the average of 10 tests as the final result, each test interval is 3 seconds to stabilize the GPU frequency. The number of times to load q is very small, so this does not bring much performance improvement. Under some parameters, there is a slight decrease in performance, but it is gratifying that there is a roughly 0.5% performance improvement overall. batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage 64,1087,128,1384,1407,1.66% 64,2111,128,1744,1761,0.97% 64,4159,128,2188,2197,0.41% 64,8255,128,2341,2345,0.17% 64,16447,128,2330,2338,0.34% 64,32831,128,2374,2374,0.0% 128,1151,128,1756,1763,0.4% 128,2175,128,2066,2072,0.29% 128,4223,128,2284,2290,0.26% 128,8319,128,2343,2349,0.26% 128,16511,128,2375,2373,-0.08% 128,32895,128,2351,2358,0.3% 256,1279,128,2033,2035,0.1% 256,2303,128,2232,2228,-0.18% 256,4351,128,2322,2340,0.78% 256,8447,128,2371,2367,-0.17% 256,16639,128,2359,2394,1.48% 256,33023,128,2381,2392,0.46% Thanks! |
||
---|---|---|
benchmark | ||
csrc | ||
flash_mla | ||
tests | ||
.gitignore | ||
.gitmodules | ||
LICENSE | ||
README.md | ||
setup.py |
FlashMLA
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
- BF16, FP16
- Paged kvcache with block size of 64
Quick start
Install
python setup.py install
Benchmark
python tests/test_flash_mla.py
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
Usage
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True,
)
...
Requirements
- Hopper GPUs
- CUDA 12.3 and above
- But we highly recommend 12.8 or above for the best performance
- PyTorch 2.0 and above
Acknowledgement
FlashMLA is inspired by FlashAttention 2&3 and cutlass projects.
Community Support
MetaX
For MetaX GPUs, visit the official website: MetaX.
The corresponding FlashMLA version can be found at: MetaX-MACA/FlashMLA
Moore Threads
For the Moore Threads GPU, visit the official website: Moore Threads.
The corresponding FlashMLA version is available on GitHub: MooreThreads/MT-flashMLA.
Hygon DCU
For the Hygon DCU, visit the official website: Hygon Developer.
The corresponding FlashMLA version is available here: OpenDAS/MLAttention.
Intellifusion
For the Intellifusion NNP, visit the official website: Intellifusion.
The corresponding FlashMLA version is available on Gitee: Intellifusion/tyllm.
Iluvatar Corex
For Iluvatar Corex GPUs, visit the official website: Iluvatar Corex.
The corresponding FlashMLA version is available on GitHub: Deep-Spark/FlashMLA
AMD Instinct
For AMD Instinct GPUs, visit the official website: AMD Instinct.
The corresponding FlashMLA version can be found at: AITER/MLA
Citation
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}