mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
use 64x64 transpose_v
This commit is contained in:
parent
d1689ab64f
commit
855c985b00
@ -12,6 +12,7 @@ using namespace cute;
|
|||||||
#include "softmax.h"
|
#include "softmax.h"
|
||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
#include "flash_mla.h"
|
#include "flash_mla.h"
|
||||||
|
#include "fp8_transpose_v.h"
|
||||||
|
|
||||||
|
|
||||||
template<typename PrecType, int DIM, int DIM2 = DIM, cute::GMMA::Major major = GMMA::Major::K>
|
template<typename PrecType, int DIM, int DIM2 = DIM, cute::GMMA::Major major = GMMA::Major::K>
|
||||||
@ -86,20 +87,11 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
// ------ for f16 ------
|
|
||||||
using SmemLayoutV = decltype(tile_to_shape(
|
using SmemLayoutV = decltype(tile_to_shape(
|
||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||||
|
|
||||||
// ------ for f8 ------
|
|
||||||
using SmemLayoutVtLoad = decltype(tile_to_shape(
|
|
||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV, GMMA::Major::MN>(),
|
|
||||||
Shape<Int<kHeadDimV>, Int<kBlockN>>{}));
|
|
||||||
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
|
||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
|
||||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
|
||||||
|
|
||||||
using SmemLayoutP = std::conditional_t<
|
using SmemLayoutP = std::conditional_t<
|
||||||
Is_FP8,
|
Is_FP8,
|
||||||
Layout<Shape<Shape<_4, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 32>>>,
|
Layout<Shape<Shape<_4, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 32>>>,
|
||||||
@ -155,6 +147,13 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||||
GmemLayoutAtomOaccum{},
|
GmemLayoutAtomOaccum{},
|
||||||
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
Layout<Shape<_1, Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
||||||
|
|
||||||
|
|
||||||
|
// ------ for f8 ------
|
||||||
|
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||||
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
|
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||||
|
using SmemFp8Tranpose = SmemTransposeFp8_64x64<kBlockN, kHeadDimV, Element>;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
@ -292,10 +291,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||||
|
|
||||||
auto sV = cute::conditional_return<Kernel_traits::Is_FP8>(
|
auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
||||||
cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtLoad{})),
|
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}));
|
|
||||||
|
|
||||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||||
@ -438,9 +434,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
if constexpr (!Kernel_traits::Is_FP8) {
|
if constexpr (!Kernel_traits::Is_FP8) {
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
sV.data() = sV.data() + sK_offset;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to clear the sK smem tiles because K is V.
|
// We need to clear the sK smem tiles because K is V.
|
||||||
@ -474,53 +467,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
if constexpr (Kernel_traits::Is_FP8) {
|
if constexpr (Kernel_traits::Is_FP8) {
|
||||||
auto TransV = [&]() {
|
auto TransV = [&]() {
|
||||||
// refer to fa3's TransV: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L697
|
using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose;
|
||||||
using LDSM_divide_shape = Shape<_64, _8>;
|
SmemFp8Tranpose smem_transpose_V;
|
||||||
using S2RTiledCopyVt = decltype(make_tiled_copy(
|
Tensor sV_divide = as_position_independent_swizzle_tensor(
|
||||||
Copy_Atom<SM75_U16x8_LDSM_T, Element>{},
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{}));
|
||||||
Layout<Shape<_32, _4, _1, _1>, Stride<_4, _1, _0, _0>>{}, // thread layout
|
Tensor sVt_divide = as_position_independent_swizzle_tensor(
|
||||||
Layout<Shape<_2, _2, _1, _4>, Stride<_1, _2, _16, _4>>{} // val layout
|
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{}));
|
||||||
));
|
|
||||||
|
|
||||||
using STSM_divide_shape = Shape<_8, _16>;
|
if (n_block % 2 == 1) {
|
||||||
using R2STiledCopyV = decltype(make_tiled_copy(
|
sV_divide.data() = sV_divide.data() + size(sK);
|
||||||
Copy_Atom<SM90_U32x4_STSM_N, Element>{},
|
}
|
||||||
Layout<Shape<_8, _4, _4, _1>, Stride<_4, _1, _32, _0>>{}, // thread layout
|
|
||||||
Layout<Shape<_1, _4, _2, _2>, Stride<_0, _1, _4, _8>>{} // val layout
|
CUTLASS_PRAGMA_UNROLL
|
||||||
));
|
for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
S2RTiledCopyVt s2r_tiled_copy_vt;
|
for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) {
|
||||||
R2STiledCopyV r2s_tiled_copy_v;
|
smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j)));
|
||||||
auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx);
|
|
||||||
auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx);
|
|
||||||
// flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8)
|
|
||||||
Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sV, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32)
|
|
||||||
// flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64))
|
|
||||||
Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sVt, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64))
|
|
||||||
CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_));
|
|
||||||
CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_));
|
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));
|
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_));
|
|
||||||
CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_));
|
|
||||||
CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_));
|
|
||||||
|
|
||||||
static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1;
|
|
||||||
Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_)>(tTranssVt_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2))
|
|
||||||
Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_)>(tTranssV_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2))
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size<1, 1>(tTranssVt); ++i) {
|
|
||||||
Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{})));
|
|
||||||
static_assert(size<0>(tTransrV) == 16);
|
|
||||||
Tensor tTransrV_64 = recast<uint2>(tTransrV);
|
|
||||||
cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i)), tTransrV);
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < size(tTransrV_64); ++j) {
|
|
||||||
uint32_t upper = tTransrV_64[j].x;
|
|
||||||
uint32_t lower = tTransrV_64[j].y;
|
|
||||||
tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420);
|
|
||||||
tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531);
|
|
||||||
}
|
}
|
||||||
cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -548,8 +511,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||||
if constexpr (!Kernel_traits::Is_FP8) {
|
if constexpr (!Kernel_traits::Is_FP8) {
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
} else {
|
|
||||||
sV.data() = sV.data() + sK_offset;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
82
csrc/fp8_transpose_v.h
Normal file
82
csrc/fp8_transpose_v.h
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
template <int kBlockN, int kHeadDim, typename Element>
|
||||||
|
struct SmemTransposeFp8_64x64 {
|
||||||
|
static_assert(sizeof(Element) == 1);
|
||||||
|
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
|
||||||
|
|
||||||
|
using SmemLayoutK = decltype(tile_to_shape(
|
||||||
|
GMMA::Layout_K_SW64_Atom<Element>{},
|
||||||
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||||
|
using SmemLayoutV = decltype(composition(
|
||||||
|
SmemLayoutK{},
|
||||||
|
Layout<Shape<Int<kBlockN>, Int<kHeadDim>>, Stride<_1, Int<kBlockN>>>{}));
|
||||||
|
using TransposeShapeAtomV = Shape<_64, _64>;
|
||||||
|
|
||||||
|
// for fp8 in-kernel transpose -- src layout
|
||||||
|
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
|
||||||
|
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
|
||||||
|
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
|
||||||
|
shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
|
||||||
|
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
|
||||||
|
|
||||||
|
// For fp8, this is the memory transpose.
|
||||||
|
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
|
||||||
|
using SmemLayoutVt = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomVt{},
|
||||||
|
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||||
|
|
||||||
|
// for fp8 in-kernel transpose -- dst layout
|
||||||
|
using SmemLayoutVtTrans = decltype(composition(
|
||||||
|
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
|
||||||
|
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
|
||||||
|
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
|
||||||
|
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}),
|
||||||
|
shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
|
||||||
|
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
|
||||||
|
|
||||||
|
|
||||||
|
using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
|
||||||
|
using ldsm_value_shape = Shape<_2, _8, _2, _1>;
|
||||||
|
using ldsm_value_stride = Stride<_2, _4, _1, _0>;
|
||||||
|
using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
|
||||||
|
Layout<ldsm_value_shape, ldsm_value_stride>{}));
|
||||||
|
TiledCopyLDSM tiled_copy_ldsm;
|
||||||
|
|
||||||
|
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
|
||||||
|
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
|
||||||
|
using stsm_value_shape = Shape<_4, _4, _1, _2>;
|
||||||
|
using stsm_value_stride = Stride<_1, _8, _0, _4>;
|
||||||
|
|
||||||
|
using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
|
||||||
|
Layout<stsm_value_shape, stsm_value_stride>{}));
|
||||||
|
TiledCopySTSM tiled_copy_stsm;
|
||||||
|
|
||||||
|
template <class SmemTensor, class SmemTensorOut>
|
||||||
|
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
auto tid = threadIdx.x;
|
||||||
|
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
|
||||||
|
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
|
||||||
|
|
||||||
|
auto tXsX = thr_copy_ldsm.partition_S(s_in);
|
||||||
|
auto tXrX = make_tensor<Element>(shape(tXsX));
|
||||||
|
auto tXsX_out = thr_copy_stsm.partition_D(s_out);
|
||||||
|
|
||||||
|
cute::copy(tiled_copy_ldsm, tXsX, tXrX);
|
||||||
|
|
||||||
|
auto data = tXrX.data();
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int n = 0; n < size(tXrX); n += 8) {
|
||||||
|
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
|
||||||
|
auto upper = data_32bit[0];
|
||||||
|
auto lower = data_32bit[1];
|
||||||
|
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
|
||||||
|
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
|
||||||
|
}
|
||||||
|
|
||||||
|
cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
Loading…
Reference in New Issue
Block a user