From 8b939854d8869d94c9e63299c098ca64535c2173 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 19:43:24 +0800 Subject: [PATCH] enable scale --- csrc/flash_fwd_mla_kernel.h | 29 +++++++++++++++-------------- csrc/softmax.h | 4 ++-- tests/test_flash_mla.py | 2 +- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index bef20ee..6e92b5e 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -151,9 +151,6 @@ struct Flash_fwd_kernel_traits_mla { // ------ for f8 ------ using SmemFp8Tranpose = SmemTransposeFp8_64x64; - // using SmemLayoutVtMMa = decltype(tile_to_shape( - // getSmemLayoutK(), - // Shape, Int >{})); using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; }; @@ -186,7 +183,7 @@ struct SharedStorageMLA { template __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kNThreadsS = Kernel_traits::kNThreadsS; @@ -203,7 +200,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const const int split_offset = __ldg(params.num_splits_ptr + bidb); - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(tOrO, scale_softmax, descale_k); using ElementO = std::conditional_t; Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -275,7 +272,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage, const float descale_q, const float descale_k) { + SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; @@ -372,10 +369,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // We have key_padding_mask so we'll need to Check_inf Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) + ? softmax.template softmax(tSrS, scale_softmax_log2) : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); + softmax.template softmax(tSrS, scale_softmax_log2) + : softmax.template softmax(tSrS, scale_softmax_log2); if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); @@ -535,9 +532,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); } template @@ -560,10 +557,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (begin_idx >= params.b) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - float descale_q, descale_k; + float descale_k = 1.f; + float scale_softmax = params.scale_softmax; + float scale_softmax_log2 = params.scale_softmax_log2; if constexpr (Kernel_traits::Is_FP8) { - descale_q = __ldg(params.descale_q_ptr); + float descale_q = __ldg(params.descale_q_ptr); descale_k = __ldg(params.descale_k_ptr); + scale_softmax = scale_softmax * descale_q * descale_k; + scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; } #pragma unroll 1 @@ -576,7 +577,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); } } diff --git a/csrc/softmax.h b/csrc/softmax.h index 4ab6ae9..bcb8cac 100644 --- a/csrc/softmax.h +++ b/csrc/softmax.h @@ -174,7 +174,7 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -184,7 +184,7 @@ struct Softmax { #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index b840a97..ff7cd27 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -35,7 +35,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") if use_fp8: - assert cos_diff < 1e-3 + assert cos_diff < 1e-2 else: assert cos_diff < 1e-5