Merge branch 'main' into will_fp8_mr

This commit is contained in:
chenhongmin.will 2025-02-28 22:07:03 +08:00
commit c7143a7bda
10 changed files with 697 additions and 54 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ build
*.egg-info/ *.egg-info/
__pycache__/ __pycache__/
dist/ dist/
*perf.csv
*.png

View File

@ -3,7 +3,7 @@
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released: Currently released:
- BF16 - BF16, FP16
- Paged kvcache with block size of 64 - Paged kvcache with block size of 64
## Quick start ## Quick start
@ -20,7 +20,7 @@ python setup.py install
python tests/test_flash_mla.py python tests/test_flash_mla.py
``` ```
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6. Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
### Usage ### Usage
@ -42,17 +42,49 @@ for i in range(num_layers):
- Hopper GPUs - Hopper GPUs
- CUDA 12.3 and above - CUDA 12.3 and above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.0 and above - PyTorch 2.0 and above
## Acknowledgement ## Acknowledgement
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
## Community Support
### MetaX
For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com).
The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)
### Moore Threads
For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/).
The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA).
### Hygon DCU
For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/).
The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention).
### Intellifusion
For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com).
The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py).
### Iluvatar Corex
For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com).
The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)
## Citation ## Citation
```bibtex ```bibtex
@misc{flashmla2025, @misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernel}, title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li}, author={Jiashi Li},
year={2025}, year={2025},
publisher = {GitHub}, publisher = {GitHub},

View File

@ -0,0 +1,520 @@
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import math
import random
import flashinfer
import torch
import triton
import triton.language as tl
# pip install flashinfer-python
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_torch, lse_torch = ref_mla()
t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t
@torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
def flash_mla():
return flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=causal,
)
out_flash, lse_flash = flash_mla()
t = triton.testing.do_bench(flash_mla)
return out_flash, lse_flash, t
@torch.inference_mode()
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
kv_indptr = [0]
kv_indices = []
for i in range(b):
seq_len = cache_seqlens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_table[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
for seq_len in cache_seqlens[1:]:
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
q_indptr = torch.arange(0, b + 1).int() * s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.int8),
backend="fa3"
)
mla_wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
cache_seqlens,
h_q,
dv,
d-dv,
block_size,
causal,
1 / math.sqrt(d),
q.dtype,
blocked_k.dtype,
)
def flash_infer():
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flash_infer()
t = triton.testing.do_bench(flash_infer)
return out_flash, lse_flash, t
@triton.jit
def _mla_attn_kernel(
Q_nope,
Q_pe,
Kv_c_cache,
K_pe_cache,
Req_to_tokens,
B_seq_len,
O,
sm_scale,
stride_q_nope_bs,
stride_q_nope_h,
stride_q_pe_bs,
stride_q_pe_h,
stride_kv_c_bs,
stride_k_pe_bs,
stride_req_to_tokens_bs,
stride_o_b,
stride_o_h,
stride_o_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
HEAD_DIM_KPE: tl.constexpr,
):
cur_batch = tl.program_id(1)
cur_head_id = tl.program_id(0)
split_kv_id = tl.program_id(2)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]
q_pe = tl.load(Q_pe + offs_q_pe)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]
k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk = tl.dot(q_nope, k_c.to(q_nope.dtype))
offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]
k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))
qk *= sm_scale
qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf"))
v_c = tl.trans(k_c)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v_c.dtype), v_c)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
def _mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
):
batch_size, head_num = q_nope.shape[0], q_nope.shape[1]
head_dim_ckv = q_nope.shape[-1]
head_dim_kpe = q_pe.shape[-1]
BLOCK_H = 16
BLOCK_N = 64
grid = (
triton.cdiv(head_num, BLOCK_H),
batch_size,
num_kv_splits,
)
_mla_attn_kernel[grid](
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
req_to_tokens,
b_seq_len,
attn_logits,
sm_scale,
# stride
q_nope.stride(0),
q_nope.stride(1),
q_pe.stride(0),
q_pe.stride(1),
kv_c_cache.stride(-2),
k_pe_cache.stride(-2),
req_to_tokens.stride(0),
attn_logits.stride(0),
attn_logits.stride(1),
attn_logits.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe,
)
@triton.jit
def _mla_softmax_reducev_kernel(
Logits,
B_seq_len,
O,
stride_l_b,
stride_l_h,
stride_l_s,
stride_o_b,
stride_o_h,
NUM_KV_SPLITS: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)
offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv
offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)
logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)
n_e_max = tl.maximum(logits_1, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(logits_1 - n_e_max)
acc += exp_logic * logits
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum,
)
def _mla_softmax_reducev(
logits,
o,
b_seq_len,
num_kv_splits,
):
batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]
grid = (batch_size, head_num)
_mla_softmax_reducev_kernel[grid](
logits,
b_seq_len,
o,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
HEAD_DIM_CKV=head_dim_ckv,
num_warps=4,
num_stages=2,
)
def mla_decode_triton(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
o,
req_to_tokens,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
):
assert num_kv_splits == attn_logits.shape[2]
_mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
)
_mla_softmax_reducev(
attn_logits,
o,
b_seq_len,
num_kv_splits,
)
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
t = triton.testing.do_bench(flash_mla_triton)
return out_flash, None, t
FUNC_TABLE = {
"torch": run_torch_mla,
"flash_mla": run_flash_mla,
"flash_infer": run_flash_infer,
"flash_mla_triton": run_flash_mla_triton,
}
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert baseline in FUNC_TABLE
assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton"]:
# flash_infer has a different lse return value
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert target in FUNC_TABLE
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_b
available_targets = [
"torch",
"flash_mla",
"flash_infer",
"flash_mla_triton",
]
shape_configs = [
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16}
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]
]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", type=str, default="torch")
parser.add_argument("--target", type=str, default="flash_mla")
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')

