#pragma once #include #include #include #include #include "config.h" using TMABarrier = cutlass::arch::ClusterTransactionBarrier; using namespace cute; template struct Traits { using InputT = InputT_; static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; static constexpr int NUM_THREADS = 256; static_assert(std::is_same_v || std::is_same_v); using TiledMMA_QK_sQ = decltype(make_tiled_mma( GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), Layout>{} )); using TiledMMA_QK_rQ = decltype(make_tiled_mma( GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), Layout>{} )); using TiledMMA_PV_LocalP = decltype(make_tiled_mma( GMMA::rs_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), Layout>{} )); using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::MN>(), Layout>{} )); using SmemLayoutQ = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using SmemLayoutK = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using SmemLayoutV = decltype(composition( SmemLayoutK{}, make_layout(Shape, Int>{}, GenRowMajor{}) )); // A transposed version of SmemLayoutK using SmemLayoutP0 = decltype(tile_to_shape( GMMA::Layout_K_SW128_Atom{}, Shape, Int>{} )); using rP0Layout = decltype(layout(partition_fragment_C( TiledMMA_QK_sQ{}, Shape, Int>{} ))); struct SharedMemoryPlan { cute::array_aligned> smem_sQ; cute::array_aligned> smem_sK0; cute::array_aligned> smem_sK1; cute::array_aligned> smem_sP0; cute::array_aligned smem_sM; cute::array_aligned sL_reduction_wksp; cute::array_aligned smem_sScale0; cute::array_aligned smem_sScale1; TMABarrier barriers_K0[HEAD_DIM_K/64]; TMABarrier barriers_K1[HEAD_DIM_K/64]; TMABarrier barrier_Q; }; }; template< typename ShapeQ, typename TMA_Q, typename ShapeK, typename TMA_K, typename ShapeO, typename TMA_O > struct TmaParams { ShapeQ shape_Q; TMA_Q tma_Q; ShapeK shape_K; TMA_K tma_K; ShapeO shape_O; TMA_O tma_O; }; enum NamedBarriers : int { sScale0Ready = 0, sScale1Ready = 1, sP0Ready = 2, rO1sP0sV0RIssued = 3, sMInitialized = 4, };