From dfe8ffc75abbd1ac2a1f3b342777ab7354099fc4 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 22:34:01 +0800 Subject: [PATCH] enable fp8 api --- csrc/flash_api.cpp | 21 +++++++++++++++------ csrc/flash_fwd_mla_kernel.h | 4 ++-- flash_mla/flash_mla_interface.py | 1 + 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 1f44b68..4be3c1c 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -68,7 +68,10 @@ mha_fwd_kvcache_mla( const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 + const at::Tensor &num_splits, // batch_size + 1 + c10::optional &descale_q, // batch_size + c10::optional &descale_k, // batch_size + c10::optional &descale_v // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -76,9 +79,9 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16); - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + auto q_dtype = q.scalar_type(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn); + TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -128,7 +131,8 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); + auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; + at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_mla_params params = {}; @@ -186,7 +190,12 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + + if (q_dtype == torch::kFloat8_e4m3fn) { + run_mha_fwd_splitkv_mla(params, stream); + } else { + run_mha_fwd_splitkv_mla(params, stream); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index d4940f1..261a275 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -278,8 +278,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); auto sVt = cute::conditional_return( - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}), - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{})); + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 2f3aa46..33c0657 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -63,5 +63,6 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + None, None, None, ) return out, softmax_lse