mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix Vt illegal
This commit is contained in:
parent
29de9e0c79
commit
59f691763e
@ -149,13 +149,13 @@ using namespace cute;
|
||||
template<typename Kernel_traits>
|
||||
struct SharedStorageMLA {
|
||||
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa> * 2>,
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa>>,
|
||||
cute::array_aligned<typename Kernel_traits::Element, 0>>;
|
||||
union {
|
||||
struct {
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||
SmemV_t smem_vt; // Double buffer
|
||||
SmemV_t smem_vt;
|
||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||
};
|
||||
@ -309,7 +309,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
@ -366,7 +366,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
||||
@ -408,7 +408,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
// Double buffer for sK
|
||||
constexpr int sK_offset = size(sK);
|
||||
tKsK.data() = tKsK.data() + sK_offset;
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
// We need to clear the sK smem tiles because K is V.
|
||||
@ -460,7 +460,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
||||
|
||||
// Double buffer for sK
|
||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||
}
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||
|
Loading…
Reference in New Issue
Block a user