mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Stage accumulator fragment to shared memory using tiled copy
This commit is contained in:
parent
414a2f3eed
commit
9f361aa02e
@ -13,10 +13,14 @@ using namespace cute;
|
|||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
#include "flash_mla.h"
|
#include "flash_mla.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template<typename PrecType, int DIM, int DIM2 = DIM>
|
/// Helper: Decide K-Layout at SMEM level given type and dimension.
|
||||||
|
/// Swizzling is determined primarily by alignment constraints.
|
||||||
|
/// Return GMMA Layout at compile time.
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename PrecType, int DIM, int DIM2 = DIM>
|
||||||
constexpr auto getSmemLayoutK() {
|
constexpr auto getSmemLayoutK() {
|
||||||
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
|
||||||
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
|
||||||
|
|
||||||
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
|
||||||
@ -28,466 +32,462 @@ constexpr auto getSmemLayoutK() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
struct Flash_fwd_kernel_traits_mla {
|
/// Kernel Trait: FWD MLA for Flash Attention
|
||||||
using Element = elem_type;
|
/// - Templated on HeadDim (kHeadDim_), block tiling, warp usage, etc.
|
||||||
|
/// - Provides all necessary sub-layouts for Q/K/V, softmax partials, etc.
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <
|
||||||
|
int kHeadDim_,
|
||||||
|
int kBlockM_,
|
||||||
|
int kBlockN_,
|
||||||
|
int kNumWarps_,
|
||||||
|
typename ElemType = cutlass::bfloat16_t,
|
||||||
|
int kHeadDimV_ = 0
|
||||||
|
>
|
||||||
|
struct FlashFwdKernelTraitsMLA {
|
||||||
|
using Element = ElemType;
|
||||||
using ElementAccum = float;
|
using ElementAccum = float;
|
||||||
using index_t = int64_t;
|
using IndexT = int64_t;
|
||||||
|
|
||||||
static constexpr int kNWarps = kNWarps_;
|
// Warp organization
|
||||||
static constexpr int kNThreads = kNWarps * 32;
|
static constexpr int kNumWarps = kNumWarps_;
|
||||||
static constexpr int kNWarpsS = 4;
|
static constexpr int kNumThreads = kNumWarps * 32;
|
||||||
static constexpr int kNThreadsS = kNWarpsS * 32;
|
static constexpr int kNumWarpsSoftmax = 4;
|
||||||
|
static constexpr int kNumThreadsSoftmax = kNumWarpsSoftmax * 32;
|
||||||
|
|
||||||
static constexpr int kBlockM = kBlockM_;
|
// Tiling in M, N, K
|
||||||
static constexpr int kBlockN = kBlockN_;
|
static constexpr int kBlockM = kBlockM_;
|
||||||
static constexpr int kHeadDim = kHeadDim_;
|
static constexpr int kBlockN = kBlockN_;
|
||||||
|
static constexpr int kHeadDim = kHeadDim_;
|
||||||
static_assert(kHeadDim % 32 == 0);
|
static_assert(kHeadDim % 32 == 0);
|
||||||
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
|
|
||||||
|
// Possibly distinct V-dimension
|
||||||
|
static constexpr int kHeadDimV = (kHeadDimV_ != 0) ? kHeadDimV_ : kHeadDim;
|
||||||
static_assert(kHeadDimV % 32 == 0);
|
static_assert(kHeadDimV % 32 == 0);
|
||||||
static_assert(kHeadDimV <= kHeadDim);
|
static_assert(kHeadDimV <= kHeadDim);
|
||||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
|
||||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
|
||||||
|
|
||||||
|
// SMEM swizzling for partial K/V
|
||||||
|
static constexpr int kBlockKSmem = (kHeadDim % 64 == 0) ? 64 : 32;
|
||||||
|
static constexpr int kSwizzle = (kBlockKSmem == 32) ? 2 : 3;
|
||||||
|
|
||||||
|
// GMMA Tiled Mma
|
||||||
|
// Q*K -> S
|
||||||
using TiledMma = decltype(make_tiled_mma(
|
using TiledMma = decltype(make_tiled_mma(
|
||||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
cute::GMMA::ss_op_selector<
|
||||||
GMMA::Major::K, GMMA::Major::K>(),
|
Element, Element, ElementAccum,
|
||||||
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
|
Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
|
||||||
|
GMMA::Major::K, GMMA::Major::K
|
||||||
|
>(),
|
||||||
|
Layout<Shape<Int<kNumWarpsSoftmax / 4>, _1, _1>>{}
|
||||||
|
));
|
||||||
|
|
||||||
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
|
// S*V -> O
|
||||||
|
// For the O “outer product,” we define the shape in [M, HeadDimV, N].
|
||||||
|
static constexpr int AtomLayoutNO = kNumThreads / kNumThreadsSoftmax;
|
||||||
using TiledMmaO = decltype(make_tiled_mma(
|
using TiledMmaO = decltype(make_tiled_mma(
|
||||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
cute::GMMA::rs_op_selector<
|
||||||
GMMA::Major::K, GMMA::Major::MN>(),
|
Element, Element, ElementAccum,
|
||||||
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
|
Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
|
||||||
|
GMMA::Major::K, GMMA::Major::MN
|
||||||
|
>(),
|
||||||
|
Layout<Shape<Int<kNumWarpsSoftmax / 4>, Int<AtomLayoutNO>, _1>>{}
|
||||||
|
));
|
||||||
|
|
||||||
using SmemLayoutQ = decltype(tile_to_shape(
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// SMEM Layout definitions: Q/K/V, P, row-scale, etc.
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
using SmemLayoutQ = decltype(
|
||||||
|
tile_to_shape(
|
||||||
getSmemLayoutK<Element, kHeadDim>(),
|
getSmemLayoutK<Element, kHeadDim>(),
|
||||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
Shape<Int<kBlockM>, Int<kHeadDim>>{}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
using SmemLayoutK = decltype(tile_to_shape(
|
using SmemLayoutK = decltype(
|
||||||
|
tile_to_shape(
|
||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
using SmemLayoutV = decltype(tile_to_shape(
|
using SmemLayoutV = decltype(
|
||||||
|
tile_to_shape(
|
||||||
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
|
||||||
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
|
Shape<Int<kBlockN>, Int<kHeadDimV>>{}
|
||||||
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
|
)
|
||||||
|
);
|
||||||
|
using SmemLayoutVtransposed = decltype(
|
||||||
|
composition(
|
||||||
|
SmemLayoutV{},
|
||||||
|
make_layout(
|
||||||
|
Shape<Int<kHeadDimV>, Int<kBlockN>>{},
|
||||||
|
GenRowMajor{}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
|
// For partial S data (softmax region)
|
||||||
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
|
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNumThreadsSoftmax>, _1, Int<kBlockN / 8>>>;
|
||||||
|
using SmemLayoutRow = Layout<Shape<_2, Int<kNumThreadsSoftmax>>, Stride<_1, _2>>;
|
||||||
|
|
||||||
using SmemLayoutAtomO = decltype(composition(
|
// Layout for the O tile in smem
|
||||||
|
using SmemLayoutAtomO = decltype(
|
||||||
|
composition(
|
||||||
Swizzle<kSwizzle, 3, 3>{},
|
Swizzle<kSwizzle, 3, 3>{},
|
||||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
|
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}
|
||||||
using SmemLayoutO = decltype(tile_to_shape(
|
)
|
||||||
|
);
|
||||||
|
using SmemLayoutO = decltype(
|
||||||
|
tile_to_shape(
|
||||||
SmemLayoutAtomO{},
|
SmemLayoutAtomO{},
|
||||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
|
Shape<Int<kBlockM>, Int<kHeadDimV>>{}
|
||||||
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
)
|
||||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
);
|
||||||
|
|
||||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
/// Copy Atoms for SMEM read/write
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
|
||||||
|
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// GMEM Tiled Copies for Q/K/V
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||||
|
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must align with vector load size");
|
||||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||||
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
using GmemCopyStruct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
|
||||||
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
|
static constexpr int kNumThreadsLoad = kNumThreads - kNumThreadsSoftmax;
|
||||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
static_assert(kNumThreadsLoad % kGmemThreadsPerRow == 0, "Thread counts must match row partitions");
|
||||||
|
|
||||||
using GmemLayoutAtom = Layout<
|
using GmemLayoutAtom = Layout<
|
||||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
Shape<Int<kNumThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
Stride<Int<kGmemThreadsPerRow>, _1>
|
||||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
>;
|
||||||
Copy_Atom<Gmem_copy_struct, Element>{},
|
using GmemTiledCopy = decltype(
|
||||||
|
make_tiled_copy(
|
||||||
|
Copy_Atom<GmemCopyStruct, Element>{},
|
||||||
GmemLayoutAtom{},
|
GmemLayoutAtom{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
Layout<Shape<_1, _8>>{} // 8 vals per read
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// For storing O to GMEM
|
||||||
using GmemLayoutAtomO = Layout<
|
using GmemLayoutAtomO = Layout<
|
||||||
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
Shape<Int<kNumThreadsSoftmax / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
Stride<Int<kGmemThreadsPerRow>, _1>
|
||||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
>;
|
||||||
|
using GmemTiledCopyO = decltype(
|
||||||
|
make_tiled_copy(
|
||||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||||
GmemLayoutAtomO{},
|
GmemLayoutAtomO{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
Layout<Shape<_1, _8>>{} // 8 vals per store
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
// For accumulation path (split)
|
||||||
|
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
||||||
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
|
||||||
using GmemLayoutAtomOaccum = Layout<
|
using GmemLayoutAtomOaccum = Layout<
|
||||||
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
Shape<Int<kNumThreadsSoftmax / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
|
||||||
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
|
Stride<Int<kGmemThreadsPerRowAccum>, _1>
|
||||||
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
|
>;
|
||||||
|
using GmemTiledCopyOaccum = decltype(
|
||||||
|
make_tiled_copy(
|
||||||
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||||
GmemLayoutAtomOaccum{},
|
GmemLayoutAtomOaccum{},
|
||||||
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
|
Layout<Shape<_1, _4>>{} // 4 vals per store
|
||||||
|
)
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Shared Storage Container for MLA
|
||||||
|
/// - Re-used union across Q/K/P/O or row sums, etc.
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
namespace flash {
|
namespace flash {
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
template<typename Kernel_traits>
|
template <typename KernelTraits>
|
||||||
struct SharedStorageMLA {
|
struct SharedStorageMLA {
|
||||||
union {
|
union {
|
||||||
struct {
|
struct {
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
|
cute::array_aligned<typename KernelTraits::Element,
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
|
cute::cosize_v<typename KernelTraits::SmemLayoutQ>> smem_q;
|
||||||
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
|
cute::array_aligned<typename KernelTraits::Element,
|
||||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
|
cute::cosize_v<typename KernelTraits::SmemLayoutK> * 2> smem_k; // double buffer
|
||||||
|
cute::array_aligned<typename KernelTraits::Element,
|
||||||
|
cute::cosize_v<typename KernelTraits::SmemLayoutP>> smem_p;
|
||||||
|
cute::array_aligned<typename KernelTraits::ElementAccum,
|
||||||
|
cute::cosize_v<typename KernelTraits::SmemLayoutRow>> smem_scale;
|
||||||
};
|
};
|
||||||
struct {
|
struct {
|
||||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
|
cute::array_aligned<typename KernelTraits::ElementAccum,
|
||||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
|
cute::cosize_v<typename KernelTraits::SmemLayoutRow>> smem_max;
|
||||||
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
|
cute::array_aligned<typename KernelTraits::ElementAccum,
|
||||||
|
cute::cosize_v<typename KernelTraits::SmemLayoutRow>> smem_sum;
|
||||||
|
cute::array_aligned<typename KernelTraits::ElementAccum,
|
||||||
|
cute::cosize_v<typename KernelTraits::SmemLayoutO>> smem_o;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// store() Epilogue for partial or non-partial results
|
||||||
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
|
/// - Manages writing O/accumulation to global memory + writing out LSE for row block.
|
||||||
__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx,
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
|
template <
|
||||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
typename KernelTraits,
|
||||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
bool Split,
|
||||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
typename SharedStorage,
|
||||||
using Element = typename Kernel_traits::Element;
|
typename AccO,
|
||||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
typename Softmax
|
||||||
using index_t = typename Kernel_traits::index_t;
|
>
|
||||||
|
__forceinline__ __device__
|
||||||
|
void store(
|
||||||
|
const Flash_fwd_mla_params ¶ms,
|
||||||
|
const int batch_id,
|
||||||
|
const int head_id,
|
||||||
|
const int m_block,
|
||||||
|
const int n_split_idx,
|
||||||
|
SharedStorage &shared_storage,
|
||||||
|
AccO tOrO,
|
||||||
|
Softmax softmax
|
||||||
|
) {
|
||||||
|
constexpr int kBlockM = KernelTraits::kBlockM;
|
||||||
|
constexpr int kHeadDimV = KernelTraits::kHeadDimV;
|
||||||
|
constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax;
|
||||||
|
using Element = typename KernelTraits::Element;
|
||||||
|
using ElementAccum = typename KernelTraits::ElementAccum;
|
||||||
|
using IndexT = typename KernelTraits::IndexT;
|
||||||
|
|
||||||
const int tidx = threadIdx.x;
|
const int tidx = threadIdx.x;
|
||||||
|
|
||||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
typename KernelTraits::TiledMmaO tiled_mma_o;
|
||||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||||
|
|
||||||
// Epilogue
|
// Softmax LSE for final normalization
|
||||||
|
auto lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
||||||
const int split_offset = __ldg(params.num_splits_ptr + bidb);
|
|
||||||
|
|
||||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
|
|
||||||
|
|
||||||
|
// Decide if writing ephemeral partial results (float accumulation) or final (Element).
|
||||||
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
|
||||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
|
||||||
// Partition sO to match the accumulator partitioning
|
// Prepare SMEM for O
|
||||||
using SmemTiledCopyO = std::conditional_t<
|
Tensor sOaccum = make_tensor(
|
||||||
!Split,
|
make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())),
|
||||||
typename Kernel_traits::SmemCopyAtomO,
|
typename KernelTraits::SmemLayoutO{}
|
||||||
typename Kernel_traits::SmemCopyAtomOaccum
|
);
|
||||||
>;
|
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(
|
||||||
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
|
std::conditional_t<!Split,
|
||||||
|
typename KernelTraits::SmemCopyAtomO,
|
||||||
|
typename KernelTraits::SmemCopyAtomOaccum>{},
|
||||||
|
tiled_mma_o
|
||||||
|
);
|
||||||
|
|
||||||
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||||
Tensor rO = flash::convert_type<ElementO>(tOrO);
|
Tensor rO = flash::convert_type<ElementO>(tOrO);
|
||||||
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
Tensor taccOrO = smem_thr_copy_Oaccum.retile_S(rO);
|
||||||
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
Tensor taccOsO = smem_thr_copy_Oaccum.partition_D(sOaccum);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
cute::copy(smem_tiled_copy_Oaccum, taccOrO, taccOsO);
|
||||||
|
|
||||||
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
// Compute GMEM offsets
|
||||||
|
const IndexT row_offset_o = batch_id * params.o_batch_stride
|
||||||
|
+ m_block * kBlockM * params.o_row_stride
|
||||||
|
+ head_id * params.o_head_stride;
|
||||||
|
const IndexT row_offset_oaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx)
|
||||||
|
* params.h + head_id)
|
||||||
|
* params.seqlen_q + (m_block * kBlockM)) * params.d_v;
|
||||||
|
const IndexT row_offset_lse = (batch_id * params.h + head_id) * params.seqlen_q + m_block * kBlockM;
|
||||||
|
const IndexT row_offset_lseaccum = (((__ldg(params.num_splits_ptr + batch_id) + n_split_idx)
|
||||||
|
* params.h + head_id)
|
||||||
|
* params.seqlen_q + (m_block * kBlockM));
|
||||||
|
|
||||||
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
// Prepare GMEM for final or partial O
|
||||||
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
|
Tensor gOaccum = make_tensor(
|
||||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
make_gmem_ptr(
|
||||||
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr)
|
||||||
|
+ (Split ? row_offset_oaccum : row_offset_o)
|
||||||
|
),
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
|
||||||
|
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})
|
||||||
|
);
|
||||||
|
|
||||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
|
// Prepare GMEM LSE
|
||||||
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
|
Tensor gLSEaccum = make_tensor(
|
||||||
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
|
make_gmem_ptr(
|
||||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
|
reinterpret_cast<ElementAccum *>(
|
||||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr
|
||||||
|
) + (Split ? row_offset_lseaccum : row_offset_lse)
|
||||||
|
),
|
||||||
|
Shape<Int<kBlockM>>{},
|
||||||
|
Stride<_1>{}
|
||||||
|
);
|
||||||
|
|
||||||
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
|
// Tiled copy from SMEM -> GMEM for O
|
||||||
GmemTiledCopyO gmem_tiled_copy_Oaccum;
|
using GmemTiledCopyOAccum = std::conditional_t<
|
||||||
|
!Split,
|
||||||
|
typename KernelTraits::GmemTiledCopyO,
|
||||||
|
typename KernelTraits::GmemTiledCopyOaccum
|
||||||
|
>;
|
||||||
|
GmemTiledCopyOAccum gmem_tiled_copy_Oaccum;
|
||||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||||
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
|
||||||
|
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);
|
||||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (tidx >= kNThreadsS) { return; }
|
// If out of range of the "softmax" portion, do not store
|
||||||
|
if (tidx >= kNumThreadsS) { return; }
|
||||||
|
|
||||||
|
// Load from SMEM
|
||||||
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
|
||||||
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
||||||
|
|
||||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
// Write out the LSE
|
||||||
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
|
auto caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{});
|
||||||
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
|
auto taccOcO = thr_mma_o.partition_C(caccO);
|
||||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
auto taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
|
||||||
|
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));
|
||||||
|
|
||||||
if (get<1>(taccOcO_row(0)) == 0) {
|
if (get<1>(taccOcO_row(0)) == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size(lse); ++mi) {
|
for (int mi = 0; mi < size(lse); ++mi) {
|
||||||
const int row = get<0>(taccOcO_row(mi));
|
const int row = get<0>(taccOcO_row(mi));
|
||||||
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
|
if (row < params.seqlen_q - m_block * kBlockM) {
|
||||||
|
gLSEaccum(row) = lse(mi);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct identity layout for sO
|
// Identity layout for sO
|
||||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
auto cO = make_identity_tensor(
|
||||||
// Repeat the partitioning with identity layouts
|
make_shape(size<0>(sOaccum), size<1>(sOaccum))
|
||||||
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
);
|
||||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
auto tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
|
||||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
auto tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||||
|
|
||||||
|
// Copy final O back to GMEM
|
||||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||||
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
|
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO,
|
||||||
|
params.seqlen_q - m_block * kBlockM
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms,
|
/// compute_attn_1rowblock_splitkv_mla()
|
||||||
const int bidb, const int bidh, const int m_block,
|
/// - Core logic for Q*K -> S -> Softmax -> S*V -> O
|
||||||
const int n_split_idx, const int seqlen_k,
|
/// - Includes partial accumulation for splits and optional causal masking.
|
||||||
const int n_block_min, const int n_block_max, const bool NoSplit,
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
SharedStorage &shared_storage) {
|
template <typename KernelTraits, bool IsCausal, typename SharedStorage>
|
||||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
__forceinline__ __device__
|
||||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
void compute_attn_1rowblock_splitkv_mla(
|
||||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
const Flash_fwd_mla_params ¶ms,
|
||||||
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
|
const int batch_id,
|
||||||
constexpr int kNThreads = Kernel_traits::kNThreads;
|
const int head_id,
|
||||||
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
|
const int m_block,
|
||||||
static_assert(kNThreads == 256 and kNThreadsS == 128);
|
const int n_split_idx,
|
||||||
using Element = typename Kernel_traits::Element;
|
const int seqlen_k,
|
||||||
using index_t = typename Kernel_traits::index_t;
|
const int n_block_min,
|
||||||
|
const int n_block_max,
|
||||||
|
const bool no_split,
|
||||||
|
SharedStorage &shared_storage
|
||||||
|
) {
|
||||||
|
constexpr int kBlockM = KernelTraits::kBlockM;
|
||||||
|
constexpr int kBlockN = KernelTraits::kBlockN;
|
||||||
|
constexpr int kHeadDim = KernelTraits::kHeadDim;
|
||||||
|
constexpr int kHeadDimV = KernelTraits::kHeadDimV;
|
||||||
|
constexpr int kNumThreads = KernelTraits::kNumThreads;
|
||||||
|
constexpr int kNumThreadsS = KernelTraits::kNumThreadsSoftmax;
|
||||||
|
using Element = typename KernelTraits::Element;
|
||||||
|
using IndexT = typename KernelTraits::IndexT;
|
||||||
|
|
||||||
|
static_assert(kNumThreads == 256 && kNumThreadsS == 128, "Expected 256 main threads, 128 softmax threads.");
|
||||||
|
|
||||||
const int tidx = threadIdx.x;
|
const int tidx = threadIdx.x;
|
||||||
int n_block = n_block_max - 1;
|
int n_block = n_block_max - 1;
|
||||||
|
|
||||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
|
// Smem pointers for Q, K, V, partial S, etc.
|
||||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
|
Tensor sQ = make_tensor(
|
||||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
|
make_smem_ptr(shared_storage.smem_q.data()),
|
||||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
|
typename KernelTraits::SmemLayoutQ{}
|
||||||
|
);
|
||||||
|
Tensor sK = make_tensor(
|
||||||
|
make_smem_ptr(shared_storage.smem_k.data()),
|
||||||
|
typename KernelTraits::SmemLayoutK{}
|
||||||
|
);
|
||||||
|
Tensor sV = make_tensor(
|
||||||
|
make_smem_ptr(shared_storage.smem_k.data()),
|
||||||
|
typename KernelTraits::SmemLayoutV{}
|
||||||
|
);
|
||||||
|
Tensor sVt = make_tensor(
|
||||||
|
make_smem_ptr(shared_storage.smem_k.data()),
|
||||||
|
typename KernelTraits::SmemLayoutVtransposed{}
|
||||||
|
);
|
||||||
|
|
||||||
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
|
// Softmax partial
|
||||||
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
|
Tensor sP = make_tensor(
|
||||||
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
|
make_smem_ptr(shared_storage.smem_p.data()),
|
||||||
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
|
typename KernelTraits::SmemLayoutP{}
|
||||||
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
|
);
|
||||||
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
|
Tensor tPsP = sP(_, tidx % kNumThreadsS, _, _);
|
||||||
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
|
|
||||||
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
|
|
||||||
|
|
||||||
typename Kernel_traits::TiledMmaO tiled_mma_o;
|
// Row-based scale, sum, etc.
|
||||||
|
Tensor sScale = make_tensor(
|
||||||
|
make_smem_ptr(shared_storage.smem_scale.data()),
|
||||||
|
typename KernelTraits::SmemLayoutRow{}
|
||||||
|
);
|
||||||
|
Tensor tScale = sScale(_, tidx % kNumThreadsS);
|
||||||
|
Tensor sRowMax = make_tensor(
|
||||||
|
make_smem_ptr(shared_storage.smem_max.data()),
|
||||||
|
typename KernelTraits::SmemLayoutRow{}
|
||||||
|
);
|
||||||
|
Tensor tRowMax = sRowMax(_, tidx % kNumThreadsS);
|
||||||
|
Tensor sRowSum = make_tensor(
|
||||||
|
make_smem_ptr(shared_storage.smem_sum.data()),
|
||||||
|
typename KernelTraits::SmemLayoutRow{}
|
||||||
|
);
|
||||||
|
Tensor tRowSum = sRowSum(_, tidx % kNumThreadsS);
|
||||||
|
|
||||||
|
// Mma for O
|
||||||
|
typename KernelTraits::TiledMmaO tiled_mma_o;
|
||||||
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
|
||||||
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
|
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt);
|
||||||
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
|
||||||
clear(tOrO);
|
clear(tOrO);
|
||||||
|
|
||||||
|
// Combined softmax utility
|
||||||
flash::Softmax<2 * size<1>(tOrO)> softmax;
|
flash::Softmax<2 * size<1>(tOrO)> softmax;
|
||||||
|
|
||||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
// Warp group logic: warpGroupIdx=0 does Q*K->S, warpGroupIdx=1 does async loads for next iteration
|
||||||
if (warp_group_idx == 0) {
|
int warpGroupIdx = cutlass::canonical_warp_group_idx();
|
||||||
typename Kernel_traits::TiledMma tiled_mma;
|
if (warpGroupIdx == 0) {
|
||||||
|
// Main matmul Q*K -> S
|
||||||
|
typename KernelTraits::TiledMma tiled_mma;
|
||||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
|
||||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
|
||||||
|
|
||||||
|
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
|
||||||
|
Tensor tSrK = thr_mma.partition_fragment_B(sK);
|
||||||
|
|
||||||
|
// If n_block is odd => shift for double-buffer
|
||||||
if (n_block % 2 == 1) {
|
if (n_block % 2 == 1) {
|
||||||
// Double buffer for sK
|
constexpr int sKOffset = size(sK);
|
||||||
constexpr int sK_offset = size(sK);
|
tSrK.data() += (sKOffset / 8);
|
||||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
tOrVt.data() += (sKOffset / 8);
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
// We have a loop from n_block_max-1 down to n_block_min
|
||||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
// Need to do “masking step(s)” for partial or causal scenarios.
|
||||||
// We will have at least 1 "masking" iteration.
|
constexpr int nMaskingSteps = !IsCausal
|
||||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
? 1
|
||||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
: cute::ceil_div(kBlockM, kBlockN) + 1;
|
||||||
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
|
|
||||||
#pragma unroll 1
|
|
||||||
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
|
|
||||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
|
|
||||||
|
|
||||||
const bool is_masking_step = masking_step > 0;
|
|
||||||
const bool is_first_masking_step = masking_step == n_masking_steps;
|
|
||||||
|
|
||||||
if (is_masking_step) {
|
|
||||||
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
|
|
||||||
Tensor tScS = thr_mma.partition_C(cS);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size(tSrS); ++i) {
|
|
||||||
if constexpr (!Is_causal) { // Just masking based on col
|
|
||||||
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
|
|
||||||
} else {
|
|
||||||
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
|
||||||
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
|
|
||||||
int row = int(get<0>(tScS(i)));
|
|
||||||
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
|
|
||||||
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have key_padding_mask so we'll need to Check_inf
|
|
||||||
Tensor scale_o = is_first_masking_step
|
|
||||||
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
|
||||||
: is_masking_step ?
|
|
||||||
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
|
|
||||||
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
|
|
||||||
|
|
||||||
Tensor rP = flash::convert_type<Element>(tSrS);
|
|
||||||
cute::copy(rP, tPsP);
|
|
||||||
cute::copy(scale_o, tScale_osScale_o);
|
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
|
||||||
|
|
||||||
flash::rescale_o(tOrO, scale_o);
|
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
|
||||||
|
|
||||||
// Double buffer for sK
|
|
||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
|
||||||
tSrK.data() = tSrK.data() + sK_offset / 8;
|
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
cute::copy(softmax.row_max, tRow_maxsRow_max);
|
|
||||||
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
|
|
||||||
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
|
||||||
} else {
|
|
||||||
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
|
|
||||||
int cur_block_table = __ldg(&block_table[n_block]);
|
|
||||||
|
|
||||||
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
|
||||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
|
||||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
|
||||||
make_stride(params.q_row_stride, _1{}));
|
|
||||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
|
|
||||||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
|
|
||||||
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
|
|
||||||
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
|
|
||||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
|
||||||
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
|
||||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
|
||||||
|
|
||||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
|
||||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
|
|
||||||
params.seqlen_q - m_block * kBlockM);
|
|
||||||
|
|
||||||
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
|
||||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
|
||||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
|
||||||
make_stride(params.k_row_stride, _1{}));
|
|
||||||
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
|
|
||||||
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
|
|
||||||
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
|
|
||||||
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
|
|
||||||
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
|
||||||
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
|
||||||
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
|
||||||
|
|
||||||
if (n_block % 2 == 1) {
|
|
||||||
// Double buffer for sK
|
|
||||||
constexpr int sK_offset = size(sK);
|
|
||||||
tKsK.data() = tKsK.data() + sK_offset;
|
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to clear the sK smem tiles because K is V.
|
|
||||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
|
||||||
tKgK.data() = tKgK.data() + offset_k;
|
|
||||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
|
|
||||||
seqlen_k - n_block * kBlockN);
|
|
||||||
tKgK.data() = tKgK.data() + -offset_k;
|
|
||||||
cute::cp_async_fence();
|
|
||||||
|
|
||||||
if (n_block - 1 >= n_block_min) {
|
|
||||||
cur_block_table = __ldg(&block_table[n_block - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll 1
|
#pragma unroll 1
|
||||||
for (; n_block >= n_block_min; --n_block) {
|
for (int masking
|
||||||
flash::cp_async_wait<0>();
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (n_block - 1 >= n_block_min) {
|
|
||||||
// Double buffer for sK
|
|
||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
|
||||||
tKsK.data() = tKsK.data() + sK_offset;
|
|
||||||
|
|
||||||
const index_t offset_k = cur_block_table * params.k_batch_stride;
|
|
||||||
tKgK.data() = tKgK.data() + offset_k;
|
|
||||||
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
|
|
||||||
tKgK.data() = tKgK.data() + -offset_k;
|
|
||||||
cute::cp_async_fence();
|
|
||||||
}
|
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
|
|
||||||
|
|
||||||
if (n_block - 2 >= n_block_min) {
|
|
||||||
cur_block_table = __ldg(&block_table[n_block - 2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
typename Kernel_traits::TiledMma tiled_mma;
|
|
||||||
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
|
|
||||||
Tensor rP = make_tensor<Element>(tSrS_layout);
|
|
||||||
Tensor scale_o = make_tensor<float>(Shape<_2>{});
|
|
||||||
cute::copy(tScale_osScale_o, scale_o);
|
|
||||||
cute::copy(tPsP, rP);
|
|
||||||
|
|
||||||
flash::rescale_o(tOrO, scale_o);
|
|
||||||
|
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
|
||||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
|
|
||||||
|
|
||||||
// Double buffer for sK
|
|
||||||
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
|
|
||||||
tOrVt.data() = tOrVt.data() + sK_offset / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
|
|
||||||
cute::copy(tRow_maxsRow_max, softmax.row_max);
|
|
||||||
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (NoSplit)
|
|
||||||
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
|
||||||
else
|
|
||||||
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
|
|
||||||
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
|
|
||||||
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
|
||||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
|
||||||
const int m_block = blockIdx.x;
|
|
||||||
const int bidh = blockIdx.y;
|
|
||||||
const int partition_idx = blockIdx.z;
|
|
||||||
|
|
||||||
extern __shared__ char shared_memory[];
|
|
||||||
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
|
|
||||||
|
|
||||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
|
|
||||||
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
|
|
||||||
int begin_idx = tile_scheduler_metadata.x;
|
|
||||||
int begin_seqlen = tile_scheduler_metadata.y;
|
|
||||||
int end_idx = tile_scheduler_metadata.z;
|
|
||||||
int end_seqlen = tile_scheduler_metadata.w;
|
|
||||||
if (begin_idx >= params.b) return;
|
|
||||||
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
|
|
||||||
|
|
||||||
#pragma unroll 1
|
|
||||||
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
|
|
||||||
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
|
|
||||||
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
|
|
||||||
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
|
|
||||||
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
|
|
||||||
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
|
|
||||||
if (batch_id > begin_idx) {
|
|
||||||
__syncthreads(); // Barrier between two tiles.
|
|
||||||
}
|
|
||||||
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
|
|
||||||
__global__ void __launch_bounds__(256, 1, 1)
|
|
||||||
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
|
||||||
constexpr int kNThreads = 128;
|
|
||||||
|
|
||||||
const int tidx = threadIdx.x;
|
|
||||||
const int bidx = blockIdx.x;
|
|
||||||
const int hs = params.h * params.seqlen_q;
|
const int hs = params.h * params.seqlen_q;
|
||||||
const int batch_idx = bidx / hs;
|
const int batch_idx = bidx / hs;
|
||||||
const int hs_idx = bidx % hs;
|
const int hs_idx = bidx % hs;
|
||||||
|
Loading…
Reference in New Issue
Block a user