29
benchmark/visualize.py Normal file
View File

@ -0,0 +1,29 @@
import argparse
import matplotlib.pyplot as plt
import pandas as pd
def parse_args():
parser = argparse.ArgumentParser(description='Visualize benchmark results')
parser.add_argument('--file', type=str, default='all_perf.csv',
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
return parser.parse_args()
args = parse_args()
file_path = args.file
df = pd.read_csv(file_path)
names = df['name'].unique()
for name in names:
subset = df[df['name'] == name]
plt.plot(subset['seqlen'], subset['bw'], label=name)
plt.title('bandwidth')
plt.xlabel('seqlen')
plt.ylabel('bw (GB/s)')
plt.legend()
plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')

View File

@ -61,7 +61,7 @@ std::vector<at::Tensor>
mha_fwd_kvcache_mla( mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v, const int head_size_v,
const at::Tensor &seqlens_k, // batch_size const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
@ -79,9 +79,8 @@ mha_fwd_kvcache_mla(
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.scalar_type(); auto q_dtype = q.scalar_type();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn); TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf || q_dtype == torch::kFloat8_e4m3fn);
TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype");
bool is_fp8 = q_dtype == torch::kFloat8_e4m3fn;
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
@ -108,7 +107,7 @@ mha_fwd_kvcache_mla(
TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (is_fp8) { if (q_dtype == torch::kFloat8_e4m3fn) {
TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8"); TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8");
auto descale_q = descale_q_.value(); auto descale_q = descale_q_.value();
auto descale_k = descale_k_.value(); auto descale_k = descale_k_.value();
@ -145,7 +144,7 @@ mha_fwd_kvcache_mla(
at::cuda::CUDAGuard device_guard{(char)q.get_device()}; at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options(); auto opts = q.options();
auto out_type = is_fp8 ? torch::kBFloat16 : q_dtype; auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type));
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
@ -186,7 +185,7 @@ mha_fwd_kvcache_mla(
params.block_table_batch_stride = block_table.stride(0); params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size; params.page_block_size = page_block_size;
if (is_fp8) { if (q_dtype == torch::kFloat8_e4m3fn) {
params.descale_q_ptr = reinterpret_cast<float*>(descale_q_.value().data_ptr()); params.descale_q_ptr = reinterpret_cast<float*>(descale_q_.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float*>(descale_k_.value().data_ptr()); params.descale_k_ptr = reinterpret_cast<float*>(descale_k_.value().data_ptr());
} }
@ -210,10 +209,19 @@ mha_fwd_kvcache_mla(
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576); TORCH_CHECK(head_size == 576);
if (is_fp8) {
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
}
#endif
else if (q_dtype == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream); run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
} else { } else {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream); TORCH_CHECK(false, "Unsupported tensor dtype for query");
} }
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)

