diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index fb53f79..d4940f1 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -630,79 +630,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla>(params, stream); } - -static constexpr int MaxBatchSize = 4096; - -__global__ void __launch_bounds__(256, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { - int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; - int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; - int block_size_n = params.block_size_n; - int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; - int num_sm_parts = params.num_sm_parts; - - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; - - int total_num_blocks = 0; - for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); - total_num_blocks += num_blocks + fixed_overhead_num_blocks; - num_blocks_shared[i] = num_blocks; - } - for (int offset = 16; offset >= 1; offset /= 2) { - total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); - } - __syncwarp(); - - if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; - num_splits_shared[0] = 0; - for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; - tile_scheduler_metadata1 = now_n_split_idx; - int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; - int now_remain_blocks = num_blocks - now_block; - if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { - cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; - remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; - now_block = 0; - now_n_split_idx = 0; - } else { - if (remain_payload - fixed_overhead_num_blocks > 0) { - now_block += remain_payload - fixed_overhead_num_blocks; - ++now_n_split_idx; - remain_payload = 0; - } - break; - } - } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; - } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); - } - __syncwarp(); - - for (int i = threadIdx.x; i <= batch_size; i += 32) { - num_splits_ptr[i] = num_splits_shared[i]; - } -} - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); -} diff --git a/csrc/flash_mla_utils.cu b/csrc/flash_mla_utils.cu new file mode 100644 index 0000000..38c74e4 --- /dev/null +++ b/csrc/flash_mla_utils.cu @@ -0,0 +1,85 @@ +#include +#include +#include + +using namespace cute; + +#include "flash_mla.h" +#include "static_switch.h" +#include "utils.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 8b11e00..c184953 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ ext_modules.append( name="flash_mla_cuda", sources=[ "csrc/flash_api.cpp", + "csrc/flash_mla_utils.cu", "csrc/flash_fwd_mla_bf16_sm90.cu", "csrc/flash_fwd_mla_fp8_sm90.cu", ],