mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add transv barrier
This commit is contained in:
parent
59f691763e
commit
6a4eb631e2
@ -360,6 +360,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
|
||||
flash::rescale_o(tOrO, scale_o);
|
||||
|
||||
if constexpr (Kernel_traits::Is_FP8) {
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||
}
|
||||
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
||||
|
||||
@ -440,6 +444,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
if constexpr (Kernel_traits::Is_FP8) {
|
||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
||||
|
||||
if (n_block - 2 >= n_block_min) {
|
||||
|
@ -10,6 +10,7 @@ namespace flash {
|
||||
enum class NamedBarriers {
|
||||
SReady = 1,
|
||||
SoftmaxReady = 2,
|
||||
TransVReady = 3,
|
||||
};
|
||||
|
||||
} // flash
|
||||
|
Loading…
Reference in New Issue
Block a user