From bfe38ab10649b55c275d454afe93a346e8d5bf20 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 18:45:09 +0800 Subject: [PATCH] fix combine --- csrc/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index ad52b3c..bef20ee 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -688,7 +688,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream dim3 grid_combine(params.b * params.h * params.seqlen_q); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< - typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH();