use mm1's Aregs instead of mma0's Cregs

This commit is contained in:
chenhongmin.will 2025-02-27 10:56:43 +08:00
parent 1757a6db07
commit d1689ab64f
2 changed files with 30 additions and 10 deletions

View File

@ -100,7 +100,11 @@ struct Flash_fwd_kernel_traits_mla {
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
using SmemLayoutP = std::conditional_t<
Is_FP8,
Layout<Shape<Shape<_4, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 32>>>,
Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 16>>>
>;
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
using SmemLayoutAtomO = decltype(composition(
@ -297,7 +301,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
@ -368,8 +372,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
Tensor rP = flash::convert_type<Element>(tSrS);
cute::copy(rP, tPsP);
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
convert_type_out(tOrP_acc, tOrP);
cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0
cute::copy(scale_o, tScale_osScale_o);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
@ -380,7 +387,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
}
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
@ -529,15 +535,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
}
typename Kernel_traits::TiledMma tiled_mma;
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
Tensor rP = make_tensor<Element>(tSrS_layout);
auto tSrS_layout = flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout());
Tensor tOrP = make_tensor<Element>(tSrS_layout);
Tensor scale_o = make_tensor<float>(Shape<_2>{});
cute::copy(tScale_osScale_o, scale_o);
cute::copy(tPsP, rP);
cute::copy(tPsP, tOrP);
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK

View File

@ -255,4 +255,20 @@ CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout, typename EngineOut>
CUTLASS_DEVICE void convert_type_out(Tensor<Engine, Layout> const &tensor, Tensor<EngineOut, Layout> &out) {
// Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong.
using From_type = typename Engine::value_type;
using To_type = typename EngineOut::value_type;
static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type));
static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly");
Tensor frag = recast<cutlass::Array<From_type, FragmentSize> const>(tensor);
Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);
static_assert(size(frag) == size(out_frg));
cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;
#pragma unroll
for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash