Support statistics tensor for low-latency kernels (#196)

This commit is contained in:
Chenggang Zhao
2025-06-09 15:50:56 +08:00
committed by GitHub
parent 0d1a855d81
commit 5a2e37fa28
6 changed files with 27 additions and 3 deletions

View File

@@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
@@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(num_experts % num_ranks == 0);
if (cumulative_local_expert_recv_stats.has_value()) {
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
@@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),