mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix sV
This commit is contained in:
parent
6dcea4952c
commit
dbd8c307eb
@ -14,18 +14,29 @@ using namespace cute;
|
||||
#include "flash_mla.h"
|
||||
|
||||
|
||||
template<typename PrecType, int DIM, int DIM2 = DIM>
|
||||
template<typename PrecType, int DIM, int DIM2 = DIM, cute::GMMA::Major major = GMMA::Major::K>
|
||||
constexpr auto getSmemLayoutK() {
|
||||
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
||||
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
||||
|
||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||
return GMMA::Layout_K_SW128_Atom<PrecType>{};
|
||||
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
||||
return GMMA::Layout_K_SW64_Atom<PrecType>{};
|
||||
if constexpr (major == GMMA::Major::K) {
|
||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||
return GMMA::Layout_K_SW128_Atom<PrecType>{};
|
||||
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
||||
return GMMA::Layout_K_SW64_Atom<PrecType>{};
|
||||
} else {
|
||||
return GMMA::Layout_K_SW32_Atom<PrecType>{};
|
||||
}
|
||||
} else {
|
||||
return GMMA::Layout_K_SW32_Atom<PrecType>{};
|
||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||
return GMMA::Layout_MN_SW128_Atom<PrecType>{};
|
||||
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
|
||||
return GMMA::Layout_MN_SW64_Atom<PrecType>{};
|
||||
} else {
|
||||
return GMMA::Layout_MN_SW32_Atom<PrecType>{};
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, typename elem_type_o = cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
||||
@ -75,11 +86,16 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// ------ for f16 ------
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
|
||||
// ------ for f8 ------
|
||||
using SmemLayoutVtLoad = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV, GMMA::Major::MN>(),
|
||||
Shape<Int<kHeadDimV>, Int<kBlockN>>{}));
|
||||
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||
@ -135,11 +151,6 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
||||
|
||||
|
||||
|
||||
// for fp8 trans-v
|
||||
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
@ -278,7 +289,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||
|
||||
auto sV = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||
cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{})),
|
||||
cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtLoad{})),
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}));
|
||||
|
||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||
@ -476,9 +487,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx);
|
||||
auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx);
|
||||
// flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8)
|
||||
Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32)
|
||||
Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sV, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32)
|
||||
// flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64))
|
||||
Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64))
|
||||
Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sVt, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64))
|
||||
CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_));
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));
|
||||
|
Loading…
Reference in New Issue
Block a user