From d833dbd7111e44139dec9615bb544e7d956a856f Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 09:03:02 +0800 Subject: [PATCH] enable fp8 --- csrc/flash_fwd_mla_kernel.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..9262632 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -33,6 +33,8 @@ struct Flash_fwd_kernel_traits_mla { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; + + static constexpr bool Is_FP8 = cute::is_same_v; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; @@ -49,6 +51,8 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; + using TiledMma = decltype(make_tiled_mma( cute::GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), @@ -57,7 +61,7 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; using TiledMmaO = decltype(make_tiled_mma( cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), + GMMA::Major::K, MmaMajorV>(), Layout, Int, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape(