mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
update gmem
This commit is contained in:
parent
d833dbd711
commit
b67a18f850
@ -186,7 +186,7 @@ mha_fwd_kvcache_mla(
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(head_size == 576);
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||
|
||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||
|
@ -1,3 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
@ -1,3 +1,3 @@
|
||||
#include "flash_fwd_mla_kernel.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
@ -28,9 +28,10 @@ constexpr auto getSmemLayoutK() {
|
||||
}
|
||||
}
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
struct Flash_fwd_kernel_traits_mla {
|
||||
using Element = elem_type;
|
||||
using ElementO = elem_type_o;
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
@ -48,8 +49,10 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
||||
static_assert(kHeadDimV % 32 == 0);
|
||||
static_assert(kHeadDimV <= kHeadDim);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3;
|
||||
|
||||
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
|
||||
|
||||
@ -81,12 +84,12 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(composition(
|
||||
Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
Swizzle<kSwizzleO, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmemO>>, Stride<Int<kBlockKSmemO>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, ElementO>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
@ -96,31 +99,38 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
|
||||
static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO");
|
||||
static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO");
|
||||
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomO = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowO>, Int<kGmemThreadsPerRowO>>,
|
||||
Stride<Int<kGmemThreadsPerRowO>, _1>>;
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementO>{},
|
||||
GmemLayoutAtomO{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
Layout<Shape<_1, Int<kGmemElemsPerLoadO>>>{})); // Val layout, 8 vals per store
|
||||
|
||||
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum;
|
||||
using GmemLayoutAtomOaccum = Layout<
|
||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
||||
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
||||
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
@ -597,12 +607,12 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
template<typename T, typename To, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(Headdim == 576);
|
||||
FLASH_ASSERT(params.d_v == 512);
|
||||
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>;
|
||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ static constexpr int TileSchedulerMetaDataSize = 8;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim>
|
||||
template<typename T, typename To, int Headdim>
|
||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
||||
struct Mla_metadata_params {
|
||||
|
Loading…
Reference in New Issue
Block a user