diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index ad52b3c..bef20ee 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -688,7 +688,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream dim3 grid_combine(params.b * params.h * params.seqlen_q); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< - typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH();