mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix compile
This commit is contained in:
parent
855c985b00
commit
1df91aff33
@ -395,7 +395,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
||||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||||
} else {
|
} else {
|
||||||
const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
|
||||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||||
int cur_block_table = __ldg(&block_table[n_block]);
|
int cur_block_table = __ldg(&block_table[n_block]);
|
||||||
|
|
||||||
|
@ -16,8 +16,7 @@ struct SmemTransposeFp8_64x64 {
|
|||||||
// for fp8 in-kernel transpose -- src layout
|
// for fp8 in-kernel transpose -- src layout
|
||||||
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
|
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
|
||||||
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
|
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
|
||||||
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
|
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{})));
|
||||||
shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
|
|
||||||
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
|
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
|
||||||
|
|
||||||
// For fp8, this is the memory transpose.
|
// For fp8, this is the memory transpose.
|
||||||
@ -28,11 +27,10 @@ struct SmemTransposeFp8_64x64 {
|
|||||||
|
|
||||||
// for fp8 in-kernel transpose -- dst layout
|
// for fp8 in-kernel transpose -- dst layout
|
||||||
using SmemLayoutVtTrans = decltype(composition(
|
using SmemLayoutVtTrans = decltype(composition(
|
||||||
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
|
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
|
||||||
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
|
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
|
||||||
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
|
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
|
||||||
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}),
|
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
|
||||||
shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
|
|
||||||
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
|
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user