mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Add automatic warp count control for low-latency kernels (#213)
* Add automatic warp count control for low-latency dispatch * Add automatic warp count control for low-latency combine * More assertions
This commit is contained in:
@@ -333,18 +333,18 @@ __device__ __forceinline__ void tma_store_wait() {
|
||||
#endif
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
|
||||
return cell_div<dtype_t>(a, b) * b;
|
||||
return ceil_div<dtype_t>(a, b) * b;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
|
||||
int& token_start_idx, int& token_end_idx) {
|
||||
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
|
||||
int num_tokens_per_sm = ceil_div(num_tokens, num_sms);
|
||||
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
|
||||
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user