Support CUDA graph for intranode normal kernels (#203)

This commit is contained in:
Chenggang Zhao
2025-06-11 11:08:54 +08:00
committed by GitHub
parent 8da2d7b38d
commit a8299ca7c2
7 changed files with 86 additions and 38 deletions

View File

@@ -162,6 +162,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
allocate_on_comm_stream=previous_event is not None)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# Unless you specify `num_worst_tokens`, but this flag is for intranode only
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,