#pragma once #include #include #include #include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" #include "mask.h" #include "dropout.h" #include "rotary.h" #include "attn_mask.h" namespace flash { using namespace cute; template __forceinline__ __device__ void compute_attn_1rowblock_splitkv_k64_mla_V1x8(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kBlockKSmem = Kernel_traits::kBlockKSmem; constexpr int kBlockKGmem = Kernel_traits::UseWarpsNx1 ? Kernel_traits::kBlockKSmem : 128; static_assert(kBlockKSmem == 64); static_assert(kBlockM % (kNWarps * 16) == 0); using GmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum >; using ElementO = std::conditional_t; const BlockInfo binfo(params, bidb); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_block_min = !Is_local ? n_split_idx * n_blocks_per_split : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. // Otherwise we might read OOB elements from gK and gV, // or get wrong results when we combine gOaccum from different blocks. const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), Shape>{}, Stride<_1>{}); GmemTiledCopyO gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); clear(tOrOaccum); // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d_value; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll for (int m = 0; m < size<1>(tOgOaccum); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } } return; } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride : (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); //Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutK{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutVtNoSwizzle{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QK; auto gmem_thr_copy_QK = gmem_tiled_copy_QK.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QK.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QK.partition_D(sQ); Tensor tKgK = gmem_thr_copy_QK.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QK.partition_D(sK); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_V; auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx); Tensor tVgV = gmem_thr_copy_V.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_V.partition_D(sV); Tensor tVrV = make_fragment_like(tVgV); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling // auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomB64{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomB64{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK_ori = smem_thr_copy_K.partition_S(sK); const int offset_swz334 = (__lane_id() % 2) ? ((__lane_id() / 16 % 2) ? -4 : 4) : 0; Tensor tSsK = make_tensor(make_smem_ptr(reinterpret_cast(tSsK_ori.data().get()) + offset_swz334), tSsK_ori.layout()); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVtNoSwizzle); // PREDICATES // // // Allocate predicate tensors for m and n // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QK.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QK.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) Tensor tVcV = gmem_thr_copy_V.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if constexpr (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } // Prologue // Copy from Knew to K, optionally apply rotary embedding. typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. // This maps to accessing the first 64 rows of knew_ptr. Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), Shape, Int>{}, make_stride(params.knew_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); Tensor tKgKnew = gmem_thr_copy_QK.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) Tensor tVgVnew = gmem_thr_copy_QK.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) Tensor tVgVcache = gmem_thr_copy_QK.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { flash::copy_w_min_idx( tVgVnew, tVgVcache, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { flash::copy_w_min_idx( tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); } else { if (params.is_rotary_interleaved) { // Don't clear OOB_K because we're writing to global memory flash::copy_rotary_interleaved( tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); } else { // Don't clear OOB_K because we're writing to global memory flash::copy_rotary_contiguous( tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); } } tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); if constexpr (!Is_page_attn) { tVgVcache.data() = tVgVcache.data() + (-int(kBlockN * params.v_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; const int offset_diff = block_table_offset_next - block_table_offset_cur; tVgVcache.data() = tVgVcache.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; } } } // Need this before we can read in K again, so that we'll see the updated K values. __syncthreads(); tKgK.data() = tKgK_data; } // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy_b128(tQgQ, tQrQ, tQcQ, params.d, binfo.actual_seqlen_q - m_block * kBlockM); } else { const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); if (params.is_rotary_interleaved) { flash::copy_rotary_interleaved( tQgQ, tQrQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); } else { flash::copy_rotary_contiguous( tQgQ, tQrQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); } } cute::copy(tQrQ, tQsQ); if (Kernel_traits::Is_Q_in_regs) { flash::sync_threads(); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ); flash::sync_threads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. Tensor tKrK = make_fragment_like(tKgK); if constexpr (!Is_page_attn) { flash::copy_b128(tKgK, tKrK, tKVcKV, params.d, binfo.actual_seqlen_k - n_block * kBlockN); } else { flash::copy_b128_page_one(gK, tKgK, tKrK, tKVcKV, params.d, n_block, block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size, binfo.actual_seqlen_k - n_block * kBlockN); } // flash::cp_async_wait<0>(); // __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); clear(acc_o); flash::Softmax(acc_o)> softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) swap_swz334(tKrK); cute::copy(tKrK, tKsK); clear(acc_s); // Advance gV if (masking_step > 0) { if constexpr (!Is_page_attn) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); flash::copy_b128(tVgV, tVrV, tVcV, params.d_value); } else { flash::copy_b128_page_one(gV, tVgV, tVrV, tVcV, params.d_value, n_block, block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size); } } else { if constexpr (!Is_page_attn) { // Clear the smem tiles to account for predicated off loads flash::copy_b128( tVgV, tVrV, tVcV, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN ); } else { flash::copy_b128_page_one(gV, tVgV, tVrV, tVcV, params.d_value, n_block, block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size, binfo.actual_seqlen_k - n_block * kBlockN); } } flash::sync_threads(); flash::gemm_opt( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 64) * 16 + (tidx & 0xf), kNWarps * 16 ); cute::copy(tVrV, tVsV); if (n_block > n_block_min) { // Advance gK if constexpr (!Is_page_attn) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy_b128(tKgK, tKrK, tKVcKV, params.d); } else { flash::copy_b128_page_one(gK, tKgK, tKrK, tKVcKV, params.d, n_block - 1, block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size); } } // We have key_padding_mask so we'll need to Check_inf masking_step == 0 ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert acc_s from fp32 to fp16/bf16 //Tensor rP = flash::convert_type(acc_s); CONVERT_TENSOR_TYPE(ElementAccum, Element, acc_s, rP) // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. //Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); Tensor tOrP = make_tensor(rP.data(), acc_s.layout()); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) swap_swz334(tKrK); cute::copy(tKrK, tKsK); clear(acc_s); // Advance gV if constexpr (!Is_page_attn) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); flash::copy_b128(tVgV, tVrV, tVcV, params.d_value); } else { flash::copy_b128_page_one(gV, tVgV, tVrV, tVcV, params.d_value, n_block, block_table, params.v_batch_stride, params.v_row_stride, params.page_block_size); } flash::sync_threads(); flash::gemm_opt( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); cute::copy(tVrV, tVsV); if (n_block > n_block_min) { // Advance gK if constexpr (!Is_page_attn) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy_b128(tKgK, tKrK, tKVcKV, params.d); } else { flash::copy_b128_page_one(gK, tKgK, tKrK, tKVcKV, params.d, n_block - 1, block_table, params.k_batch_stride, params.k_row_stride, params.page_block_size); } } if constexpr (Is_softcap){ flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 64) * 16 + (tidx & 0xf), kNWarps * 16 ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); //Tensor rP = flash::convert_type(acc_s); CONVERT_TENSOR_TYPE(ElementAccum, Element, acc_s, rP) // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. //Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); Tensor tOrP = make_tensor(rP.data(), acc_s.layout()); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning using SmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum >; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); //Tensor rO = flash::convert_type(acc_o); CONVERT_TENSOR_TYPE(ElementAccum, ElementO, acc_o, rO) Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sOaccum is larger than sQ, so we need to syncthreads here // TODO: allocate enough smem for sOaccum if constexpr (Split || Kernel_traits::Share_Q_K_smem) { flash::sync_threads(); } cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), Shape>{}, Stride<_1>{}); // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } GmemTiledCopyO gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); flash::sync_threads(); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_4>{})(make_coord(0, _), _, 0); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy_reg_to_global( tOrOaccum, tOgOaccum, tOcO, params.d_value, binfo.actual_seqlen_q - m_block * kBlockM ); } } // namespace flash