Add automatic warp count control for low-latency kernels (#213)

* Add automatic warp count control for low-latency dispatch

* Add automatic warp count control for low-latency combine

* More assertions
This commit is contained in:
Chenggang Zhao
2025-06-16 11:56:43 +08:00
committed by GitHub
parent 4e923188f7
commit 1b92be8a71
6 changed files with 83 additions and 65 deletions

View File

@@ -148,7 +148,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int* usage_flag,
cudaStream_t stream, int phases);
int num_device_sms, cudaStream_t stream,
int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
@@ -158,7 +159,8 @@ void combine(void* combined_x,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, int* usage_flag,
cudaStream_t stream, int phases, bool zero_copy);
int num_device_sms, cudaStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll