mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
use mm1's Aregs instead of mma0's Cregs
This commit is contained in:
parent
1757a6db07
commit
d1689ab64f
@ -100,7 +100,11 @@ struct Flash_fwd_kernel_traits_mla {
|
||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||
|
||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
||||
using SmemLayoutP = std::conditional_t<
|
||||
Is_FP8,
|
||||
Layout<Shape<Shape<_4, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 32>>>,
|
||||
Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 16>>>
|
||||
>;
|
||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||
|
||||
using SmemLayoutAtomO = decltype(composition(
|
||||
@ -297,7 +301,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||
|
||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
|
||||
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
||||
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||
@ -368,8 +372,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||
|
||||
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
||||
cute::copy(rP, tPsP);
|
||||
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
|
||||
Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
|
||||
convert_type_out(tOrP_acc, tOrP);
|
||||
|
||||
cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0
|
||||
cute::copy(scale_o, tScale_osScale_o);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
@ -380,7 +387,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||
}
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
@ -529,15 +535,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
}
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
||||
Tensor rP = make_tensor<Element>(tSrS_layout);
|
||||
auto tSrS_layout = flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout());
|
||||
Tensor tOrP = make_tensor<Element>(tSrS_layout);
|
||||
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
||||
cute::copy(tScale_osScale_o, scale_o);
|
||||
cute::copy(tPsP, rP);
|
||||
cute::copy(tPsP, tOrP);
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
// Double buffer for sK
|
||||
|
16
csrc/utils.h
16
csrc/utils.h
@ -255,4 +255,20 @@ CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout, typename EngineOut>
|
||||
CUTLASS_DEVICE void convert_type_out(Tensor<Engine, Layout> const &tensor, Tensor<EngineOut, Layout> &out) {
|
||||
// Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong.
|
||||
using From_type = typename Engine::value_type;
|
||||
using To_type = typename EngineOut::value_type;
|
||||
static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type));
|
||||
static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly");
|
||||
Tensor frag = recast<cutlass::Array<From_type, FragmentSize> const>(tensor);
|
||||
Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);
|
||||
static_assert(size(frag) == size(out_frg));
|
||||
cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); }
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
Loading…
Reference in New Issue
Block a user