FlashMLA/csrc/kernels/params.h
Shengyu Liu c2067be3ea
Performance Update (2025.04.22) (#71)
* Fix benchmark script

* Performance optimization for compute-bound cases

* Add new testcase (s_k = 16384)

* Update README.md

* Update comment

* Update README.md

* Add the deep-dive blog

* Add background color for MLA Kernel Sched.drawio.svg

* Use relative path for the schedule image

* Move flash_mla.h to kernels/params.h
2025-04-22 17:50:57 +08:00

59 lines
1.7 KiB
C

#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_mla_params {
using index_t = int64_t;
int b; // batch size
int s_q;
int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q
int d, d_v; // K/V dimension
int h_q, h_k; // The number of Q/K heads
int num_blocks; // Number of blocks in total
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int batch_size;
int block_size_n;
int fixed_overhead_num_blocks;
int num_sm_parts;
};