mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix Vt illegal
This commit is contained in:
parent
29de9e0c79
commit
59f691763e
@ -149,13 +149,13 @@ using namespace cute;
|
|||||||
template<typename Kernel_traits>
|
template<typename Kernel_traits>
|
||||||
struct SharedStorageMLA {
|
struct SharedStorageMLA {
|
||||||
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
|
using SmemV_t = std::conditional_t<Kernel_traits::Is_FP8,
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa> * 2>,
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutVtMMa>>,
|
||||||
cute::array_aligned<typename Kernel_traits::Element, 0>>;
|
cute::array_aligned<typename Kernel_traits::Element, 0>>;
|
||||||
union {
|
union {
|
||||||
struct {
|
struct {
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
||||||
SmemV_t smem_vt; // Double buffer
|
SmemV_t smem_vt;
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
||||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
||||||
};
|
};
|
||||||
@ -309,7 +309,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
constexpr int sK_offset = size(sK);
|
constexpr int sK_offset = size(sK);
|
||||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||||
@ -366,7 +366,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
tSrK.data() = tSrK.data() + sK_offset / 8;
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
||||||
@ -408,7 +408,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
constexpr int sK_offset = size(sK);
|
constexpr int sK_offset = size(sK);
|
||||||
tKsK.data() = tKsK.data() + sK_offset;
|
tKsK.data() = tKsK.data() + sK_offset;
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to clear the sK smem tiles because K is V.
|
// We need to clear the sK smem tiles because K is V.
|
||||||
@ -460,7 +460,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f
|
|||||||
|
|
||||||
// Double buffer for sK
|
// Double buffer for sK
|
||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user