fix combine

This commit is contained in:
chenhongmin.will 2025-02-28 18:45:09 +08:00
parent fd1e662deb
commit bfe38ab106

View File

@ -688,7 +688,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream
dim3 grid_combine(params.b * params.h * params.seqlen_q);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();