|
|
|
|
@@ -28,9 +28,10 @@ constexpr auto getSmemLayoutK() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
|
|
|
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
|
|
|
|
struct Flash_fwd_kernel_traits_mla {
|
|
|
|
|
using Element = elem_type;
|
|
|
|
|
using ElementO = elem_type_o;
|
|
|
|
|
using ElementAccum = float;
|
|
|
|
|
using index_t = int64_t;
|
|
|
|
|
|
|
|
|
|
@@ -48,8 +49,10 @@ struct Flash_fwd_kernel_traits_mla {
|
|
|
|
|
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
|
|
|
|
static_assert(kHeadDimV % 32 == 0);
|
|
|
|
|
static_assert(kHeadDimV <= kHeadDim);
|
|
|
|
|
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
|
|
|
|
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
|
|
|
|
|
|
|
|
|
static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32);
|
|
|
|
|
static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32;
|
|
|
|
|
static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3;
|
|
|
|
|
|
|
|
|
|
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
|
|
|
|
|
|
|
|
|
|
@@ -81,12 +84,12 @@ struct Flash_fwd_kernel_traits_mla {
|
|
|
|
|
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
|
|
|
|
|
|
|
|
|
using SmemLayoutAtomO = decltype(composition(
|
|
|
|
|
Swizzle<kSwizzle, 3, 3>{},
|
|
|
|
|
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
|
|
|
|
Swizzle<kSwizzleO, 3, 3>{},
|
|
|
|
|
Layout<Shape<Int<8>, Int<kBlockKSmemO>>, Stride<Int<kBlockKSmemO>, _1>>{}));
|
|
|
|
|
using SmemLayoutO = decltype(tile_to_shape(
|
|
|
|
|
SmemLayoutAtomO{},
|
|
|
|
|
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
|
|
|
|
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
|
|
|
|
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, ElementO>;
|
|
|
|
|
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
|
|
|
|
|
|
|
|
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
|
|
|
|
@@ -96,31 +99,38 @@ struct Flash_fwd_kernel_traits_mla {
|
|
|
|
|
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
|
|
|
|
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
|
|
|
|
|
|
|
|
|
static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO);
|
|
|
|
|
static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO");
|
|
|
|
|
static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO;
|
|
|
|
|
static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO");
|
|
|
|
|
|
|
|
|
|
using GmemLayoutAtom = Layout<
|
|
|
|
|
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
|
|
|
|
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using GmemTiledCopy = decltype(make_tiled_copy(
|
|
|
|
|
Copy_Atom<Gmem_copy_struct, Element>{},
|
|
|
|
|
GmemLayoutAtom{},
|
|
|
|
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
|
|
|
|
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 vals per read
|
|
|
|
|
|
|
|
|
|
using GmemLayoutAtomO = Layout<
|
|
|
|
|
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
|
|
|
|
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
|
|
|
|
Shape<Int<kNThreadsS / kGmemThreadsPerRowO>, Int<kGmemThreadsPerRowO>>,
|
|
|
|
|
Stride<Int<kGmemThreadsPerRowO>, _1>>;
|
|
|
|
|
using GmemTiledCopyO = decltype(make_tiled_copy(
|
|
|
|
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
|
|
|
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementO>{},
|
|
|
|
|
GmemLayoutAtomO{},
|
|
|
|
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
|
|
|
|
Layout<Shape<_1, Int<kGmemElemsPerLoadO>>>{})); // Val layout, 8 vals per store
|
|
|
|
|
|
|
|
|
|
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
|
|
|
|
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
|
|
|
|
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum;
|
|
|
|
|
using GmemLayoutAtomOaccum = Layout<
|
|
|
|
|
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
|
|
|
|
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
|
|
|
|
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
|
|
|
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
|
|
|
|
GmemLayoutAtomOaccum{},
|
|
|
|
|
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
|
|
|
|
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
namespace flash {
|
|
|
|
|
@@ -597,12 +607,12 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream
|
|
|
|
|
CHECK_CUDA_KERNEL_LAUNCH();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T, int Headdim>
|
|
|
|
|
template<typename T, typename To, int Headdim>
|
|
|
|
|
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
|
|
|
|
static_assert(Headdim == 576);
|
|
|
|
|
FLASH_ASSERT(params.d_v == 512);
|
|
|
|
|
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
|
|
|
|
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
|
|
|
|
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>;
|
|
|
|
|
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|