Hi, I find in scale_apply_exp2, The code comments also mentioned this issue: https://github.com/pytorch/pytorch/issues/121558 This issue is that the ffma instruction generates some calculation errors during the flash attention compared to fadd and fmul separated. For fadd and fmul, the calculation is: round_fp32(x_i * scale) - round_fp32(x_i * scale) For max(x), this value is 0. But For ffma, the calculation is: x_i * scale - round_fp32(x_i * scale) Although the accuracy of ffma calculations has actually improved, there have been errors in the values. We can raise this issue by changing the initialization value of q k, and the final outs will all be 0: q = torch.full((b, s_q, h_q, d), 133120.0) blocked_k = torch.full((block_table.numel(), block_size, h_kv, d), 133120.0) If we define UNFUSE_FMA, This problem has been alleviated, but it still cannot pass the cal-diff test. I am not sure if it is an accuracy issue, but I think it is necessary to fix the fma bug first. |
||
|---|---|---|
| benchmark | ||
| csrc | ||
| flash_mla | ||
| tests | ||
| .gitignore | ||
| .gitmodules | ||
| LICENSE | ||
| README.md | ||
| setup.py | ||
FlashMLA
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
- BF16, FP16
- Paged kvcache with block size of 64
Quick start
Install
python setup.py install
Benchmark
python tests/test_flash_mla.py
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
Usage
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True,
)
...
Requirements
- Hopper GPUs
- CUDA 12.3 and above
- But we highly recommend 12.8 or above for the best performance
- PyTorch 2.0 and above
Acknowledgement
FlashMLA is inspired by FlashAttention 2&3 and cutlass projects.
Community Support
MetaX
For MetaX GPUs, visit the official website: MetaX.
The corresponding FlashMLA version can be found at: MetaX-MACA/FlashMLA
Moore Threads
For the Moore Threads GPU, visit the official website: Moore Threads.
The corresponding FlashMLA version is available on GitHub: MooreThreads/MT-flashMLA.
Hygon DCU
For the Hygon DCU, visit the official website: Hygon Developer.
The corresponding FlashMLA version is available here: OpenDAS/MLAttention.
Intellifusion
For the Intellifusion NNP, visit the official website: Intellifusion.
The corresponding FlashMLA version is available on Gitee: Intellifusion/tyllm.
Iluvatar Corex
For Iluvatar Corex GPUs, visit the official website: Iluvatar Corex.
The corresponding FlashMLA version is available on GitHub: Deep-Spark/FlashMLA
AMD Instinct
For AMD Instinct GPUs, visit the official website: AMD Instinct.
The corresponding FlashMLA version can be found at: AITER/MLA
Citation
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}