fix compile

This commit is contained in:
chenhongmin.will 2025-02-27 23:40:02 +08:00
parent 855c985b00
commit 1df91aff33
2 changed files with 3 additions and 6 deletions

View File

@ -395,7 +395,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
} else {
const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);

View File

@ -16,8 +16,7 @@ struct SmemTransposeFp8_64x64 {
// for fp8 in-kernel transpose -- src layout
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{})));
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
// For fp8, this is the memory transpose.
@ -28,11 +27,10 @@ struct SmemTransposeFp8_64x64 {
// for fp8 in-kernel transpose -- dst layout
using SmemLayoutVtTrans = decltype(composition(
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}),
shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));