mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
try fix
This commit is contained in:
parent
dbd8c307eb
commit
1757a6db07
@ -367,6 +367,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
||||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
||||||
|
|
||||||
|
if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); }
|
||||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
Tensor rP = flash::convert_type<Element>(tSrS);
|
||||||
cute::copy(rP, tPsP);
|
cute::copy(rP, tPsP);
|
||||||
cute::copy(scale_o, tScale_osScale_o);
|
cute::copy(scale_o, tScale_osScale_o);
|
||||||
@ -379,7 +380,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
@ -536,7 +537,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
flash::rescale_o(tOrO, scale_o);
|
flash::rescale_o(tOrO, scale_o);
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaO>(rP.layout()));
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
|
20
csrc/utils.h
20
csrc/utils.h
@ -235,4 +235,24 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename Fragment>
|
||||||
|
CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {
|
||||||
|
// frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits
|
||||||
|
static_assert(decltype(size<0, 0>(frag))::value == 2);
|
||||||
|
static_assert(decltype(size<0, 1>(frag))::value == 2);
|
||||||
|
static_assert(decltype(size<0, 2>(frag))::value % 2 == 0);
|
||||||
|
static_assert(decltype(stride<0, 0>(frag))::value == 1);
|
||||||
|
static_assert(sizeof(typename Fragment::value_type) == 4);
|
||||||
|
Tensor frag_64b = group_modes<1, 3>(recast<uint2>(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N))
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<1>(frag_64b); ++mi) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) {
|
||||||
|
cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace flash
|
} // namespace flash
|
||||||
|
Loading…
Reference in New Issue
Block a user