Add automatic warp count control for low-latency dispatch

This commit is contained in:
Chenggang Zhao
2025-06-16 11:31:38 +08:00
parent 4e923188f7
commit 632c81f1d7
6 changed files with 59 additions and 45 deletions

View File

@@ -41,6 +41,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
// Get device info
cudaDeviceProp device_prop = {};
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
num_device_sms = device_prop.multiProcessorCount;
if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
@@ -1142,7 +1143,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
workspace, low_latency_usage_flag_mapped, launch_stream, phases);
workspace, low_latency_usage_flag_mapped,
num_device_sms, launch_stream,
phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
@@ -1234,7 +1237,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, low_latency_usage_flag_mapped, launch_stream,
workspace, low_latency_usage_flag_mapped,
num_device_sms, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));