Support statistics tensor for low-latency kernels (#196)

This commit is contained in:
Chenggang Zhao
2025-06-09 15:50:56 +08:00
committed by GitHub
parent 0d1a855d81
commit 5a2e37fa28
6 changed files with 27 additions and 3 deletions

View File

@@ -142,6 +142,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook);