diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a1cb8e..1f44b68 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -186,7 +186,7 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + run_mha_fwd_splitkv_mla(params, stream); out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu index 35691f2..4990c48 100644 --- a/csrc/flash_fwd_mla_bf16_sm90.cu +++ b/csrc/flash_fwd_mla_bf16_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_fp8_sm90.cu b/csrc/flash_fwd_mla_fp8_sm90.cu index 2384a30..b678962 100644 --- a/csrc/flash_fwd_mla_fp8_sm90.cu +++ b/csrc/flash_fwd_mla_fp8_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 9262632..e83e9cc 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -28,9 +28,10 @@ constexpr auto getSmemLayoutK() { } } -template +template struct Flash_fwd_kernel_traits_mla { using Element = elem_type; + using ElementO = elem_type_o; using ElementAccum = float; using index_t = int64_t; @@ -48,8 +49,10 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3; static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; @@ -81,12 +84,12 @@ struct Flash_fwd_kernel_traits_mla { using SmemLayoutRow = Layout>, Stride<_1, _2>>; using SmemLayoutAtomO = decltype(composition( - Swizzle{}, - Layout, Int>, Stride, _1>>{})); + Swizzle{}, + Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -96,31 +99,38 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO); + static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO"); + static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO; + static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO"); + using GmemLayoutAtom = Layout< Shape, Int>, Stride, _1>>; + + using GmemTiledCopy = decltype(make_tiled_copy( Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>>{})); // Val layout, 8 vals per read using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; + Shape, Int>, + Stride, _1>>; using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom, ElementO>{}, GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store + Layout>>{})); // Val layout, 8 vals per store static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); - static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum; using GmemLayoutAtomOaccum = Layout< Shape, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store + Layout>>{})); // Val layout, 4 vals per store }; namespace flash { @@ -597,12 +607,12 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { static_assert(Headdim == 576); FLASH_ASSERT(params.d_v == 512); FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV - using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla>(params, stream); } diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h index 2994cb7..a2ef414 100644 --- a/csrc/flash_mla.h +++ b/csrc/flash_mla.h @@ -47,7 +47,7 @@ static constexpr int TileSchedulerMetaDataSize = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); struct Mla_metadata_params {