This commit is contained in:
chenhongmin.will 2025-02-27 09:11:17 +08:00
parent dbd8c307eb
commit 1757a6db07
2 changed files with 23 additions and 2 deletions

View File

@ -367,6 +367,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
Tensor rP = flash::convert_type<Element>(tSrS);
cute::copy(rP, tPsP);
cute::copy(scale_o, tScale_osScale_o);
@ -379,7 +380,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
}
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
@ -536,7 +537,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK

View File

@ -235,4 +235,24 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Fragment>
CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
// frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
static_assert(decltype(size<0, 0>(frag))::value == 2);
static_assert(decltype(size<0, 1>(frag))::value == 2);
static_assert(decltype(size<0, 2>(frag))::value % 2 == 0);
static_assert(decltype(stride<0, 0>(frag))::value == 1);
static_assert(sizeof(typename Fragment::value_type) == 4);
Tensor frag_64b = group_modes<1, 3>(recast<uint2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
#pragma unroll
for (int mi = 0; mi < size<1>(frag_64b); ++mi) {
#pragma unroll
for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) {
cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash