mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix compile
This commit is contained in:
parent
855c985b00
commit
1df91aff33
@ -395,7 +395,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
} else {
|
||||
const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||
int cur_block_table = __ldg(&block_table[n_block]);
|
||||
|
||||
|
@ -16,8 +16,7 @@ struct SmemTransposeFp8_64x64 {
|
||||
// for fp8 in-kernel transpose -- src layout
|
||||
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
|
||||
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
|
||||
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
|
||||
shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
|
||||
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{})));
|
||||
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
|
||||
|
||||
// For fp8, this is the memory transpose.
|
||||
@ -28,11 +27,10 @@ struct SmemTransposeFp8_64x64 {
|
||||
|
||||
// for fp8 in-kernel transpose -- dst layout
|
||||
using SmemLayoutVtTrans = decltype(composition(
|
||||
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
|
||||
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
|
||||
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
|
||||
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
|
||||
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}),
|
||||
shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
|
||||
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
|
||||
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user