FlashMLA/csrc/flash_fwd_split_kernel_k64_V1x8.h
Kevin Zhang e0557deb3a Feature:Support flashMLA decoding via flashAttn2(#29)
Changes:
1. Implement flashMLA with matrix absorption algorithm via flashAttn2
2. Add golden test on MXMACA platform
2025-02-24 23:56:05 +08:00

600 lines
36 KiB
C++

#pragma once
#include <cute/algorithm/copy.hpp>
#include <mctlass/mctlass.h>
#include <mctlass/array.h>
#include <mctlass/numeric_types.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "mask.h"
#include "dropout.h"
#include "rotary.h"
#include "attn_mask.h"
namespace flash {
using namespace cute;
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Is_page_attn, typename Params>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_k64_mla_V1x8(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int kBlockKSmem = Kernel_traits::kBlockKSmem;
constexpr int kBlockKGmem = Kernel_traits::UseWarpsNx1 ? Kernel_traits::kBlockKSmem : 128;
static_assert(kBlockKSmem == 64);
static_assert(kBlockM % (kNWarps * 16) == 0);
using GmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::GmemTiledCopyO,
typename Kernel_traits::GmemTiledCopyOaccum
>;
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
const int n_block_min = !Is_local
? n_split_idx * n_blocks_per_split
: std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
// Otherwise we might read OOB elements from gK and gV,
// or get wrong results when we combine gOaccum from different blocks.
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
clear(tOrOaccum);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d_value; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgOaccum); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
}
return;
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.v_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
//Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutVtNoSwizzle{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtNoSwizzle{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QK;
auto gmem_thr_copy_QK = gmem_tiled_copy_QK.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QK.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QK.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QK.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QK.partition_D(sK);
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_V;
auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx);
Tensor tVgV = gmem_thr_copy_V.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_V.partition_D(sV);
Tensor tVrV = make_fragment_like(tVgV);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomB64{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomB64{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK_ori = smem_thr_copy_K.partition_S(sK);
const int offset_swz334 = (__lane_id() % 2) ? ((__lane_id() / 16 % 2) ? -4 : 4) : 0;
Tensor tSsK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(tSsK_ori.data().get()) + offset_swz334), tSsK_ori.layout());
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVtNoSwizzle);
// PREDICATES
//
// // Allocate predicate tensors for m and n
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QK.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QK.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tVcV = gmem_thr_copy_V.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
// Set predicates for k bounds
if constexpr (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}
// Prologue
// Copy from Knew to K, optionally apply rotary embedding.
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
if constexpr (Append_KV) {
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.knew_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_QK.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew = gmem_thr_copy_QK.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
Tensor tVgVcache = gmem_thr_copy_QK.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data();
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
flash::copy_w_min_idx<Is_even_K>(
tVgVnew, tVgVcache, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
if (params.rotary_dim == 0) {
flash::copy_w_min_idx<Is_even_K>(
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
} else {
if (params.is_rotary_interleaved) {
// Don't clear OOB_K because we're writing to global memory
flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
);
tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
} else {
// Don't clear OOB_K because we're writing to global memory
flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
);
tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
}
}
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
if constexpr (!Is_page_attn) {
tVgVcache.data() = tVgVcache.data() + (-int(kBlockN * params.v_row_stride));
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgVcache.data() = tVgVcache.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
}
}
}
// Need this before we can read in K again, so that we'll see the updated K values.
__syncthreads();
tKgK.data() = tKgK_data;
}
// Read Q from gmem to smem, optionally apply rotary embedding.
Tensor tQrQ = make_fragment_like(tQgQ);
if (!Append_KV || params.rotary_dim == 0) {
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy_b128<Is_even_MN, Is_even_K>(tQgQ, tQrQ, tQcQ, params.d, binfo.actual_seqlen_q - m_block * kBlockM);
} else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
if (params.is_rotary_interleaved) {
flash::copy_rotary_interleaved<Is_even_K>(
tQgQ, tQrQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
0, params.d, params.rotary_dim
);
} else {
flash::copy_rotary_contiguous<Is_even_K>(
tQgQ, tQrQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
0, params.d, params.rotary_dim
);
}
}
cute::copy(tQrQ, tQsQ);
if (Kernel_traits::Is_Q_in_regs) {
flash::sync_threads();
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ);
flash::sync_threads();
}
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
Tensor tKrK = make_fragment_like(tKgK);
if constexpr (!Is_page_attn) {
flash::copy_b128<Is_even_MN, Is_even_K>(tKgK, tKrK, tKVcKV, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
} else {
flash::copy_b128_page_one<Kernel_traits, Is_even_MN, Is_even_K>(gK, tKgK, tKrK, tKVcKV, params.d, n_block,
block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size, binfo.actual_seqlen_k - n_block * kBlockN);
}
// flash::cp_async_wait<0>();
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
// __syncthreads();
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
swap_swz334(tKrK);
cute::copy(tKrK, tKsK);
clear(acc_s);
// Advance gV
if (masking_step > 0) {
if constexpr (!Is_page_attn) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tVgV, tVrV, tVcV, params.d_value);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gV, tVgV, tVrV, tVcV, params.d_value, n_block,
block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size);
}
} else {
if constexpr (!Is_page_attn) {
// Clear the smem tiles to account for predicated off loads
flash::copy_b128<Is_even_MN, Is_even_K>(
tVgV, tVrV, tVcV, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN
);
} else {
flash::copy_b128_page_one<Kernel_traits, Is_even_MN, Is_even_K>(gV, tVgV, tVrV, tVcV, params.d_value, n_block,
block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size, binfo.actual_seqlen_k - n_block * kBlockN);
}
}
flash::sync_threads();
flash::gemm_opt</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 64) * 16 + (tidx & 0xf), kNWarps * 16
);
cute::copy(tVrV, tVsV);
if (n_block > n_block_min) {
// Advance gK
if constexpr (!Is_page_attn) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tKgK, tKrK, tKVcKV, params.d);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gK, tKgK, tKrK, tKVcKV, params.d, n_block - 1,
block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size);
}
}
// We have key_padding_mask so we'll need to Check_inf
masking_step == 0
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN, true, true>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN, true, true>(acc_s, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert acc_s from fp32 to fp16/bf16
//Tensor rP = flash::convert_type<Element>(acc_s);
CONVERT_TENSOR_TYPE(ElementAccum, Element, acc_s, rP)
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
//Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), acc_s.layout());
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
swap_swz334(tKrK);
cute::copy(tKrK, tKsK);
clear(acc_s);
// Advance gV
if constexpr (!Is_page_attn) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tVgV, tVrV, tVcV, params.d_value);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gV, tVgV, tVrV, tVcV, params.d_value, n_block,
block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size);
}
flash::sync_threads();
flash::gemm_opt</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
cute::copy(tVrV, tVsV);
if (n_block > n_block_min) {
// Advance gK
if constexpr (!Is_page_attn) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy_b128</*Is_even_MN=*/true, Is_even_K>(tKgK, tKrK, tKVcKV, params.d);
} else {
flash::copy_b128_page_one<Kernel_traits, /*Is_even_MN=*/true, Is_even_K>(gK, tKgK, tKrK, tKVcKV, params.d, n_block - 1,
block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size);
}
}
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask</*Causal_mask=*/false>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 64) * 16 + (tidx & 0xf), kNWarps * 16
);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local, true, true>(acc_s, acc_o, params.scale_softmax_log2);
//Tensor rP = flash::convert_type<Element>(acc_s);
CONVERT_TENSOR_TYPE(ElementAccum, Element, acc_s, rP)
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
//Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), acc_s.layout());
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, /*Return_lse*/true, Split>(acc_o, params.scale_softmax);
// if (cute::thread0()) { print(lse); }
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
//Tensor rO = flash::convert_type<ElementO>(acc_o);
CONVERT_TENSOR_TYPE(ElementAccum, ElementO, acc_o, rO)
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sOaccum is larger than sQ, so we need to syncthreads here
// TODO: allocate enough smem for sOaccum
if constexpr (Split || Kernel_traits::Share_Q_K_smem) { flash::sync_threads(); }
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ m_block * kBlockM) * params.d_rounded;
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
// if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
flash::sync_threads();
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0>(taccOcO))::value == 4);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_4>{})(make_coord(0, _), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy_reg_to_global<Is_even_MN, Is_even_K>(
tOrOaccum, tOgOaccum, tOcO, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM
);
}
} // namespace flash