mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix combine
This commit is contained in:
parent
fd1e662deb
commit
bfe38ab106
@ -688,7 +688,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream
|
|||||||
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
dim3 grid_combine(params.b * params.h * params.seqlen_q);
|
||||||
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
|
||||||
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
|
||||||
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
|
typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
|
||||||
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
|
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
|
||||||
});
|
});
|
||||||
CHECK_CUDA_KERNEL_LAUNCH();
|
CHECK_CUDA_KERNEL_LAUNCH();
|
||||||
|
Loading…
Reference in New Issue
Block a user