This commit is contained in:
chenhongmin.will 2025-02-28 08:09:02 +08:00
parent 1df91aff33
commit 0337732dc1
2 changed files with 17 additions and 13 deletions

View File

@ -150,10 +150,11 @@ struct Flash_fwd_kernel_traits_mla {
// ------ for f8 ------
using SmemLayoutVtMMa = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, Element>;
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
// using SmemLayoutVtMMa = decltype(tile_to_shape(
// getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
// Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
};
namespace flash {
@ -292,9 +293,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
auto sVt = [&](){
if constexpr(Kernel_traits::Is_FP8){
return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{});
} else {
return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
}
}();
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
@ -381,8 +386,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
if constexpr (Kernel_traits::Is_FP8) {
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
__syncthreads();
}
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
@ -504,6 +509,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::copy(tPsP, tOrP);
flash::rescale_o(tOrO, scale_o);
if constexpr (Kernel_traits::Is_FP8) __syncthreads();
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK

View File

@ -1,13 +1,10 @@
#pragma once
template <int kBlockN, int kHeadDim, typename Element>
template <int kBlockN, int kHeadDim, typename SmemLayoutK>
struct SmemTransposeFp8_64x64 {
static_assert(sizeof(Element) == 1);
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
using SmemLayoutK = decltype(tile_to_shape(
GMMA::Layout_K_SW64_Atom<Element>{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using Element = cutlass::float_e4m3_t;
using SmemLayoutV = decltype(composition(
SmemLayoutK{},
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));