mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Merge branch 'main' into will_fp8_mr
This commit is contained in:
commit
c7143a7bda
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,3 +3,5 @@ build
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
dist/
|
||||
*perf.csv
|
||||
*.png
|
||||
|
||||
38
README.md
38
README.md
@ -3,7 +3,7 @@
|
||||
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
||||
|
||||
Currently released:
|
||||
- BF16
|
||||
- BF16, FP16
|
||||
- Paged kvcache with block size of 64
|
||||
|
||||
## Quick start
|
||||
@ -20,7 +20,7 @@ python setup.py install
|
||||
python tests/test_flash_mla.py
|
||||
```
|
||||
|
||||
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6.
|
||||
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
|
||||
|
||||
### Usage
|
||||
|
||||
@ -42,17 +42,49 @@ for i in range(num_layers):
|
||||
|
||||
- Hopper GPUs
|
||||
- CUDA 12.3 and above
|
||||
- **But we highly recommend 12.8 or above for the best performance**
|
||||
- PyTorch 2.0 and above
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
|
||||
|
||||
## Community Support
|
||||
|
||||
### MetaX
|
||||
For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com).
|
||||
|
||||
The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)
|
||||
|
||||
|
||||
### Moore Threads
|
||||
For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/).
|
||||
|
||||
The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA).
|
||||
|
||||
|
||||
### Hygon DCU
|
||||
For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/).
|
||||
|
||||
The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention).
|
||||
|
||||
|
||||
### Intellifusion
|
||||
For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com).
|
||||
|
||||
The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py).
|
||||
|
||||
|
||||
### Iluvatar Corex
|
||||
For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com).
|
||||
|
||||
The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{flashmla2025,
|
||||
title={FlashMLA: Efficient MLA decoding kernel},
|
||||
title={FlashMLA: Efficient MLA decoding kernels},
|
||||
author={Jiashi Li},
|
||||
year={2025},
|
||||
publisher = {GitHub},
|
||||
|
||||
520
benchmark/bench_flash_mla.py
Normal file
520
benchmark/bench_flash_mla.py
Normal file
@ -0,0 +1,520 @@
|
||||
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# pip install flashinfer-python
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
def ref_mla():
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
h_q, h_kv,
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
return out, lse
|
||||
|
||||
out_torch, lse_torch = ref_mla()
|
||||
t = triton.testing.do_bench(ref_mla)
|
||||
return out_torch, lse_torch, t
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q, blocked_k, block_table, cache_seqlens, dv,
|
||||
tile_scheduler_metadata, num_splits, causal=causal,
|
||||
)
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
return out_flash, lse_flash, t
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
|
||||
assert d > dv, "mla with rope dim should be larger than no rope dim"
|
||||
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
|
||||
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
|
||||
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
for i in range(b):
|
||||
seq_len = cache_seqlens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_table[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
for seq_len in cache_seqlens[1:]:
|
||||
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
|
||||
|
||||
q_indptr = torch.arange(0, b + 1).int() * s_q
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
|
||||
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
torch.empty(128 * 1024 * 1024, dtype=torch.int8),
|
||||
backend="fa3"
|
||||
)
|
||||
mla_wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
cache_seqlens,
|
||||
h_q,
|
||||
dv,
|
||||
d-dv,
|
||||
block_size,
|
||||
causal,
|
||||
1 / math.sqrt(d),
|
||||
q.dtype,
|
||||
blocked_k.dtype,
|
||||
)
|
||||
|
||||
def flash_infer():
|
||||
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True)
|
||||
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
|
||||
|
||||
out_flash, lse_flash = flash_infer()
|
||||
t = triton.testing.do_bench(flash_infer)
|
||||
return out_flash, lse_flash, t
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _mla_attn_kernel(
|
||||
Q_nope,
|
||||
Q_pe,
|
||||
Kv_c_cache,
|
||||
K_pe_cache,
|
||||
Req_to_tokens,
|
||||
B_seq_len,
|
||||
O,
|
||||
sm_scale,
|
||||
stride_q_nope_bs,
|
||||
stride_q_nope_h,
|
||||
stride_q_pe_bs,
|
||||
stride_q_pe_h,
|
||||
stride_kv_c_bs,
|
||||
stride_k_pe_bs,
|
||||
stride_req_to_tokens_bs,
|
||||
stride_o_b,
|
||||
stride_o_h,
|
||||
stride_o_s,
|
||||
BLOCK_H: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
HEAD_DIM_CKV: tl.constexpr,
|
||||
HEAD_DIM_KPE: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(1)
|
||||
cur_head_id = tl.program_id(0)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
|
||||
|
||||
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
|
||||
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
|
||||
q_nope = tl.load(Q_nope + offs_q_nope)
|
||||
|
||||
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
|
||||
offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]
|
||||
q_pe = tl.load(Q_pe + offs_q_pe)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]
|
||||
k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)
|
||||
|
||||
qk = tl.dot(q_nope, k_c.to(q_nope.dtype))
|
||||
|
||||
offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]
|
||||
k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)
|
||||
|
||||
qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf"))
|
||||
|
||||
v_c = tl.trans(k_c)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
acc += tl.dot(p.to(v_c.dtype), v_c)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
|
||||
tl.store(O + offs_o, acc / e_sum[:, None])
|
||||
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
|
||||
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
|
||||
|
||||
|
||||
def _mla_attn(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_cache,
|
||||
k_pe_cache,
|
||||
attn_logits,
|
||||
req_to_tokens,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
):
|
||||
batch_size, head_num = q_nope.shape[0], q_nope.shape[1]
|
||||
head_dim_ckv = q_nope.shape[-1]
|
||||
head_dim_kpe = q_pe.shape[-1]
|
||||
|
||||
BLOCK_H = 16
|
||||
BLOCK_N = 64
|
||||
grid = (
|
||||
triton.cdiv(head_num, BLOCK_H),
|
||||
batch_size,
|
||||
num_kv_splits,
|
||||
)
|
||||
_mla_attn_kernel[grid](
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_cache,
|
||||
k_pe_cache,
|
||||
req_to_tokens,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
sm_scale,
|
||||
# stride
|
||||
q_nope.stride(0),
|
||||
q_nope.stride(1),
|
||||
q_pe.stride(0),
|
||||
q_pe.stride(1),
|
||||
kv_c_cache.stride(-2),
|
||||
k_pe_cache.stride(-2),
|
||||
req_to_tokens.stride(0),
|
||||
attn_logits.stride(0),
|
||||
attn_logits.stride(1),
|
||||
attn_logits.stride(2),
|
||||
BLOCK_H=BLOCK_H,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NUM_KV_SPLITS=num_kv_splits,
|
||||
PAGE_SIZE=page_size,
|
||||
HEAD_DIM_CKV=head_dim_ckv,
|
||||
HEAD_DIM_KPE=head_dim_kpe,
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def _mla_softmax_reducev_kernel(
|
||||
Logits,
|
||||
B_seq_len,
|
||||
O,
|
||||
stride_l_b,
|
||||
stride_l_h,
|
||||
stride_l_s,
|
||||
stride_o_b,
|
||||
stride_o_h,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
HEAD_DIM_CKV: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
|
||||
|
||||
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
|
||||
|
||||
e_sum = 0.0
|
||||
e_max = -float("inf")
|
||||
acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)
|
||||
|
||||
offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv
|
||||
offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV
|
||||
|
||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)
|
||||
logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)
|
||||
|
||||
n_e_max = tl.maximum(logits_1, e_max)
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
acc *= old_scale
|
||||
exp_logic = tl.exp(logits_1 - n_e_max)
|
||||
acc += exp_logic * logits
|
||||
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
tl.store(
|
||||
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
|
||||
acc / e_sum,
|
||||
)
|
||||
|
||||
|
||||
def _mla_softmax_reducev(
|
||||
logits,
|
||||
o,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
):
|
||||
batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]
|
||||
grid = (batch_size, head_num)
|
||||
_mla_softmax_reducev_kernel[grid](
|
||||
logits,
|
||||
b_seq_len,
|
||||
o,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
NUM_KV_SPLITS=num_kv_splits,
|
||||
HEAD_DIM_CKV=head_dim_ckv,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
def mla_decode_triton(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_cache,
|
||||
k_pe_cache,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
_mla_attn(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_cache,
|
||||
k_pe_cache,
|
||||
attn_logits,
|
||||
req_to_tokens,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
)
|
||||
_mla_softmax_reducev(
|
||||
attn_logits,
|
||||
o,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
assert d > dv, "mla with rope dim should be larger than no rope dim"
|
||||
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
|
||||
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
|
||||
|
||||
def flash_mla_triton():
|
||||
num_kv_splits = 32
|
||||
o = torch.empty([b * s_q, h_q, dv])
|
||||
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
|
||||
mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size)
|
||||
return o.view([b, s_q, h_q, dv])
|
||||
|
||||
out_flash = flash_mla_triton()
|
||||
t = triton.testing.do_bench(flash_mla_triton)
|
||||
return out_flash, None, t
|
||||
|
||||
|
||||
FUNC_TABLE = {
|
||||
"torch": run_torch_mla,
|
||||
"flash_mla": run_flash_mla,
|
||||
"flash_infer": run_flash_infer,
|
||||
"flash_mla_triton": run_flash_mla_triton,
|
||||
}
|
||||
|
||||
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
assert baseline in FUNC_TABLE
|
||||
assert target in FUNC_TABLE
|
||||
baseline_func = FUNC_TABLE[baseline]
|
||||
target_func = FUNC_TABLE[target]
|
||||
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
mean_seqlens = cache_seqlens.float().mean().int().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
|
||||
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
|
||||
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
|
||||
|
||||
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
|
||||
if target not in ["flash_infer", "flash_mla_triton"]:
|
||||
# flash_infer has a different lse return value
|
||||
# flash_mla_triton doesn't return lse
|
||||
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
|
||||
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
|
||||
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
|
||||
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b
|
||||
|
||||
|
||||
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
|
||||
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
|
||||
torch.set_default_dtype(dtype)
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
assert target in FUNC_TABLE
|
||||
target_func = FUNC_TABLE[target]
|
||||
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
mean_seqlens = cache_seqlens.float().mean().int().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
|
||||
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
|
||||
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
|
||||
return bytes / 10 ** 6 / perf_b
|
||||
|
||||
|
||||
available_targets = [
|
||||
"torch",
|
||||
"flash_mla",
|
||||
"flash_infer",
|
||||
"flash_mla_triton",
|
||||
]
|
||||
|
||||
shape_configs = [
|
||||
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16}
|
||||
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]
|
||||
]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--baseline", type=str, default="torch")
|
||||
parser.add_argument("--target", type=str, default="flash_mla")
|
||||
parser.add_argument("--all", action="store_true")
|
||||
parser.add_argument("--one", action="store_true")
|
||||
parser.add_argument("--compare", action="store_true")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
|
||||
with open(f"{benchmark_type}_perf.csv", "w") as fout:
|
||||
fout.write("name,batch,seqlen,head,bw\n")
|
||||
for shape in shape_configs:
|
||||
if args.all:
|
||||
for target in available_targets:
|
||||
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
|
||||
elif args.compare:
|
||||
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
|
||||
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
|
||||
elif args.one:
|
||||
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
|
||||
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
|
||||
29
benchmark/visualize.py
Normal file
29
benchmark/visualize.py
Normal file
@ -0,0 +1,29 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Visualize benchmark results')
|
||||
parser.add_argument('--file', type=str, default='all_perf.csv',
|
||||
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
file_path = args.file
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
names = df['name'].unique()
|
||||
|
||||
for name in names:
|
||||
subset = df[df['name'] == name]
|
||||
plt.plot(subset['seqlen'], subset['bw'], label=name)
|
||||
|
||||
plt.title('bandwidth')
|
||||
plt.xlabel('seqlen')
|
||||
plt.ylabel('bw (GB/s)')
|
||||
plt.legend()
|
||||
|
||||
plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')
|
||||
@ -61,7 +61,7 @@ std::vector<at::Tensor>
|
||||
mha_fwd_kvcache_mla(
|
||||
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
||||
const int head_size_v,
|
||||
const at::Tensor &seqlens_k, // batch_size
|
||||
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
||||
@ -79,9 +79,8 @@ mha_fwd_kvcache_mla(
|
||||
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
||||
|
||||
auto q_dtype = q.scalar_type();
|
||||
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf || q_dtype == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype");
|
||||
bool is_fp8 = q_dtype == torch::kFloat8_e4m3fn;
|
||||
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
@ -108,7 +107,7 @@ mha_fwd_kvcache_mla(
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (is_fp8) {
|
||||
if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||
TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8");
|
||||
auto descale_q = descale_q_.value();
|
||||
auto descale_k = descale_k_.value();
|
||||
@ -145,7 +144,7 @@ mha_fwd_kvcache_mla(
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto out_type = is_fp8 ? torch::kBFloat16 : q_dtype;
|
||||
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
|
||||
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type));
|
||||
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
@ -186,7 +185,7 @@ mha_fwd_kvcache_mla(
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
if (is_fp8) {
|
||||
if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||
params.descale_q_ptr = reinterpret_cast<float*>(descale_q_.value().data_ptr());
|
||||
params.descale_k_ptr = reinterpret_cast<float*>(descale_k_.value().data_ptr());
|
||||
}
|
||||
@ -210,10 +209,19 @@ mha_fwd_kvcache_mla(
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
|
||||
if (is_fp8) {
|
||||
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||
}
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
else if (q_dtype == torch::kHalf) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
|
||||
}
|
||||
#endif
|
||||
else if (q_dtype == torch::kFloat8_e4m3fn) {
|
||||
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||
TORCH_CHECK(false, "Unsupported tensor dtype for query");
|
||||
}
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
|
||||
3
csrc/flash_fwd_mla_fp16_sm90.cu
Normal file
3
csrc/flash_fwd_mla_fp16_sm90.cu
Normal file
@ -0,0 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
@ -1,12 +1,4 @@
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "static_switch.h"
|
||||
#include "utils.h"
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
static constexpr int MaxBatchSize = 4096;
|
||||
|
||||
@ -16,7 +16,7 @@ def get_mla_metadata(
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
@ -42,10 +42,10 @@ def flash_mla_with_kvcache(
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
descale_q: (batch_size), torch.float. dequant scale for query
|
||||
descale_k: (batch_size), torch.float. dequant scale for key
|
||||
|
||||
34
setup.py
34
setup.py
@ -11,12 +11,35 @@ from torch.utils.cpp_extension import (
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
|
||||
|
||||
def get_sources():
|
||||
sources = [
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
"csrc/flash_fwd_mla_fp8_sm90.cu",
|
||||
"csrc/flash_fwd_mla_metadata.cu",
|
||||
]
|
||||
|
||||
if not DISABLE_FP16:
|
||||
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def get_features_args():
|
||||
features_args = []
|
||||
if DISABLE_FP16:
|
||||
features_args.append("-DFLASH_MLA_DISABLE_FP16")
|
||||
return features_args
|
||||
|
||||
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
|
||||
cc_flag = []
|
||||
@ -34,14 +57,9 @@ ext_modules = []
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_mla_cuda",
|
||||
sources=[
|
||||
"csrc/flash_api.cpp",
|
||||
"csrc/flash_mla_utils.cu",
|
||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||
"csrc/flash_fwd_mla_fp8_sm90.cu",
|
||||
],
|
||||
sources=get_sources(),
|
||||
extra_compile_args={
|
||||
"cxx": cxx_args,
|
||||
"cxx": cxx_args + get_features_args(),
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
@ -60,7 +78,7 @@ ext_modules.append(
|
||||
"--ftemplate-backtrace-limit=0"
|
||||
]
|
||||
+ cc_flag
|
||||
),
|
||||
) + get_features_args(),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / "csrc",
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
@ -42,7 +43,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {use_fp8=}"
|
||||
)
|
||||
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
@ -56,15 +59,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_size = 64
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
|
||||
float("nan")
|
||||
)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
init_dtype = q.dtype
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
def prepare_fp8_input():
|
||||
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
|
||||
@ -90,9 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q, blocked_k, block_table, cache_seqlens, dv,
|
||||
tile_scheduler_metadata, num_splits, causal=causal,
|
||||
descale_q=descale_q, descale_k=descale_k,
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
|
||||
def ref_mla():
|
||||
@ -124,14 +138,18 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
|
||||
torch.finfo(q.dtype).bits // 8
|
||||
)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = torch.bfloat16
|
||||
def main(torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
@ -140,11 +158,32 @@ if __name__ == "__main__":
|
||||
h_kv = 1
|
||||
d, dv = 576, 512
|
||||
causal = False
|
||||
use_fp8 = True
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
|
||||
for b in [16]:
|
||||
for s in [4096]:
|
||||
for h_q in [128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [2]: # MTP = 1, 2
|
||||
for varlen in [False]:
|
||||
for b in [128]:
|
||||
for s in [4096, 8192]:
|
||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [1, 2]: # MTP = 1, 2
|
||||
for varlen in [False, True]:
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["bf16", "fp16", "e4m3"],
|
||||
default="bf16",
|
||||
help="Data type to use for testing (bf16/fp16/e4m3)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_dtype = torch.bfloat16
|
||||
if args.dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif args.dtype = "e4m3":
|
||||
torch.dtype = torch.float8_e4m3fn
|
||||
|
||||
main(torch_dtype)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user