diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index e83e9cc..1af3eb7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -80,6 +80,10 @@ struct Flash_fwd_kernel_traits_mla { Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtMMa = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int >{})); + using SmemLayoutP = Layout, Int, _1, Int>>; using SmemLayoutRow = Layout>, Stride<_1, _2>>; @@ -139,10 +143,14 @@ using namespace cute; template struct SharedStorageMLA { + using SmemV_t = std::conditional_t * 2>, + cute::array_aligned>; union { struct { cute::array_aligned> smem_q; cute::array_aligned * 2> smem_k; // Double buffer + SmemV_t smem_vt; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_scale; }; diff --git a/setup.py b/setup.py index c622b7c..bfe931f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ ext_modules.append( sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp8_sm90.cu", + #"csrc/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args,