Support CUDA graph for intranode normal kernels (#203)

This commit is contained in:
Chenggang Zhao
2025-06-11 11:08:54 +08:00
committed by GitHub
parent 8da2d7b38d
commit a8299ca7c2
7 changed files with 86 additions and 38 deletions

View File

@@ -249,7 +249,8 @@ class Buffer:
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,
expert_alignment: int = 1, num_worst_tokens: int = 0,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
@@ -276,6 +277,8 @@ class Buffer:
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
will be CUDA-graph compatible. Please also notice that this flag is for intranode only.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
@@ -296,6 +299,7 @@ class Buffer:
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
@@ -308,14 +312,16 @@ class Buffer:
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
x, x_scales, None, None,
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
expert_alignment, num_worst_tokens, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
expert_alignment, num_worst_tokens, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)