update gmem

This commit is contained in:
chenhongmin.will 2025-02-25 09:40:56 +08:00
parent d833dbd711
commit b67a18f850
5 changed files with 29 additions and 19 deletions

View File

@ -186,7 +186,7 @@ mha_fwd_kvcache_mla(
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});

View File

@ -1,3 +1,3 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -1,3 +1,3 @@
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);

View File

@ -28,9 +28,10 @@ constexpr auto getSmemLayoutK() {
}
}
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla {
using Element = elem_type;
using ElementO = elem_type_o;
using ElementAccum = float;
using index_t = int64_t;
@ -48,8 +49,10 @@ struct Flash_fwd_kernel_traits_mla {
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 32 == 0);
static_assert(kHeadDimV <= kHeadDim);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3;
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
@ -81,12 +84,12 @@ struct Flash_fwd_kernel_traits_mla {
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
Swizzle<kSwizzleO, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmemO>>, Stride<Int<kBlockKSmemO>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, ElementO>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
@ -96,31 +99,38 @@ struct Flash_fwd_kernel_traits_mla {
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO);
static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO");
static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO;
static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomO = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
Shape<Int<kNThreadsS / kGmemThreadsPerRowO>, Int<kGmemThreadsPerRowO>>,
Stride<Int<kGmemThreadsPerRowO>, _1>>;
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementO>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
Layout<Shape<_1, Int<kGmemElemsPerLoadO>>>{})); // Val layout, 8 vals per store
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum;
using GmemLayoutAtomOaccum = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
};
namespace flash {
@ -597,12 +607,12 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename T, int Headdim>
template<typename T, typename To, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
static_assert(Headdim == 576);
FLASH_ASSERT(params.d_v == 512);
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>;
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}

View File

@ -47,7 +47,7 @@ static constexpr int TileSchedulerMetaDataSize = 8;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
template<typename T, typename To, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);
struct Mla_metadata_params {