mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
reorg
This commit is contained in:
parent
1df91aff33
commit
0337732dc1
@ -150,10 +150,11 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
|
||||
|
||||
// ------ for f8 ------
|
||||
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, Element>;
|
||||
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
|
||||
// using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||
// getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
// Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
|
||||
};
|
||||
|
||||
namespace flash {
|
||||
@ -292,9 +293,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||
|
||||
auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||
auto sVt = [&](){
|
||||
if constexpr(Kernel_traits::Is_FP8){
|
||||
return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{});
|
||||
} else {
|
||||
return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
}
|
||||
}();
|
||||
|
||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
|
||||
@ -381,8 +386,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
|
||||
if constexpr (Kernel_traits::Is_FP8) {
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
@ -504,6 +509,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
cute::copy(tPsP, tOrP);
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
if constexpr (Kernel_traits::Is_FP8) __syncthreads();
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
|
@ -1,13 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
template <int kBlockN, int kHeadDim, typename Element>
|
||||
template <int kBlockN, int kHeadDim, typename SmemLayoutK>
|
||||
struct SmemTransposeFp8_64x64 {
|
||||
static_assert(sizeof(Element) == 1);
|
||||
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
|
||||
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
GMMA::Layout_K_SW64_Atom<Element>{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
using Element = cutlass::float_e4m3_t;
|
||||
using SmemLayoutV = decltype(composition(
|
||||
SmemLayoutK{},
|
||||
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));
|
||||
|
Loading…
Reference in New Issue
Block a user