diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 12e9883..3b3cd9c 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -367,6 +367,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f softmax.template softmax(tSrS, params.scale_softmax_log2) : softmax.template softmax(tSrS, params.scale_softmax_log2); + if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } Tensor rP = flash::convert_type(tSrS); cute::copy(rP, tPsP); cute::copy(scale_o, tScale_osScale_o); @@ -379,7 +380,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); } - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK @@ -536,7 +537,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f flash::rescale_o(tOrO, scale_o); - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK diff --git a/csrc/utils.h b/csrc/utils.h index 3b8dd52..854c75f 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -235,4 +235,24 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash