From f6fab1b915eb398ebc0ced017e2300b57c62eff3 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 10:17:29 +0800 Subject: [PATCH] change to use per_tensor --- csrc/flash_api.cpp | 4 ++-- csrc/flash_fwd_mla_kernel.h | 10 ++++++++-- tests/test_flash_mla.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a0caa1..9631b32 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -118,8 +118,8 @@ mha_fwd_kvcache_mla( TORCH_CHECK(descale_k.stride(-1) == 1); TORCH_CHECK(descale_q.dtype() == torch::kFloat); TORCH_CHECK(descale_k.dtype() == torch::kFloat); - CHECK_SHAPE(descale_q, batch_size); - CHECK_SHAPE(descale_k, batch_size); + CHECK_SHAPE(descale_q, 1); + CHECK_SHAPE(descale_k, 1); } if (seqlen_q_ori == 1) { is_causal = false; } diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 261a275..874aded 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -260,7 +260,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { + SharedStorage &shared_storage, const float descale_q, const float descale_k) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; @@ -494,6 +494,12 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (begin_idx >= params.b) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + float descale_q, descale_k; + if constexpr (Kernel_traits::Is_FP8) { + descale_q = __ldg(params.descale_q_ptr); + descale_k = __ldg(params.descale_k_ptr); + } + #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; @@ -504,7 +510,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k); } } diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index f700864..5c68dba 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -67,8 +67,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = if use_fp8: nonlocal q, blocked_k, blocked_v fp8_dtype = torch.float8_e4m3fn - descale_q = torch.ones((b), dtype=torch.float32) - descale_k = torch.ones((b), dtype=torch.float32) + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) q_fp8 = q.to(fp8_dtype) blocked_k_fp8 = blocked_k.to(fp8_dtype)