Go to file
zhouzihan30 8997d13745 Performance: Async copy for q in advance, roughly 0.5% performance gain
Hi, I think we can optimize the process that load q from gmem to smem.
In function compute_attn_1rowblock_splitkv_mla, after the last calculation
of q @ k_T, we no longer need q. So we can use async copy before the
next function compute_attn_1rowblock_splitkv_mla run, This means that load
the next q to the smem in advance.

In order to prevent the valid values from being overwritten in smem,
I adjusted the layout of SharedStorageMLA, and use test_flash_mla.py
to test. The test can pass normally without any calculation errors.

I tested it on H800, and I use the average of 10 tests as the final result,
each test interval is 3 seconds to stabilize the GPU frequency.

The number of times to load q is very small, so this does not bring much
performance improvement. Under some parameters, there is a slight decrease
in performance, but it is gratifying that there is a roughly 0.5%
performance improvement overall.

batch,seqlen,head,bw_orig,bw_opt,bw_diff_percentage
64,1087,128,1384,1407,1.66%
64,2111,128,1744,1761,0.97%
64,4159,128,2188,2197,0.41%
64,8255,128,2341,2345,0.17%
64,16447,128,2330,2338,0.34%
64,32831,128,2374,2374,0.0%
128,1151,128,1756,1763,0.4%
128,2175,128,2066,2072,0.29%
128,4223,128,2284,2290,0.26%
128,8319,128,2343,2349,0.26%
128,16511,128,2375,2373,-0.08%
128,32895,128,2351,2358,0.3%
256,1279,128,2033,2035,0.1%
256,2303,128,2232,2228,-0.18%
256,4351,128,2322,2340,0.78%
256,8447,128,2371,2367,-0.17%
256,16639,128,2359,2394,1.48%
256,33023,128,2381,2392,0.46%

Thanks!
2025-03-26 14:20:28 +08:00
benchmark fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them 2025-02-26 00:14:51 +08:00
csrc Performance: Async copy for q in advance, roughly 0.5% performance gain 2025-03-26 14:20:28 +08:00
flash_mla Update docstring 2025-02-25 00:11:57 +08:00
tests Style fix 2025-02-25 09:18:11 +08:00
.gitignore add gitignore for png and csv files in benchmark 2025-02-25 00:38:02 +08:00
.gitmodules Initial commit 2025-02-24 09:20:23 +08:00
LICENSE Initial commit 2025-02-24 09:20:23 +08:00
README.md add community support for [AMD] 2025-03-01 17:55:58 +08:00
setup.py cuda12.8 recommendation 2025-02-26 00:05:57 +08:00

FlashMLA

FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.

Currently released:

  • BF16, FP16
  • Paged kvcache with block size of 64

Quick start

Install

python setup.py install

Benchmark

python tests/test_flash_mla.py

Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.

Usage

from flash_mla import get_mla_metadata, flash_mla_with_kvcache

tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
        tile_scheduler_metadata, num_splits, causal=True,
    )
    ...

Requirements

  • Hopper GPUs
  • CUDA 12.3 and above
    • But we highly recommend 12.8 or above for the best performance
  • PyTorch 2.0 and above

Acknowledgement

FlashMLA is inspired by FlashAttention 2&3 and cutlass projects.

Community Support

MetaX

For MetaX GPUs, visit the official website: MetaX.

The corresponding FlashMLA version can be found at: MetaX-MACA/FlashMLA

Moore Threads

For the Moore Threads GPU, visit the official website: Moore Threads.

The corresponding FlashMLA version is available on GitHub: MooreThreads/MT-flashMLA.

Hygon DCU

For the Hygon DCU, visit the official website: Hygon Developer.

The corresponding FlashMLA version is available here: OpenDAS/MLAttention.

Intellifusion

For the Intellifusion NNP, visit the official website: Intellifusion.

The corresponding FlashMLA version is available on Gitee: Intellifusion/tyllm.

Iluvatar Corex

For Iluvatar Corex GPUs, visit the official website: Iluvatar Corex.

The corresponding FlashMLA version is available on GitHub: Deep-Spark/FlashMLA

AMD Instinct

For AMD Instinct GPUs, visit the official website: AMD Instinct.

The corresponding FlashMLA version can be found at: AITER/MLA

Citation

@misc{flashmla2025,
      title={FlashMLA: Efficient MLA decoding kernels},
      author={Jiashi Li},
      year={2025},
      publisher = {GitHub},
      howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}