mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
reorg
This commit is contained in:
parent
1df91aff33
commit
0337732dc1
@ -150,10 +150,11 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
|
|
||||||
|
|
||||||
// ------ for f8 ------
|
// ------ for f8 ------
|
||||||
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, SmemLayoutK>;
|
||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
// using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
// getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, Element>;
|
// Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||||
|
using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
@ -292,9 +293,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||||
|
|
||||||
auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
auto sVt = [&](){
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
if constexpr(Kernel_traits::Is_FP8){
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{});
|
||||||
|
} else {
|
||||||
|
return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
|
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
|
||||||
@ -381,8 +386,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
if constexpr (Kernel_traits::Is_FP8) {
|
if constexpr (Kernel_traits::Is_FP8) {
|
||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
@ -504,6 +509,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
cute::copy(tPsP, tOrP);
|
cute::copy(tPsP, tOrP);
|
||||||
|
|
||||||
flash::rescale_o(tOrO, scale_o);
|
flash::rescale_o(tOrO, scale_o);
|
||||||
|
|
||||||
|
if constexpr (Kernel_traits::Is_FP8) __syncthreads();
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
template <int kBlockN, int kHeadDim, typename Element>
|
template <int kBlockN, int kHeadDim, typename SmemLayoutK>
|
||||||
struct SmemTransposeFp8_64x64 {
|
struct SmemTransposeFp8_64x64 {
|
||||||
static_assert(sizeof(Element) == 1);
|
|
||||||
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
|
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
|
||||||
|
|
||||||
using SmemLayoutK = decltype(tile_to_shape(
|
using Element = cutlass::float_e4m3_t;
|
||||||
GMMA::Layout_K_SW64_Atom<Element>{},
|
|
||||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
|
||||||
using SmemLayoutV = decltype(composition(
|
using SmemLayoutV = decltype(composition(
|
||||||
SmemLayoutK{},
|
SmemLayoutK{},
|
||||||
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));
|
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));
|
||||||
|
Loading…
Reference in New Issue
Block a user