mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
use mm1's Aregs instead of mma0's Cregs
This commit is contained in:
parent
1757a6db07
commit
d1689ab64f
@ -100,7 +100,11 @@ struct Flash_fwd_kernel_traits_mla {
|
|||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
Shape<Int<kHeadDimV>, Int<kBlockN> >{}));
|
||||||
|
|
||||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
using SmemLayoutP = std::conditional_t<
|
||||||
|
Is_FP8,
|
||||||
|
Layout<Shape<Shape<_4, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 32>>>,
|
||||||
|
Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, _2, Int<kBlockN / 16>>>
|
||||||
|
>;
|
||||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
||||||
|
|
||||||
using SmemLayoutAtomO = decltype(composition(
|
using SmemLayoutAtomO = decltype(composition(
|
||||||
@ -297,7 +301,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||||
|
|
||||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
||||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _);
|
||||||
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||||
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
||||||
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
||||||
@ -368,8 +372,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||||
|
|
||||||
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
||||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaO>(tSrS.layout()));
|
||||||
cute::copy(rP, tPsP);
|
Tensor tOrP = make_tensor_like<Element>(tOrP_acc);
|
||||||
|
convert_type_out(tOrP_acc, tOrP);
|
||||||
|
|
||||||
|
cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0
|
||||||
cute::copy(scale_o, tScale_osScale_o);
|
cute::copy(scale_o, tScale_osScale_o);
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||||
@ -380,7 +387,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
|
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
@ -529,15 +535,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
}
|
}
|
||||||
|
|
||||||
typename Kernel_traits::TiledMma tiled_mma;
|
typename Kernel_traits::TiledMma tiled_mma;
|
||||||
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
auto tSrS_layout = flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout());
|
||||||
Tensor rP = make_tensor<Element>(tSrS_layout);
|
Tensor tOrP = make_tensor<Element>(tSrS_layout);
|
||||||
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
||||||
cute::copy(tScale_osScale_o, scale_o);
|
cute::copy(tScale_osScale_o, scale_o);
|
||||||
cute::copy(tPsP, rP);
|
cute::copy(tPsP, tOrP);
|
||||||
|
|
||||||
flash::rescale_o(tOrO, scale_o);
|
flash::rescale_o(tOrO, scale_o);
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
|
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
|
16
csrc/utils.h
16
csrc/utils.h
@ -255,4 +255,20 @@ CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename Engine, typename Layout, typename EngineOut>
|
||||||
|
CUTLASS_DEVICE void convert_type_out(Tensor<Engine, Layout> const &tensor, Tensor<EngineOut, Layout> &out) {
|
||||||
|
// Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong.
|
||||||
|
using From_type = typename Engine::value_type;
|
||||||
|
using To_type = typename EngineOut::value_type;
|
||||||
|
static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type));
|
||||||
|
static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly");
|
||||||
|
Tensor frag = recast<cutlass::Array<From_type, FragmentSize> const>(tensor);
|
||||||
|
Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);
|
||||||
|
static_assert(size(frag) == size(out_frg));
|
||||||
|
cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); }
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace flash
|
} // namespace flash
|
||||||
|
Loading…
Reference in New Issue
Block a user