mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
add TransV
This commit is contained in:
parent
6a4eb631e2
commit
6dcea4952c
@ -276,7 +276,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
||||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
||||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
|
||||||
|
auto sV = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||||
|
cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{})),
|
||||||
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}));
|
||||||
|
|
||||||
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
auto sVt = cute::conditional_return<Kernel_traits::Is_FP8>(
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}),
|
||||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}));
|
||||||
@ -377,6 +381,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
||||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||||
} else {
|
} else {
|
||||||
|
const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
||||||
int cur_block_table = __ldg(&block_table[n_block]);
|
int cur_block_table = __ldg(&block_table[n_block]);
|
||||||
|
|
||||||
@ -412,7 +417,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
constexpr int sK_offset = size(sK);
|
constexpr int sK_offset = size(sK);
|
||||||
tKsK.data() = tKsK.data() + sK_offset;
|
tKsK.data() = tKsK.data() + sK_offset;
|
||||||
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
if constexpr (!Kernel_traits::Is_FP8) {
|
||||||
|
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sV.data() = sV.data() + sK_offset;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to clear the sK smem tiles because K is V.
|
// We need to clear the sK smem tiles because K is V.
|
||||||
@ -445,6 +455,58 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (Kernel_traits::Is_FP8) {
|
if constexpr (Kernel_traits::Is_FP8) {
|
||||||
|
auto TransV = [&]() {
|
||||||
|
// refer to fa3's TransV: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L697
|
||||||
|
using LDSM_divide_shape = Shape<_64, _8>;
|
||||||
|
using S2RTiledCopyVt = decltype(make_tiled_copy(
|
||||||
|
Copy_Atom<SM75_U16x8_LDSM_T, Element>{},
|
||||||
|
Layout<Shape<_32, _4, _1, _1>, Stride<_4, _1, _0, _0>>{}, // thread layout
|
||||||
|
Layout<Shape<_2, _2, _1, _4>, Stride<_1, _2, _16, _4>>{} // val layout
|
||||||
|
));
|
||||||
|
|
||||||
|
using STSM_divide_shape = Shape<_8, _16>;
|
||||||
|
using R2STiledCopyV = decltype(make_tiled_copy(
|
||||||
|
Copy_Atom<SM90_U32x4_STSM_N, Element>{},
|
||||||
|
Layout<Shape<_8, _4, _4, _1>, Stride<_4, _1, _32, _0>>{}, // thread layout
|
||||||
|
Layout<Shape<_1, _4, _2, _2>, Stride<_0, _1, _4, _8>>{} // val layout
|
||||||
|
));
|
||||||
|
|
||||||
|
S2RTiledCopyVt s2r_tiled_copy_vt;
|
||||||
|
R2STiledCopyV r2s_tiled_copy_v;
|
||||||
|
auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx);
|
||||||
|
auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx);
|
||||||
|
// flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8)
|
||||||
|
Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32)
|
||||||
|
// flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64))
|
||||||
|
Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64))
|
||||||
|
CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_));
|
||||||
|
CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_));
|
||||||
|
|
||||||
|
static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1;
|
||||||
|
Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_)>(tTranssVt_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2))
|
||||||
|
Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_)>(tTranssV_), Shape<Underscore, Int<Transpose_ILP>>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2))
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size<1, 1>(tTranssVt); ++i) {
|
||||||
|
Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{})));
|
||||||
|
static_assert(size<0>(tTransrV) == 16);
|
||||||
|
Tensor tTransrV_64 = recast<uint2>(tTransrV);
|
||||||
|
cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i)), tTransrV);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < size(tTransrV_64); ++j) {
|
||||||
|
uint32_t upper = tTransrV_64[j].x;
|
||||||
|
uint32_t lower = tTransrV_64[j].y;
|
||||||
|
tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420);
|
||||||
|
tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531);
|
||||||
|
}
|
||||||
|
cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TransV();
|
||||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::TransVReady));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -468,7 +530,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||||
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
if constexpr (!Kernel_traits::Is_FP8) {
|
||||||
|
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
|
} else {
|
||||||
|
sV.data() = sV.data() + sK_offset;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||||
|
Loading…
Reference in New Issue
Block a user