fix Vt illegal

This commit is contained in:
chenhongmin.will 2025-02-26 17:25:52 +08:00
parent 29de9e0c79
commit 59f691763e

View File

@ -149,13 +149,13 @@ using namespace cute;
template<typename Kernel_traits>
struct SharedStorageMLA {
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa> * 2>,
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa>>,
cute::array_aligned<typename Kernel_traits::Element, 0>>;
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
SmemV_t smem_vt; // Double buffer
SmemV_t smem_vt;
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
};
@ -309,7 +309,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
// Double buffer for sK
constexpr int sK_offset = size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
@ -366,7 +366,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cute::copy(softmax.row_max, tRow_maxsRow_max);
@ -408,7 +408,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
// Double buffer for sK
constexpr int sK_offset = size(sK);
tKsK.data() = tKsK.data() + sK_offset;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need to clear the sK smem tiles because K is V.
@ -460,7 +460,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tOrVt.data() = tOrVt.data() + sK_offset / 8;
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));