View File

@ -0,0 +1,3 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -1,12 +1,4 @@
#include <cutlass/cutlass.h> #include "flash_fwd_mla_kernel.h"
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
using namespace cute;
#include "flash_mla.h"
#include "static_switch.h"
#include "utils.h"
static constexpr int MaxBatchSize = 4096; static constexpr int MaxBatchSize = 4096;

View File

@ -16,7 +16,7 @@ def get_mla_metadata(
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k. num_heads_k: num_heads_k.
Return: Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.
""" """
@ -42,10 +42,10 @@ def flash_mla_with_kvcache(
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32. cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v. head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask. causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float. dequant scale for query descale_q: (batch_size), torch.float. dequant scale for query
descale_k: (batch_size), torch.float. dequant scale for key descale_k: (batch_size), torch.float. dequant scale for key

View File

@ -11,12 +11,35 @@ from torch.utils.cpp_extension import (
IS_WINDOWS, IS_WINDOWS,
) )
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "32" nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return nvcc_extra_args + ["--threads", nvcc_threads] return nvcc_extra_args + ["--threads", nvcc_threads]
def get_sources():
sources = [
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_fp8_sm90.cu",
"csrc/flash_fwd_mla_metadata.cu",
]
if not DISABLE_FP16:
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
return sources
def get_features_args():
features_args = []
if DISABLE_FP16:
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
cc_flag = [] cc_flag = []
@ -34,14 +57,9 @@ ext_modules = []
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="flash_mla_cuda", name="flash_mla_cuda",
sources=[ sources=get_sources(),
"csrc/flash_api.cpp",
"csrc/flash_mla_utils.cu",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_fp8_sm90.cu",
],
extra_compile_args={ extra_compile_args={
"cxx": cxx_args, "cxx": cxx_args + get_features_args(),
"nvcc": append_nvcc_threads( "nvcc": append_nvcc_threads(
[ [
"-O3", "-O3",
@ -60,7 +78,7 @@ ext_modules.append(
"--ftemplate-backtrace-limit=0" "--ftemplate-backtrace-limit=0"
] ]
+ cc_flag + cc_flag
), ) + get_features_args(),
}, },
include_dirs=[ include_dirs=[
Path(this_dir) / "csrc", Path(this_dir) / "csrc",

View File

@ -1,10 +1,11 @@
import argparse
import math import math
import random import random
import torch import torch
import triton import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache from flash_mla import flash_mla_with_kvcache, get_mla_metadata
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@ -42,7 +43,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False): def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {use_fp8=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen: if varlen:
@ -56,15 +59,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
q = torch.randn(b, s_q, h_q, d) q = torch.randn(b, s_q, h_q, d)
block_size = 64 block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b): for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
init_dtype = q.dtype )
def prepare_fp8_input(): def prepare_fp8_input():
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
@ -90,9 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
def flash_mla(): def flash_mla():
return flash_mla_with_kvcache( return flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv, q,
tile_scheduler_metadata, num_splits, causal=causal, blocked_k,
descale_q=descale_q, descale_k=descale_k, block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
) )
def ref_mla(): def ref_mla():
@ -124,14 +138,18 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
t = triton.testing.do_bench(flash_mla) t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
if __name__ == "__main__": def main(torch_dtype):
dtype = torch.bfloat16
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.set_default_dtype(dtype) init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.manual_seed(0) torch.manual_seed(0)
@ -140,11 +158,32 @@ if __name__ == "__main__":
h_kv = 1 h_kv = 1
d, dv = 576, 512 d, dv = 576, 512
causal = False causal = False
use_fp8 = True use_fp8 = torch_dtype == torch.float8_e4m3fn
for b in [16]: for b in [128]:
for s in [4096]: for s in [4096, 8192]:
for h_q in [128]: # TP = 8, 4, 2, 1 for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [2]: # MTP = 1, 2 for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False]: for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8) test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16", "e4m3"],
default="bf16",
help="Data type to use for testing (bf16/fp16/e4m3)",
)
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
elif args.dtype = "e4m3":
torch.dtype = torch.float8_e4m3fn
main(torch_dtype)