From 15a82b81b868049677bc86ed19c56809b1863511 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 00:25:25 -0800 Subject: [PATCH 01/15] replace c10 optional with std optional --- csrc/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a1cb8e..3184465 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -61,7 +61,7 @@ std::vector mha_fwd_kvcache_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v + std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq From 65fb7732fc5f95edea30019296b1940c759ae321 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 01:58:53 -0800 Subject: [PATCH 02/15] support fp16 --- README.md | 2 +- csrc/flash_api.cpp | 9 +++- csrc/flash_fwd_mla_fp16_sm90.cu | 3 ++ csrc/flash_fwd_mla_kernel.h | 76 -------------------------------- csrc/flash_fwd_mla_metadata.cu | 77 +++++++++++++++++++++++++++++++++ setup.py | 2 + tests/test_flash_mla.py | 61 +++++++++++++++++++++----- 7 files changed, 139 insertions(+), 91 deletions(-) create mode 100644 csrc/flash_fwd_mla_fp16_sm90.cu create mode 100644 csrc/flash_fwd_mla_metadata.cu diff --git a/README.md b/README.md index bb55395..4027334 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. Currently released: -- BF16 +- BF16, FP16 - Paged kvcache with block size of 64 ## Quick start diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 3184465..b7b11f1 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,7 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -186,7 +186,12 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + + if (q_dtype == torch::kBFloat16) { + run_mha_fwd_splitkv_mla(params, stream); + } else { + run_mha_fwd_splitkv_mla(params, stream); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu new file mode 100644 index 0000000..abdaf7b --- /dev/null +++ b/csrc/flash_fwd_mla_fp16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..d96acd8 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; run_flash_splitkv_fwd_mla>(params, stream); } - -static constexpr int MaxBatchSize = 4096; - -__global__ void __launch_bounds__(256, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { - int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; - int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; - int block_size_n = params.block_size_n; - int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; - int num_sm_parts = params.num_sm_parts; - - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; - - int total_num_blocks = 0; - for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); - total_num_blocks += num_blocks + fixed_overhead_num_blocks; - num_blocks_shared[i] = num_blocks; - } - for (int offset = 16; offset >= 1; offset /= 2) { - total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); - } - __syncwarp(); - - if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; - num_splits_shared[0] = 0; - for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; - tile_scheduler_metadata1 = now_n_split_idx; - int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; - int now_remain_blocks = num_blocks - now_block; - if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { - cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; - remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; - now_block = 0; - now_n_split_idx = 0; - } else { - if (remain_payload - fixed_overhead_num_blocks > 0) { - now_block += remain_payload - fixed_overhead_num_blocks; - ++now_n_split_idx; - remain_payload = 0; - } - break; - } - } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; - } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); - } - __syncwarp(); - - for (int i = threadIdx.x; i <= batch_size; i += 32) { - num_splits_ptr[i] = num_splits_shared[i]; - } -} - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); -} diff --git a/csrc/flash_fwd_mla_metadata.cu b/csrc/flash_fwd_mla_metadata.cu new file mode 100644 index 0000000..82f5b5a --- /dev/null +++ b/csrc/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 0a3bd17..662a301 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ ext_modules.append( sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_fp16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", ], extra_compile_args={ "cxx": cxx_args, diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 8db5db0..e676fa7 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -1,10 +1,11 @@ +import argparse import math import random import torch import triton -from flash_mla import get_mla_metadata, flash_mla_with_kvcache +from flash_mla import flash_mla_with_kvcache, get_mla_metadata def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + ) cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: @@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) blocked_v = blocked_k[..., :dv] - tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv + ) def flash_mla(): return flash_mla_with_kvcache( - q, blocked_k, block_table, cache_seqlens, dv, - tile_scheduler_metadata, num_splits, causal=causal, + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, ) def ref_mla(): @@ -91,14 +106,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(q.dtype).bits // 8 + ) + print( + f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" + ) -if __name__ == "__main__": - dtype = torch.bfloat16 +def main(torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + torch.set_default_dtype(torch_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) @@ -114,3 +132,22 @@ if __name__ == "__main__": for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type to use for testing (bf16 or fp16)", + ) + + args = parser.parse_args() + + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + + main(torch_dtype) From 4da4dbd303eabbcdb5806051d82430ded46625d6 Mon Sep 17 00:00:00 2001 From: zhengsize Date: Mon, 24 Feb 2025 22:34:22 +0800 Subject: [PATCH 03/15] feat: add benchmark for flash_infer vs flash_mla --- benchmark/bench_flash_mla.py | 514 +++++++++++++++++++++++++++++++++++ benchmark/visualize.py | 19 ++ 2 files changed, 533 insertions(+) create mode 100644 benchmark/bench_flash_mla.py create mode 100644 benchmark/visualize.py diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py new file mode 100644 index 0000000..7b0e7b4 --- /dev/null +++ b/benchmark/bench_flash_mla.py @@ -0,0 +1,514 @@ +# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a +import math +import random + +import torch +import triton +import triton.language as tl +import argparse + +# pip install flashinfer-python +from flash_mla import get_mla_metadata, flash_mla_with_kvcache +import flashinfer + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + +@torch.inference_mode() +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, blocked_k, block_table, cache_seqlens, dv, + tile_scheduler_metadata, num_splits, causal=causal, + ) + + out_flash, lse_flash = flash_mla() + t = triton.testing.do_bench(flash_mla) + return out_flash, lse_flash, t + + +@torch.inference_mode() +def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + + kv_indptr = [0] + kv_indices = [] + for i in range(b): + seq_len = cache_seqlens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_table[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + for seq_len in cache_seqlens[1:]: + kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + + q_indptr = torch.arange(0, b + 1).int() * s_q + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.int8), + backend="fa3" + ) + mla_wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + cache_seqlens, + h_q, + dv, + d-dv, + block_size, + causal, + 1 / math.sqrt(d), + q.dtype, + blocked_k.dtype, + ) + + def flash_infer(): + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True) + return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) + + out_flash, lse_flash = flash_infer() + t = triton.testing.do_bench(flash_infer) + return out_flash, lse_flash, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + ) + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + num_warps=4, + num_stages=2, + ) + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "flash_mla": run_flash_mla, + "flash_infer": run_flash_infer, + "flash_mla_triton": run_flash_mla_triton, +} + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flash_infer", "flash_mla_triton"]: + # flash_infer has a different lse return value + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + return bytes / 10 ** 6 / perf_b + + +available_targets = [ + "torch", + "flash_mla", + "flash_infer", + "flash_mla_triton", +] + +shape_configs = [ + {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16} + for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="flash_mla") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + with open("all_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') + elif args.compare: + compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + elif args.one: + compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) \ No newline at end of file diff --git a/benchmark/visualize.py b/benchmark/visualize.py new file mode 100644 index 0000000..db62519 --- /dev/null +++ b/benchmark/visualize.py @@ -0,0 +1,19 @@ +import matplotlib.pyplot as plt +import pandas as pd + +file_path = 'all_perf.csv' + +df = pd.read_csv(file_path) + +names = df['name'].unique() + +for name in names: + subset = df[df['name'] == name] + plt.plot(subset['seqlen'], subset['bw'], label=name) + +plt.title('bandwidth') +plt.xlabel('seqlen') +plt.ylabel('bw (GB/s)') +plt.legend() + +plt.savefig('bandwidth_vs_seqlen.png') \ No newline at end of file From c4c5912b058668e44ba1aebbebe1097305303cc7 Mon Sep 17 00:00:00 2001 From: "chunyang.wen" Date: Tue, 25 Feb 2025 00:11:57 +0800 Subject: [PATCH 04/15] Update docstring --- flash_mla/flash_mla_interface.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 2f3aa46..b2922af 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -16,7 +16,7 @@ def get_mla_metadata( num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. - Return: + Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ @@ -40,13 +40,13 @@ def flash_mla_with_kvcache( k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - Return: + Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ From 922f63bdaa64ad4ae37174269f35fd81c62d2952 Mon Sep 17 00:00:00 2001 From: zhengsize Date: Mon, 24 Feb 2025 23:58:52 +0800 Subject: [PATCH 05/15] add gitignore for png and csv files in benchmark --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index bfe80b5..5f9e980 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ build *.egg-info/ __pycache__/ dist/ +*perf.csv +*.png From a3b74b85749f0b450f5fda1c423b1ad56df704db Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 10:01:59 -0800 Subject: [PATCH 06/15] add flag to disable FP16 compile --- csrc/flash_api.cpp | 9 +++++++-- setup.py | 29 +++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index b7b11f1..d2567fe 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,6 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -189,9 +188,15 @@ mha_fwd_kvcache_mla( if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, stream); - } else { + } + #ifndef FLASH_MLA_DISABLE_FP16 + else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, stream); } + #endif + else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/setup.py b/setup.py index 662a301..6377b1e 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,29 @@ from torch.utils.cpp_extension import ( IS_WINDOWS, ) +DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] +def get_sources(): + sources = [ + "csrc/flash_api.cpp", + "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", + ] + + if not DISABLE_FP16: + sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + + return sources + +def get_features_args(): + features_args = [] + if DISABLE_FP16: + features_args.append("-DFLASH_MLA_DISABLE_FP16") + return features_args subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) @@ -34,14 +52,9 @@ ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla_cuda", - sources=[ - "csrc/flash_api.cpp", - "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp16_sm90.cu", - "csrc/flash_fwd_mla_metadata.cu", - ], + sources=get_sources(), extra_compile_args={ - "cxx": cxx_args, + "cxx": cxx_args + get_features_args(), "nvcc": append_nvcc_threads( [ "-O3", @@ -59,7 +72,7 @@ ext_modules.append( "--ptxas-options=-v,--register-usage-level=10" ] + cc_flag - ), + ) + get_features_args(), }, include_dirs=[ Path(this_dir) / "csrc", From e1e9fa98f80f34c3b155fd483c38227abb5f400d Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 25 Feb 2025 09:18:11 +0800 Subject: [PATCH 07/15] Style fix --- tests/test_flash_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index e676fa7..0abe9d2 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -60,7 +60,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( float("nan") ) blocked_v = blocked_k[..., :dv] From 4edea86f9e85eea6ea41dd14b2798fc6a0e2d80c Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Wed, 26 Feb 2025 00:05:57 +0800 Subject: [PATCH 08/15] cuda12.8 recommendation --- README.md | 5 +++-- setup.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4027334..6d0bcb6 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ python setup.py install python tests/test_flash_mla.py ``` -Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.6. +Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. ### Usage @@ -42,6 +42,7 @@ for i in range(num_layers): - Hopper GPUs - CUDA 12.3 and above + - **But we highly recommend 12.8 or above for the best performance** - PyTorch 2.0 and above ## Acknowledgement @@ -52,7 +53,7 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ```bibtex @misc{flashmla2025, - title={FlashMLA: Efficient MLA decoding kernel}, + title={FlashMLA: Efficient MLA decoding kernels}, author={Jiashi Li}, year={2025}, publisher = {GitHub}, diff --git a/setup.py b/setup.py index 6377b1e..cd311f2 100644 --- a/setup.py +++ b/setup.py @@ -13,10 +13,12 @@ from torch.utils.cpp_extension import ( DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_sources(): sources = [ "csrc/flash_api.cpp", @@ -29,12 +31,14 @@ def get_sources(): return sources + def get_features_args(): features_args = [] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) cc_flag = [] From b67980309b4140dd18edd8b50919243c4dbf0415 Mon Sep 17 00:00:00 2001 From: "yangsijia.614" Date: Tue, 25 Feb 2025 23:52:54 +0800 Subject: [PATCH 09/15] fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them --- benchmark/bench_flash_mla.py | 18 ++++++++++++------ benchmark/visualize.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 7b0e7b4..14e1352 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -1,15 +1,16 @@ # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a +import argparse import math import random +import flashinfer import torch import triton import triton.language as tl -import argparse # pip install flashinfer-python -from flash_mla import get_mla_metadata, flash_mla_with_kvcache -import flashinfer +from flash_mla import flash_mla_with_kvcache, get_mla_metadata + def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): query = query.float() @@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): @@ -501,7 +503,8 @@ def get_args(): if __name__ == "__main__": args = get_args() - with open("all_perf.csv", "w") as fout: + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: fout.write("name,batch,seqlen,head,bw\n") for shape in shape_configs: if args.all: @@ -509,6 +512,9 @@ if __name__ == "__main__": perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') elif args.compare: - compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') elif args.one: - compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) \ No newline at end of file + perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file diff --git a/benchmark/visualize.py b/benchmark/visualize.py index db62519..c1fb37e 100644 --- a/benchmark/visualize.py +++ b/benchmark/visualize.py @@ -1,7 +1,17 @@ +import argparse + import matplotlib.pyplot as plt import pandas as pd -file_path = 'all_perf.csv' + +def parse_args(): + parser = argparse.ArgumentParser(description='Visualize benchmark results') + parser.add_argument('--file', type=str, default='all_perf.csv', + help='Path to the CSV file with benchmark results (default: all_perf.csv)') + return parser.parse_args() + +args = parse_args() +file_path = args.file df = pd.read_csv(file_path) @@ -16,4 +26,4 @@ plt.xlabel('seqlen') plt.ylabel('bw (GB/s)') plt.legend() -plt.savefig('bandwidth_vs_seqlen.png') \ No newline at end of file +plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png') \ No newline at end of file From 6492cabb28d0319f4068a5e8a393b11d5f9bb896 Mon Sep 17 00:00:00 2001 From: hpp Date: Wed, 26 Feb 2025 11:26:42 +0800 Subject: [PATCH 10/15] add Community Support of [MetaX] and [Moore Threads] --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 6d0bcb6..95aed11 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,17 @@ for i in range(num_layers): FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. +## Community Support + +### MetaX + +For the MetaX GPU【https://www.metax-tech.com】, the corresponding FlashMLA version link is as follows: +GitHub - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) + +### Moore Threads (WIP) +For the Moore Threads GPU【https://www.mthreads.com/】, the corresponding FlashMLA version link is as follows: +GitHub - [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) + ## Citation ```bibtex From 966eedc2f72f832083e4011c3f73f2ebcf501856 Mon Sep 17 00:00:00 2001 From: Jiashi Li <31004720+beginlner@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:30:45 +0800 Subject: [PATCH 11/15] Fix readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 95aed11..252bae1 100644 --- a/README.md +++ b/README.md @@ -53,12 +53,12 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ### MetaX -For the MetaX GPU【https://www.metax-tech.com】, the corresponding FlashMLA version link is as follows: -GitHub - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) +For [MetaX](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +- [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads (WIP) -For the Moore Threads GPU【https://www.mthreads.com/】, the corresponding FlashMLA version link is as follows: -GitHub - [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) +For [Moore Threads](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +- [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) ## Citation From 480405ada9beff03b3e99a1dc28c9a35deb8a05c Mon Sep 17 00:00:00 2001 From: Jiashi Li <31004720+beginlner@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:32:39 +0800 Subject: [PATCH 12/15] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 252bae1..0bb3e52 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ For [MetaX](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads (WIP) -For [Moore Threads](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +For [Moore Threads](https://www.mthreads.com) GPUs, the corresponding FlashMLA version can be found at: - [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) ## Citation From 4430e398d9228fd0e04b4ef97d8b7c388760c4e6 Mon Sep 17 00:00:00 2001 From: hpp Date: Thu, 27 Feb 2025 09:39:18 +0800 Subject: [PATCH 13/15] add Community Support of [Hygon DCU] [Intellifusion] [Iluvatar Corex] --- README.md | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0bb3e52..ae9f900 100644 --- a/README.md +++ b/README.md @@ -53,12 +53,41 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ### MetaX -For [MetaX](https://www.metax-tech.com) GPUs, the corresponding FlashMLA version can be found at: +For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). + +The corresponding FlashMLA version can be found at: - [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) -### Moore Threads (WIP) -For [Moore Threads](https://www.mthreads.com) GPUs, the corresponding FlashMLA version can be found at: -- [MooreThreads/MT-DeepSeek](https://github.com/MooreThreads/MT-DeepSeek) + +### Moore Threads +For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). + +The corresponding FlashMLA version is available on GitHub: +[MooreThreads/MT-flashMLA](GitHub - MooreThreads/MT-flashMLA: Fork from https://github.com/deepseek-ai/FlashMLA). + + +### Hygon DCU + +For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/). + +The corresponding FlashMLA version is available here: +[OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). + + +### Intellifusion + +For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com). + +The corresponding FlashMLA version is available on Gitee: +[Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). + + +### Iluvatar Corex + +For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com). + +The corresponding FlashMLA version is available on GitHub: +[Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) ## Citation From 77d9d8d21bea77c7a6dba2fb86bb993af7063c1d Mon Sep 17 00:00:00 2001 From: hpp Date: Thu, 27 Feb 2025 09:40:47 +0800 Subject: [PATCH 14/15] add Community Support of [Hygon DCU] [Intellifusion] [Iluvatar Corex] --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ae9f900..2d2cfd6 100644 --- a/README.md +++ b/README.md @@ -56,14 +56,14 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). The corresponding FlashMLA version can be found at: -- [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) +[MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). The corresponding FlashMLA version is available on GitHub: -[MooreThreads/MT-flashMLA](GitHub - MooreThreads/MT-flashMLA: Fork from https://github.com/deepseek-ai/FlashMLA). +[MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). ### Hygon DCU From 1aef31d163f11c0f2c7329fcda71d2985bfb47a0 Mon Sep 17 00:00:00 2001 From: hpp Date: Thu, 27 Feb 2025 09:42:09 +0800 Subject: [PATCH 15/15] reformat Community Support section --- README.md | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2d2cfd6..b79757c 100644 --- a/README.md +++ b/README.md @@ -51,43 +51,34 @@ FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash- ## Community Support -### MetaX - +### MetaX For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com). -The corresponding FlashMLA version can be found at: -[MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) +The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA) ### Moore Threads For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/). -The corresponding FlashMLA version is available on GitHub: -[MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). +The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA). ### Hygon DCU - For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/). -The corresponding FlashMLA version is available here: -[OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). +The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention). ### Intellifusion - For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com). -The corresponding FlashMLA version is available on Gitee: -[Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). +The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py). ### Iluvatar Corex - For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com). -The corresponding FlashMLA version is available on GitHub: -[Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) +The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla) ## Citation