mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support statistics tensor for low-latency kernels (#196)
This commit is contained in:
@@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
@@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
|
||||
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
if (cumulative_local_expert_recv_stats.has_value()) {
|
||||
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
|
||||
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous());
|
||||
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
|
||||
}
|
||||
|
||||
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
|
||||
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
|
||||
@@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
|
||||
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
|
||||
packed_recv_count.data_ptr<int>(),
|
||||
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
|
||||
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
|
||||
buffer.dispatch_rdma_send_buffer,
|
||||
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
|
||||
|
||||
Reference in New Issue
Block a user