enable fp8

This commit is contained in:
chenhongmin.will 2025-02-25 09:03:02 +08:00
parent dae0690055
commit d833dbd711

View File

@ -33,6 +33,8 @@ struct Flash_fwd_kernel_traits_mla {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t>;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
@ -49,6 +51,8 @@ struct Flash_fwd_kernel_traits_mla {
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
using TiledMma = decltype(make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
GMMA::Major::K, GMMA::Major::K>(),
@ -57,7 +61,7 @@ struct Flash_fwd_kernel_traits_mla {
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
using TiledMmaO = decltype(make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
GMMA::Major::K, GMMA::Major::MN>(),
GMMA::Major::K, MmaMajorV>(),
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(