mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
enable fp8
This commit is contained in:
parent
dae0690055
commit
d833dbd711
@ -33,6 +33,8 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t>;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
@ -49,6 +51,8 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
|
||||
|
||||
using TiledMma = decltype(make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
||||
GMMA::Major::K, GMMA::Major::K>(),
|
||||
@ -57,7 +61,7 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
|
||||
using TiledMmaO = decltype(make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
GMMA::Major::K, MmaMajorV>(),
|
||||
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
|
Loading…
Reference in New Issue
Block a user