mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fp8 shared mem
This commit is contained in:
parent
b67a18f850
commit
fed0499301
@ -80,6 +80,10 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
|
||||
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||
|
||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||
|
||||
@ -139,10 +143,14 @@ using namespace cute;
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct SharedStorageMLA {
|
||||
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa> * 2>,
|
||||
cute::array_aligned<typename Kernel_traits::Element, 0>>;
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||
SmemV_t smem_vt; // Double buffer
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user