Add automatic warp count control for low-latency kernels (#213)

* Add automatic warp count control for low-latency dispatch

* Add automatic warp count control for low-latency combine

* More assertions
This commit is contained in:
Chenggang Zhao
2025-06-16 11:56:43 +08:00
committed by GitHub
parent 4e923188f7
commit 1b92be8a71
6 changed files with 83 additions and 65 deletions

View File

@@ -333,18 +333,18 @@ __device__ __forceinline__ void tma_store_wait() {
#endif
template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
return ceil_div<dtype_t>(a, b) * b;
}
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
int& token_start_idx, int& token_end_idx) {
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
int num_tokens_per_sm = ceil_div(num_tokens, num_sms);
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
}