mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fp8 shared mem
This commit is contained in:
parent
b67a18f850
commit
fed0499301
@ -80,6 +80,10 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
||||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||||
|
|
||||||
|
using SmemLayoutVtMMa = decltype(tile_to_shape(
|
||||||
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
|
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||||
|
|
||||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
||||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||||
|
|
||||||
@ -139,10 +143,14 @@ using namespace cute;
|
|||||||
|
|
||||||
template<typename Kernel_traits>
|
template<typename Kernel_traits>
|
||||||
struct SharedStorageMLA {
|
struct SharedStorageMLA {
|
||||||
|
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
|
||||||
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa> * 2>,
|
||||||
|
cute::array_aligned<typename Kernel_traits::Element, 0>>;
|
||||||
union {
|
union {
|
||||||
struct {
|
struct {
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||||
|
SmemV_t smem_vt; // Double buffer
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||||
};
|
};
|
||||||
|
|||||||
2
setup.py
2
setup.py
@ -37,7 +37,7 @@ ext_modules.append(
|
|||||||
sources=[
|
sources=[
|
||||||
"csrc/flash_api.cpp",
|
"csrc/flash_api.cpp",
|
||||||
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
"csrc/flash_fwd_mla_bf16_sm90.cu",
|
||||||
"csrc/flash_fwd_mla_fp8_sm90.cu",
|
#"csrc/flash_fwd_mla_fp8_sm90.cu",
|
||||||
],
|
],
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": cxx_args,
|
"cxx": cxx_args,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user