From 15a82b81b868049677bc86ed19c56809b1863511 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 00:25:25 -0800 Subject: [PATCH 1/4] replace c10 optional with std optional --- csrc/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a1cb8e..3184465 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -61,7 +61,7 @@ std::vector mha_fwd_kvcache_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v + std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq From 65fb7732fc5f95edea30019296b1940c759ae321 Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 01:58:53 -0800 Subject: [PATCH 2/4] support fp16 --- README.md | 2 +- csrc/flash_api.cpp | 9 +++- csrc/flash_fwd_mla_fp16_sm90.cu | 3 ++ csrc/flash_fwd_mla_kernel.h | 76 -------------------------------- csrc/flash_fwd_mla_metadata.cu | 77 +++++++++++++++++++++++++++++++++ setup.py | 2 + tests/test_flash_mla.py | 61 +++++++++++++++++++++----- 7 files changed, 139 insertions(+), 91 deletions(-) create mode 100644 csrc/flash_fwd_mla_fp16_sm90.cu create mode 100644 csrc/flash_fwd_mla_metadata.cu diff --git a/README.md b/README.md index bb55395..4027334 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. Currently released: -- BF16 +- BF16, FP16 - Paged kvcache with block size of 64 ## Quick start diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 3184465..b7b11f1 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,7 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -186,7 +186,12 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + + if (q_dtype == torch::kBFloat16) { + run_mha_fwd_splitkv_mla(params, stream); + } else { + run_mha_fwd_splitkv_mla(params, stream); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu new file mode 100644 index 0000000..abdaf7b --- /dev/null +++ b/csrc/flash_fwd_mla_fp16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..d96acd8 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; run_flash_splitkv_fwd_mla>(params, stream); } - -static constexpr int MaxBatchSize = 4096; - -__global__ void __launch_bounds__(256, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { - int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; - int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; - int block_size_n = params.block_size_n; - int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; - int num_sm_parts = params.num_sm_parts; - - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; - - int total_num_blocks = 0; - for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); - total_num_blocks += num_blocks + fixed_overhead_num_blocks; - num_blocks_shared[i] = num_blocks; - } - for (int offset = 16; offset >= 1; offset /= 2) { - total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); - } - __syncwarp(); - - if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; - num_splits_shared[0] = 0; - for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; - tile_scheduler_metadata1 = now_n_split_idx; - int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; - int now_remain_blocks = num_blocks - now_block; - if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { - cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; - remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; - now_block = 0; - now_n_split_idx = 0; - } else { - if (remain_payload - fixed_overhead_num_blocks > 0) { - now_block += remain_payload - fixed_overhead_num_blocks; - ++now_n_split_idx; - remain_payload = 0; - } - break; - } - } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; - } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); - } - __syncwarp(); - - for (int i = threadIdx.x; i <= batch_size; i += 32) { - num_splits_ptr[i] = num_splits_shared[i]; - } -} - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); -} diff --git a/csrc/flash_fwd_mla_metadata.cu b/csrc/flash_fwd_mla_metadata.cu new file mode 100644 index 0000000..82f5b5a --- /dev/null +++ b/csrc/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 0a3bd17..662a301 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ ext_modules.append( sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_fp16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", ], extra_compile_args={ "cxx": cxx_args, diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 8db5db0..e676fa7 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -1,10 +1,11 @@ +import argparse import math import random import torch import triton -from flash_mla import get_mla_metadata, flash_mla_with_kvcache +from flash_mla import flash_mla_with_kvcache, get_mla_metadata def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + ) cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: @@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): q = torch.randn(b, s_q, h_q, d) block_size = 64 - block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) blocked_v = blocked_k[..., :dv] - tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv + ) def flash_mla(): return flash_mla_with_kvcache( - q, blocked_k, block_table, cache_seqlens, dv, - tile_scheduler_metadata, num_splits, causal=causal, + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, ) def ref_mla(): @@ -91,14 +106,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(q.dtype).bits // 8 + ) + print( + f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" + ) -if __name__ == "__main__": - dtype = torch.bfloat16 +def main(torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + torch.set_default_dtype(torch_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) @@ -114,3 +132,22 @@ if __name__ == "__main__": for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type to use for testing (bf16 or fp16)", + ) + + args = parser.parse_args() + + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + + main(torch_dtype) From a3b74b85749f0b450f5fda1c423b1ad56df704db Mon Sep 17 00:00:00 2001 From: Sijia Chen Date: Mon, 24 Feb 2025 10:01:59 -0800 Subject: [PATCH 3/4] add flag to disable FP16 compile --- csrc/flash_api.cpp | 9 +++++++-- setup.py | 29 +++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index b7b11f1..d2567fe 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -77,7 +77,6 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat16); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -189,9 +188,15 @@ mha_fwd_kvcache_mla( if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, stream); - } else { + } + #ifndef FLASH_MLA_DISABLE_FP16 + else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, stream); } + #endif + else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/setup.py b/setup.py index 662a301..6377b1e 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,29 @@ from torch.utils.cpp_extension import ( IS_WINDOWS, ) +DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] +def get_sources(): + sources = [ + "csrc/flash_api.cpp", + "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_metadata.cu", + ] + + if not DISABLE_FP16: + sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + + return sources + +def get_features_args(): + features_args = [] + if DISABLE_FP16: + features_args.append("-DFLASH_MLA_DISABLE_FP16") + return features_args subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) @@ -34,14 +52,9 @@ ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla_cuda", - sources=[ - "csrc/flash_api.cpp", - "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp16_sm90.cu", - "csrc/flash_fwd_mla_metadata.cu", - ], + sources=get_sources(), extra_compile_args={ - "cxx": cxx_args, + "cxx": cxx_args + get_features_args(), "nvcc": append_nvcc_threads( [ "-O3", @@ -59,7 +72,7 @@ ext_modules.append( "--ptxas-options=-v,--register-usage-level=10" ] + cc_flag - ), + ) + get_features_args(), }, include_dirs=[ Path(this_dir) / "csrc", From e1e9fa98f80f34c3b155fd483c38227abb5f400d Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 25 Feb 2025 09:18:11 +0800 Subject: [PATCH 4/4] Style fix --- tests/test_flash_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index e676fa7..0abe9d2 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -60,7 +60,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( float("nan") ) blocked_v = blocked_k[..., :dv]