mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support CUDA graph for intranode normal kernels (#203)
This commit is contained in:
@@ -249,7 +249,8 @@ class Buffer:
|
||||
handle: Optional[Tuple] = None,
|
||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None,
|
||||
expert_alignment: int = 1, num_worst_tokens: int = 0,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
@@ -276,6 +277,8 @@ class Buffer:
|
||||
`-1` means no selections.
|
||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
|
||||
expert_alignment: align the number of tokens received by each local expert to this variable.
|
||||
num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it
|
||||
will be CUDA-graph compatible. Please also notice that this flag is for intranode only.
|
||||
config: the performance tuning config.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
@@ -296,6 +299,7 @@ class Buffer:
|
||||
|
||||
# Internode
|
||||
if self.runtime.get_num_rdma_ranks() > 1:
|
||||
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
|
||||
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||
|
||||
@@ -308,14 +312,16 @@ class Buffer:
|
||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
|
||||
x, x_scales, None, None,
|
||||
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
expert_alignment, num_worst_tokens, config,
|
||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||
else:
|
||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
|
||||
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
||||
expert_alignment, num_worst_tokens, config,
|
||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user