mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Performance optimization for compute-bound cases
This commit is contained in:
13
csrc/kernels/config.h
Normal file
13
csrc/kernels/config.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
namespace Config {
|
||||
|
||||
static constexpr int BLOCK_SIZE_M = 64;
|
||||
static constexpr int PAGE_BLOCK_SIZE = 64;
|
||||
|
||||
static constexpr int HEAD_DIM_K = 576;
|
||||
static constexpr int HEAD_DIM_V = 512;
|
||||
|
||||
static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5;
|
||||
|
||||
}
|
||||
82
csrc/kernels/get_mla_metadata.cu
Normal file
82
csrc/kernels/get_mla_metadata.cu
Normal file
@@ -0,0 +1,82 @@
|
||||
#include "get_mla_metadata.h"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cutlass/fast_math.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
__global__ void __launch_bounds__(32, 1, 1)
|
||||
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
|
||||
int *seqlens_k_ptr = params.seqlens_k_ptr;
|
||||
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
|
||||
int *num_splits_ptr = params.num_splits_ptr;
|
||||
int batch_size = params.batch_size;
|
||||
int block_size_n = params.block_size_n;
|
||||
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
|
||||
int num_sm_parts = params.num_sm_parts;
|
||||
|
||||
extern __shared__ int shared_mem[];
|
||||
int* num_blocks_shared = shared_mem; // [batch_size]
|
||||
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
|
||||
|
||||
int total_num_blocks = 0;
|
||||
for (int i = threadIdx.x; i < batch_size; i += 32) {
|
||||
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
|
||||
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
|
||||
num_blocks_shared[i] = num_blocks;
|
||||
}
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
|
||||
|
||||
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
|
||||
num_splits_shared[0] = 0;
|
||||
for (int i = 0; i < num_sm_parts; ++i) {
|
||||
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
|
||||
tile_scheduler_metadata0[0] = now_idx;
|
||||
tile_scheduler_metadata0[1] = now_block * block_size_n;
|
||||
tile_scheduler_metadata1 = now_n_split_idx;
|
||||
int remain_payload = payload;
|
||||
while (now_idx < batch_size) {
|
||||
int num_blocks = num_blocks_shared[now_idx];
|
||||
int now_remain_blocks = num_blocks - now_block;
|
||||
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
|
||||
cum_num_splits += now_n_split_idx + 1;
|
||||
num_splits_shared[now_idx + 1] = cum_num_splits;
|
||||
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
|
||||
++now_idx;
|
||||
now_block = 0;
|
||||
now_n_split_idx = 0;
|
||||
} else {
|
||||
if (remain_payload - fixed_overhead_num_blocks > 0) {
|
||||
now_block += remain_payload - fixed_overhead_num_blocks;
|
||||
++now_n_split_idx;
|
||||
remain_payload = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
|
||||
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
|
||||
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
|
||||
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
|
||||
}
|
||||
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
for (int i = threadIdx.x; i <= batch_size; i += 32) {
|
||||
num_splits_ptr[i] = num_splits_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) {
|
||||
int smem_size = sizeof(int) * (params.batch_size*2+1);
|
||||
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
5
csrc/kernels/get_mla_metadata.h
Normal file
5
csrc/kernels/get_mla_metadata.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "flash_mla.h"
|
||||
|
||||
void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream);
|
||||
207
csrc/kernels/mla_combine.cu
Normal file
207
csrc/kernels/mla_combine.cu
Normal file
@@ -0,0 +1,207 @@
|
||||
#include "mla_combine.h"
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "flash_mla.h"
|
||||
#include "utils.h"
|
||||
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
|
||||
__global__ void __launch_bounds__(NUM_THREADS)
|
||||
flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
|
||||
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
|
||||
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
|
||||
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
|
||||
const int batch_idx = blockIdx.x;
|
||||
const int m_block_idx = blockIdx.y;
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
const int lane_idx = threadIdx.x % 32;
|
||||
|
||||
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
|
||||
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
|
||||
const int my_num_splits = end_split_idx - start_split_idx;
|
||||
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
|
||||
if (my_num_splits == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_q_seqs = params.q_seq_per_hk * params.h_k;
|
||||
const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M);
|
||||
Tensor gLseAccum = make_tensor(
|
||||
make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
|
||||
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
|
||||
make_stride(num_q_seqs, _1{})
|
||||
);
|
||||
Tensor gLse = make_tensor(
|
||||
make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
|
||||
Shape<Int<BLOCK_SIZE_M>>{},
|
||||
Stride<_1>{}
|
||||
);
|
||||
|
||||
extern __shared__ float smem_buf[];
|
||||
Tensor sLseScale = make_tensor(
|
||||
make_smem_ptr(smem_buf),
|
||||
Shape<Int<BLOCK_SIZE_M>, Int<MAX_SPLITS>>{},
|
||||
Stride<Int<MAX_SPLITS+1>, _1>{} // +1 to avoid bank conflict
|
||||
);
|
||||
|
||||
// Wait for the previous kernel (the MLA kernel) to finish
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
// Read gLseAccum into sLseScale
|
||||
{
|
||||
#pragma unroll 4
|
||||
for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) {
|
||||
int split_idx = elem_idx / BLOCK_SIZE_M;
|
||||
int seq_idx = elem_idx % BLOCK_SIZE_M;
|
||||
sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (warp_idx >= num_cur_valid_q_seqs)
|
||||
return;
|
||||
|
||||
// Warp #i gathers LseAccum for seq #i
|
||||
{
|
||||
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
|
||||
float local_lse[NUM_LSE_PER_THREAD];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
|
||||
const int split_idx = i*32 + lane_idx;
|
||||
local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY;
|
||||
}
|
||||
|
||||
float max_lse = -INFINITY;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
|
||||
max_lse = max(max_lse, local_lse[i]);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2)
|
||||
max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
|
||||
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
|
||||
|
||||
float sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
|
||||
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2)
|
||||
sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
|
||||
|
||||
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse;
|
||||
if (lane_idx == 0)
|
||||
gLse(warp_idx) = global_lse / (float)M_LOG2E;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
|
||||
const int split_idx = i*32 + lane_idx;
|
||||
if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp();
|
||||
|
||||
// Warp #i accumulates activation for seq #i
|
||||
{
|
||||
const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V;
|
||||
Tensor gOaccum = make_tensor(
|
||||
make_gmem_ptr(reinterpret_cast<float *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<MAX_SPLITS>, Int<HEAD_DIM_V>>{},
|
||||
make_stride(num_q_seqs*HEAD_DIM_V, _1{})
|
||||
);
|
||||
|
||||
static_assert(HEAD_DIM_V % 32 == 0);
|
||||
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32;
|
||||
float result[ELEMS_PER_THREAD];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
|
||||
result[i] = 0.0f;
|
||||
|
||||
#pragma unroll 2
|
||||
for (int split = 0; split < my_num_splits; ++split) {
|
||||
float lse_scale = sLseScale(warp_idx, split);
|
||||
if (lse_scale != 0.f) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
|
||||
result[i] += lse_scale * gOaccum(split, lane_idx + i*32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
|
||||
const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx;
|
||||
const int k_head_idx = q_seq_idx / params.q_seq_per_hk;
|
||||
auto o_ptr = reinterpret_cast<ElementT *>(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride;
|
||||
Tensor gO = make_tensor(
|
||||
make_gmem_ptr(o_ptr),
|
||||
Shape<Int<HEAD_DIM_V>>{},
|
||||
Stride<_1>{}
|
||||
);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
|
||||
gO(lane_idx+i*32) = (ElementT)result[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
|
||||
[&] { \
|
||||
if (NUM_SPLITS <= 32) { \
|
||||
constexpr static int NAME = 32; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 64) { \
|
||||
constexpr static int NAME = 64; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 96) { \
|
||||
constexpr static int NAME = 96; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 128) { \
|
||||
constexpr static int NAME = 128; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (NUM_SPLITS <= 160) { \
|
||||
constexpr static int NAME = 160; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
FLASH_ASSERT(false); \
|
||||
} \
|
||||
}()
|
||||
|
||||
|
||||
template<typename ElementT>
|
||||
void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) {
|
||||
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
|
||||
constexpr int BLOCK_SIZE_M = 8;
|
||||
constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
|
||||
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
|
||||
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, Config::HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
|
||||
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
|
||||
cudaLaunchAttribute attribute[1];
|
||||
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attribute[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
cudaLaunchConfig_t combine_kernel_config = {
|
||||
dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1),
|
||||
dim3(NUM_THREADS, 1, 1),
|
||||
smem_size,
|
||||
stream,
|
||||
attribute,
|
||||
1
|
||||
};
|
||||
cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params);
|
||||
});
|
||||
CHECK_CUDA_KERNEL_LAUNCH();
|
||||
}
|
||||
|
||||
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
|
||||
#ifndef FLASH_MLA_DISABLE_FP16
|
||||
template void run_flash_mla_combine_kernel<cutlass::half_t>(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
#endif
|
||||
6
csrc/kernels/mla_combine.h
Normal file
6
csrc/kernels/mla_combine.h
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "flash_mla.h"
|
||||
|
||||
template<typename ElementT>
|
||||
void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
1349
csrc/kernels/splitkv_mla.cu
Normal file
1349
csrc/kernels/splitkv_mla.cu
Normal file
File diff suppressed because it is too large
Load Diff
6
csrc/kernels/splitkv_mla.h
Normal file
6
csrc/kernels/splitkv_mla.h
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "flash_mla.h"
|
||||
|
||||
template<typename InputT>
|
||||
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream);
|
||||
106
csrc/kernels/traits.h
Normal file
106
csrc/kernels/traits.h
Normal file
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/barrier.h>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
using namespace cute;
|
||||
|
||||
template<typename InputT_>
|
||||
struct Traits {
|
||||
using InputT = InputT_;
|
||||
|
||||
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
|
||||
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
|
||||
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
|
||||
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
|
||||
|
||||
static constexpr int NUM_THREADS = 256;
|
||||
|
||||
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
|
||||
|
||||
using TiledMMA_QK_sQ = decltype(make_tiled_mma(
|
||||
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
|
||||
Layout<Shape<_1, _1, _1>>{}
|
||||
));
|
||||
|
||||
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
|
||||
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
|
||||
Layout<Shape<_1, _1, _1>>{}
|
||||
));
|
||||
|
||||
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
|
||||
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
|
||||
Layout<Shape<_1, _1, _1>>{}
|
||||
));
|
||||
|
||||
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
|
||||
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
|
||||
Layout<Shape<_1, _1, _1>>{}
|
||||
));
|
||||
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
GMMA::Layout_K_SW128_Atom<InputT>{},
|
||||
Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}
|
||||
));
|
||||
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
GMMA::Layout_K_SW128_Atom<InputT>{},
|
||||
Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}
|
||||
));
|
||||
|
||||
using SmemLayoutV = decltype(composition(
|
||||
SmemLayoutK{},
|
||||
make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})
|
||||
)); // A transposed version of SmemLayoutK
|
||||
|
||||
using SmemLayoutP0 = decltype(tile_to_shape(
|
||||
GMMA::Layout_K_SW128_Atom<InputT>{},
|
||||
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
|
||||
));
|
||||
|
||||
using rP0Layout = decltype(layout(partition_fragment_C(
|
||||
TiledMMA_QK_sQ{},
|
||||
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
|
||||
)));
|
||||
|
||||
struct SharedMemoryPlan {
|
||||
cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;
|
||||
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;
|
||||
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;
|
||||
cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
|
||||
cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;
|
||||
cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;
|
||||
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;
|
||||
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;
|
||||
TMABarrier barriers_K0[HEAD_DIM_K/64];
|
||||
TMABarrier barriers_K1[HEAD_DIM_K/64];
|
||||
TMABarrier barrier_Q;
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
template<
|
||||
typename ShapeQ, typename TMA_Q,
|
||||
typename ShapeK, typename TMA_K,
|
||||
typename ShapeO, typename TMA_O
|
||||
>
|
||||
struct TmaParams {
|
||||
ShapeQ shape_Q;
|
||||
TMA_Q tma_Q;
|
||||
ShapeK shape_K;
|
||||
TMA_K tma_K;
|
||||
ShapeO shape_O;
|
||||
TMA_O tma_O;
|
||||
};
|
||||
|
||||
enum NamedBarriers : int {
|
||||
sScale0Ready = 0,
|
||||
sScale1Ready = 1,
|
||||
sP0Ready = 2,
|
||||
rO1sP0sV0RIssued = 3
|
||||
};
|
||||
32
csrc/kernels/utils.h
Normal file
32
csrc/kernels/utils.h
Normal file
@@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if (status_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
||||
|
||||
|
||||
#define FLASH_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
|
||||
#define FLASH_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); }
|
||||
Reference in New Issue
Block a user