mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add transv barrier
This commit is contained in:
parent
59f691763e
commit
6a4eb631e2
@ -360,6 +360,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
flash::rescale_o(tOrO, scale_o);
|
flash::rescale_o(tOrO, scale_o);
|
||||||
|
|
||||||
|
if constexpr (Kernel_traits::Is_FP8) {
|
||||||
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||||
|
}
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||||
|
|
||||||
@ -440,6 +444,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
cute::cp_async_fence();
|
cute::cp_async_fence();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (Kernel_traits::Is_FP8) {
|
||||||
|
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||||
|
}
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||||
|
|
||||||
if (n_block - 2 >= n_block_min) {
|
if (n_block - 2 >= n_block_min) {
|
||||||
|
@ -10,6 +10,7 @@ namespace flash {
|
|||||||
enum class NamedBarriers {
|
enum class NamedBarriers {
|
||||||
SReady = 1,
|
SReady = 1,
|
||||||
SoftmaxReady = 2,
|
SoftmaxReady = 2,
|
||||||
|
TransVReady = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
} // flash
|
} // flash
|
||||||
|
Loading…
Reference in New Issue
Block a user