mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
update gmem
This commit is contained in:
parent
d833dbd711
commit
b67a18f850
@ -186,7 +186,7 @@ mha_fwd_kvcache_mla(
|
|||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
TORCH_CHECK(head_size == 576);
|
TORCH_CHECK(head_size == 576);
|
||||||
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
|
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
|
||||||
|
|
||||||
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
|
||||||
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
#include "flash_fwd_mla_kernel.h"
|
#include "flash_fwd_mla_kernel.h"
|
||||||
|
|
||||||
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
#include "flash_fwd_mla_kernel.h"
|
#include "flash_fwd_mla_kernel.h"
|
||||||
|
|
||||||
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
@ -28,9 +28,10 @@ constexpr auto getSmemLayoutK() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||||
struct Flash_fwd_kernel_traits_mla {
|
struct Flash_fwd_kernel_traits_mla {
|
||||||
using Element = elem_type;
|
using Element = elem_type;
|
||||||
|
using ElementO = elem_type_o;
|
||||||
using ElementAccum = float;
|
using ElementAccum = float;
|
||||||
using index_t = int64_t;
|
using index_t = int64_t;
|
||||||
|
|
||||||
@ -48,8 +49,10 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
||||||
static_assert(kHeadDimV % 32 == 0);
|
static_assert(kHeadDimV % 32 == 0);
|
||||||
static_assert(kHeadDimV <= kHeadDim);
|
static_assert(kHeadDimV <= kHeadDim);
|
||||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
|
||||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||||
|
static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32;
|
||||||
|
static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3;
|
||||||
|
|
||||||
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
|
static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K;
|
||||||
|
|
||||||
@ -81,12 +84,12 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||||
|
|
||||||
using SmemLayoutAtomO = decltype(composition(
|
using SmemLayoutAtomO = decltype(composition(
|
||||||
Swizzle<kSwizzle, 3, 3>{},
|
Swizzle<kSwizzleO, 3, 3>{},
|
||||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
Layout<Shape<Int<8>, Int<kBlockKSmemO>>, Stride<Int<kBlockKSmemO>, _1>>{}));
|
||||||
using SmemLayoutO = decltype(tile_to_shape(
|
using SmemLayoutO = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomO{},
|
SmemLayoutAtomO{},
|
||||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
||||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, ElementO>;
|
||||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||||
|
|
||||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||||
@ -96,31 +99,38 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
||||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||||
|
|
||||||
|
static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO);
|
||||||
|
static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO");
|
||||||
|
static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO;
|
||||||
|
static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO");
|
||||||
|
|
||||||
using GmemLayoutAtom = Layout<
|
using GmemLayoutAtom = Layout<
|
||||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||||
|
|
||||||
|
|
||||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||||
Copy_Atom<Gmem_copy_struct, Element>{},
|
Copy_Atom<Gmem_copy_struct, Element>{},
|
||||||
GmemLayoutAtom{},
|
GmemLayoutAtom{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 vals per read
|
||||||
|
|
||||||
using GmemLayoutAtomO = Layout<
|
using GmemLayoutAtomO = Layout<
|
||||||
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
Shape<Int<kNThreadsS / kGmemThreadsPerRowO>, Int<kGmemThreadsPerRowO>>,
|
||||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
Stride<Int<kGmemThreadsPerRowO>, _1>>;
|
||||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementO>{},
|
||||||
GmemLayoutAtomO{},
|
GmemLayoutAtomO{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
Layout<Shape<_1, Int<kGmemElemsPerLoadO>>>{})); // Val layout, 8 vals per store
|
||||||
|
|
||||||
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
||||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum;
|
||||||
using GmemLayoutAtomOaccum = Layout<
|
using GmemLayoutAtomOaccum = Layout<
|
||||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
||||||
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
||||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
||||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||||
GmemLayoutAtomOaccum{},
|
GmemLayoutAtomOaccum{},
|
||||||
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
@ -597,12 +607,12 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream
|
|||||||
CHECK_CUDA_KERNEL_LAUNCH();
|
CHECK_CUDA_KERNEL_LAUNCH();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int Headdim>
|
template<typename T, typename To, int Headdim>
|
||||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||||
static_assert(Headdim == 576);
|
static_assert(Headdim == 576);
|
||||||
FLASH_ASSERT(params.d_v == 512);
|
FLASH_ASSERT(params.d_v == 512);
|
||||||
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
|
||||||
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
|
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>;
|
||||||
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ static constexpr int TileSchedulerMetaDataSize = 8;
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T, int Headdim>
|
template<typename T, typename To, int Headdim>
|
||||||
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
struct Mla_metadata_params {
|
struct Mla_metadata_params {
|
||||||
|
Loading…
Reference in New Issue
Block a user