Performance optimization for compute-bound cases

This commit is contained in:
Shengyu Liu
2025-04-21 17:22:59 +08:00
parent 063ffa8ec1
commit 287061ec34
20 changed files with 1799 additions and 1217 deletions

13
csrc/kernels/config.h Normal file
View File

@@ -0,0 +1,13 @@
#pragma once
namespace Config {
static constexpr int BLOCK_SIZE_M = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5;
}

View File

@@ -0,0 +1,82 @@
#include "get_mla_metadata.h"
#include <cuda_runtime_api.h>
#include <cutlass/fast_math.h>
#include "utils.h"
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
}
for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
}
__syncwarp();
if (threadIdx.x == 0) {
int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncwarp();
for (int i = threadIdx.x; i <= batch_size; i += 32) {
num_splits_ptr[i] = num_splits_shared[i];
}
}
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*2+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include "flash_mla.h"
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream);

207
csrc/kernels/mla_combine.cu Normal file
View File

@@ -0,0 +1,207 @@
#include "mla_combine.h"
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "flash_mla.h"
#include "utils.h"
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
using namespace cute;
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
const int batch_idx = blockIdx.x;
const int m_block_idx = blockIdx.y;
const int warp_idx = threadIdx.x / 32;
const int lane_idx = threadIdx.x % 32;
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx;
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
if (my_num_splits == 1) {
return;
}
const int num_q_seqs = params.q_seq_per_hk * params.h_k;
const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M);
Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
make_stride(num_q_seqs, _1{})
);
Tensor gLse = make_tensor(
make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
Shape<Int<BLOCK_SIZE_M>>{},
Stride<_1>{}
);
extern __shared__ float smem_buf[];
Tensor sLseScale = make_tensor(
make_smem_ptr(smem_buf),
Shape<Int<BLOCK_SIZE_M>, Int<MAX_SPLITS>>{},
Stride<Int<MAX_SPLITS+1>, _1>{} // +1 to avoid bank conflict
);
// Wait for the previous kernel (the MLA kernel) to finish
cudaGridDependencySynchronize();
// Read gLseAccum into sLseScale
{
#pragma unroll 4
for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) {
int split_idx = elem_idx / BLOCK_SIZE_M;
int seq_idx = elem_idx % BLOCK_SIZE_M;
sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY;
}
__syncthreads();
}
if (warp_idx >= num_cur_valid_q_seqs)
return;
// Warp #i gathers LseAccum for seq #i
{
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
float local_lse[NUM_LSE_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY;
}
float max_lse = -INFINITY;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
max_lse = max(max_lse, local_lse[i]);
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2)
max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2)
sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse;
if (lane_idx == 0)
gLse(warp_idx) = global_lse / (float)M_LOG2E;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse);
}
}
__syncwarp();
// Warp #i accumulates activation for seq #i
{
const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V;
Tensor gOaccum = make_tensor(
make_gmem_ptr(reinterpret_cast<float *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<MAX_SPLITS>, Int<HEAD_DIM_V>>{},
make_stride(num_q_seqs*HEAD_DIM_V, _1{})
);
static_assert(HEAD_DIM_V % 32 == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32;
float result[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
result[i] = 0.0f;
#pragma unroll 2
for (int split = 0; split < my_num_splits; ++split) {
float lse_scale = sLseScale(warp_idx, split);
if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
result[i] += lse_scale * gOaccum(split, lane_idx + i*32);
}
}
}
cudaTriggerProgrammaticLaunchCompletion();
const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx;
const int k_head_idx = q_seq_idx / params.q_seq_per_hk;
auto o_ptr = reinterpret_cast<ElementT *>(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride;
Tensor gO = make_tensor(
make_gmem_ptr(o_ptr),
Shape<Int<HEAD_DIM_V>>{},
Stride<_1>{}
);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
gO(lane_idx+i*32) = (ElementT)result[i];
}
}
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 8;
constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, Config::HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t combine_kernel_config = {
dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1),
dim3(NUM_THREADS, 1, 1),
smem_size,
stream,
attribute,
1
};
cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_mla_combine_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#endif

View File

@@ -0,0 +1,6 @@
#pragma once
#include "flash_mla.h"
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);

1349
csrc/kernels/splitkv_mla.cu Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
#pragma once
#include "flash_mla.h"
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);

106
csrc/kernels/traits.h Normal file
View File

@@ -0,0 +1,106 @@
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>
#include "config.h"
using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
using namespace cute;
template<typename InputT_>
struct Traits {
using InputT = InputT_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
using TiledMMA_QK_sQ = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using SmemLayoutQ = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutK = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutV = decltype(composition(
SmemLayoutK{},
make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})
)); // A transposed version of SmemLayoutK
using SmemLayoutP0 = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
));
using rP0Layout = decltype(layout(partition_fragment_C(
TiledMMA_QK_sQ{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
)));
struct SharedMemoryPlan {
cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;
cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;
cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;
TMABarrier barriers_K0[HEAD_DIM_K/64];
TMABarrier barriers_K1[HEAD_DIM_K/64];
TMABarrier barrier_Q;
};
};
template<
typename ShapeQ, typename TMA_Q,
typename ShapeK, typename TMA_K,
typename ShapeO, typename TMA_O
>
struct TmaParams {
ShapeQ shape_Q;
TMA_Q tma_Q;
ShapeK shape_K;
TMA_K tma_K;
ShapeO shape_O;
TMA_O tma_O;
};
enum NamedBarriers : int {
sScale0Ready = 0,
sScale1Ready = 1,
sP0Ready = 2,
rO1sP0sV0RIssued = 3
};

32
csrc/kernels/utils.h Normal file
View File

@@ -0,0 +1,32 @@
#pragma once
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
#define FLASH_ASSERT(cond) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
exit(1); \
} \
} while(0)
#define FLASH_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while(0)
#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); }