fp8 shared mem

This commit is contained in:
chenhongmin.will 2025-02-25 11:08:28 +08:00
parent b67a18f850
commit fed0499301
2 changed files with 9 additions and 1 deletions

View File

@ -80,6 +80,10 @@ struct Flash_fwd_kernel_traits_mla {
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtMMa = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
@ -139,10 +143,14 @@ using namespace cute;
template<typename Kernel_traits>
struct SharedStorageMLA {
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa> * 2>,
cute::array_aligned<typename Kernel_traits::Element, 0>>;
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
SmemV_t smem_vt; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
};

View File

@ -37,7 +37,7 @@ ext_modules.append(
sources=[
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_fp8_sm90.cu",
#"csrc/flash_fwd_mla_fp8_sm90.cu",
],
extra_compile_args={
"cxx": cxx_args,