mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Initial commit
This commit is contained in:
534
deep_ep/buffer.py
Normal file
534
deep_ep/buffer.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Callable, List, Tuple, Optional, Union
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep_cpp
|
||||
# noinspection PyUnresolvedReferences
|
||||
from deep_ep_cpp import Config, EventHandle
|
||||
from .utils import EventOverlap
|
||||
|
||||
|
||||
class Buffer:
|
||||
"""
|
||||
The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:
|
||||
- high-throughput intranode all-to-all (dispatch and combine, using NVLink)
|
||||
- high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
|
||||
- low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
|
||||
|
||||
Attributes:
|
||||
num_sms: the SMs used in high-throughput kernels.
|
||||
rank: the local rank number.
|
||||
group_size: the number of ranks in the group.
|
||||
group: the communication group.
|
||||
num_nvl_bytes: the buffer size for intranode NVLink communication.
|
||||
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
|
||||
runtime: the C++ runtime.
|
||||
"""
|
||||
|
||||
num_sms: int = 20
|
||||
|
||||
def __init__(self, group: dist.ProcessGroup,
|
||||
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
|
||||
"""
|
||||
Initialize the communication buffer.
|
||||
|
||||
Arguments:
|
||||
group: the communication group.
|
||||
num_nvl_bytes: the buffer size for intranode NVLink communication.
|
||||
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
|
||||
low_latency_mode: whether to enable low-latency mode.
|
||||
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
|
||||
to the number of local experts.
|
||||
"""
|
||||
|
||||
# TODO: argument docs
|
||||
# Initialize the CPP runtime
|
||||
self.rank = group.rank()
|
||||
self.group_size = group.size()
|
||||
self.group = group
|
||||
self.num_nvl_bytes = num_nvl_bytes
|
||||
self.num_rdma_bytes = num_rdma_bytes
|
||||
self.low_latency_mode = low_latency_mode
|
||||
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
|
||||
|
||||
# Synchronize device IDs
|
||||
device_ids = [None, ] * self.group_size
|
||||
local_device_id = self.runtime.get_local_device_id()
|
||||
dist.all_gather_object(device_ids, local_device_id, group)
|
||||
|
||||
# Synchronize IPC handles
|
||||
ipc_handles = [None, ] * self.group_size
|
||||
local_ipc_handle = self.runtime.get_local_ipc_handle()
|
||||
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
|
||||
|
||||
# Synchronize NVSHMEM unique IDs
|
||||
root_unique_id = None
|
||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
|
||||
if low_latency_mode:
|
||||
assert num_qps_per_rank > 0
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
||||
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
|
||||
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
|
||||
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
|
||||
# NOTES: NVSHMEM initialization requires at least 256 MiB
|
||||
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
|
||||
|
||||
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
|
||||
# Synchronize using the root ID
|
||||
nvshmem_unique_ids = [None, ] * self.group_size
|
||||
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
|
||||
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
|
||||
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
|
||||
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
|
||||
|
||||
# Make CPP runtime available
|
||||
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
|
||||
assert self.runtime.is_available()
|
||||
|
||||
@staticmethod
|
||||
def set_num_sms(new_num_sms: int) -> None:
|
||||
"""
|
||||
Set the number of SMs to use in high-throughput kernels.
|
||||
|
||||
Arguments:
|
||||
new_num_sms: the new number to be set.
|
||||
"""
|
||||
|
||||
assert new_num_sms % 2 == 0, 'The SM count must be even'
|
||||
Buffer.num_sms = new_num_sms
|
||||
|
||||
@staticmethod
|
||||
def capture() -> EventOverlap:
|
||||
"""
|
||||
Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`.
|
||||
|
||||
Returns:
|
||||
event: the captured event.
|
||||
"""
|
||||
return EventOverlap(EventHandle())
|
||||
|
||||
@staticmethod
|
||||
def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int) -> int:
|
||||
"""
|
||||
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
|
||||
|
||||
Arguments:
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
hidden: the hidden dimension of each token.
|
||||
num_ranks: the number of EP group ranks.
|
||||
num_experts: the number of all experts.
|
||||
|
||||
Returns:
|
||||
size: the RDMA buffer size recommended.
|
||||
"""
|
||||
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
|
||||
|
||||
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
|
||||
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Get the raw buffer (slice supported) as a PyTorch tensor.
|
||||
|
||||
Argument:
|
||||
dtype: the data type (PyTorch `dtype`) for the tensor.
|
||||
size: the slice size (by elements) to get from the buffer.
|
||||
offset: the offset of the beginning element.
|
||||
use_rdma_buffer: whether to return the RDMA buffer.
|
||||
"""
|
||||
tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer)
|
||||
if size is None:
|
||||
return tensor
|
||||
|
||||
assert tensor.numel() >= size.numel()
|
||||
return tensor[:size.numel()].view(size)
|
||||
|
||||
@staticmethod
|
||||
def get_dispatch_config(num_ranks: int) -> Config:
|
||||
"""
|
||||
Get a recommended dispatch config.
|
||||
|
||||
Argument:
|
||||
num_ranks: the number of ranks.
|
||||
|
||||
Returns:
|
||||
config: the recommended config.
|
||||
"""
|
||||
# Intranode
|
||||
if num_ranks <= 8:
|
||||
return Config(Buffer.num_sms, 6, 256, 6, 128)
|
||||
|
||||
# Internode
|
||||
config_map = {
|
||||
16: Config(Buffer.num_sms, 16, 288, 20, 128),
|
||||
24: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
32: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
64: Config(Buffer.num_sms, 20, 288, 28, 128),
|
||||
128: Config(Buffer.num_sms, 20, 560, 32, 128),
|
||||
144: Config(Buffer.num_sms, 32, 720, 12, 128),
|
||||
160: Config(Buffer.num_sms, 28, 720, 12, 128),
|
||||
}
|
||||
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
|
||||
return config_map[num_ranks]
|
||||
|
||||
@staticmethod
|
||||
def get_combine_config(num_ranks: int) -> Config:
|
||||
"""
|
||||
Get a recommended combine config.
|
||||
|
||||
Argument:
|
||||
num_ranks: the number of ranks.
|
||||
|
||||
Returns:
|
||||
config: the recommended config.
|
||||
"""
|
||||
# Intranode
|
||||
if num_ranks <= 8:
|
||||
return Config(Buffer.num_sms, 6, 256, 6, 128)
|
||||
|
||||
# Internode
|
||||
config_map = {
|
||||
16: Config(Buffer.num_sms, 2, 288, 28, 128),
|
||||
24: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
32: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
64: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
128: Config(Buffer.num_sms, 1, 560, 12, 128),
|
||||
144: Config(Buffer.num_sms, 2, 720, 8, 128),
|
||||
160: Config(Buffer.num_sms, 2, 720, 8, 128),
|
||||
}
|
||||
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
|
||||
return config_map[num_ranks]
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]:
|
||||
"""
|
||||
Calculate the layout required for later communication.
|
||||
|
||||
Arguments:
|
||||
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token,
|
||||
`-1` means no selections.
|
||||
num_experts: the number of experts.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
|
||||
|
||||
Returns:
|
||||
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
|
||||
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
|
||||
rank (with the same GPU index), return `None` for intranode settings.
|
||||
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
|
||||
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
"""
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
|
||||
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
|
||||
async_finish, allocate_on_comm_stream)
|
||||
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
handle: Optional[Tuple] = None,
|
||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
|
||||
"""
|
||||
Dispatch tokens to different ranks, both intranode and internode settings are supported.
|
||||
Intranode kernels require all the ranks should be visible via NVLink.
|
||||
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
|
||||
index should be visible via RDMA. AR must be disabled.
|
||||
|
||||
Arguments:
|
||||
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
|
||||
and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as
|
||||
`[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]`
|
||||
(requiring divisible) with type `torch.float`.
|
||||
handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time.
|
||||
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
|
||||
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
|
||||
rank (with the same GPU index), return `None` for intranode settings.
|
||||
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
|
||||
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
|
||||
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
|
||||
`-1` means no selections.
|
||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
|
||||
expert_alignment: align the number of tokens received by each local expert to this variable.
|
||||
config: the performance tuning config.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
|
||||
|
||||
Returns:
|
||||
recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the
|
||||
received token count.
|
||||
recv_topk_idx: received expert indices.
|
||||
recv_topk_weights: received expert weights.
|
||||
num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by
|
||||
each local expert, aligned to the input `expert_alignment`.
|
||||
handle: the returned communication handle.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
"""
|
||||
# Default config
|
||||
config = self.get_dispatch_config(self.group_size) if config is None else config
|
||||
|
||||
# Internode
|
||||
if self.runtime.get_num_rdma_ranks() > 1:
|
||||
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||
|
||||
# Launch the kernel with cached or non-cached mode
|
||||
x, x_scales = x if isinstance(x, tuple) else (x, None)
|
||||
if handle is not None:
|
||||
assert topk_idx is None and topk_weights is None
|
||||
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle
|
||||
num_recv_tokens = recv_src_idx.size(0)
|
||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
|
||||
x, x_scales, None, None,
|
||||
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||
else:
|
||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
|
||||
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def combine(self, x: torch.Tensor, handle: Tuple,
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
|
||||
"""
|
||||
Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode
|
||||
settings are supported.
|
||||
Intranode kernels require all the ranks should be visible via NVLink.
|
||||
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
|
||||
index should be visible via RDMA. AR must be disabled.
|
||||
|
||||
Arguments:
|
||||
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.
|
||||
handle: a must-set communication handle, you can obtain this from the dispatch function.
|
||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks.
|
||||
config: the performance tuning config.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
|
||||
|
||||
Returns:
|
||||
recv_x: the reduced token from its dispatched ranks.
|
||||
recv_topk_weights: the reduced top-k weights from its dispatch ranks.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
"""
|
||||
# Default config
|
||||
config = self.get_combine_config(self.group_size) if config is None else config
|
||||
|
||||
# Internode
|
||||
if self.runtime.get_num_rdma_ranks() > 1:
|
||||
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||
|
||||
# NOTES: the second `_` is for the sending side, so we should use the third one
|
||||
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
|
||||
|
||||
# Launch the kernel
|
||||
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
|
||||
x, topk_weights,
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
|
||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return recv_x, recv_topk_weights, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
handle: Optional[Tuple] = None,
|
||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
|
||||
"""
|
||||
Internode dispatch implementation, for more details, please refer to the `dispatch` docs.
|
||||
Normally, you should not directly call this function.
|
||||
"""
|
||||
assert config is not None
|
||||
|
||||
# Launch the kernel with cached or non-cached mode
|
||||
x, x_scales = x if isinstance(x, tuple) else (x, None)
|
||||
if handle is not None:
|
||||
assert topk_idx is None and topk_weights is None
|
||||
is_token_in_rank, \
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
recv_src_meta, send_rdma_head, send_nvl_head = handle
|
||||
num_recv_tokens = recv_src_meta.size(0)
|
||||
num_rdma_recv_tokens = send_nvl_head.size(0)
|
||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
|
||||
x, x_scales, topk_idx, topk_weights,
|
||||
None, None, is_token_in_rank, None,
|
||||
num_recv_tokens, num_rdma_recv_tokens,
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||
else:
|
||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch(
|
||||
x, x_scales, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
0, 0, None, None, None, None,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
handle = (is_token_in_rank,
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta, send_rdma_head, send_nvl_head)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
|
||||
"""
|
||||
Internode combine implementation, for more details, please refer to the `combine` docs.
|
||||
Normally, you should not directly call this function.
|
||||
"""
|
||||
assert config is not None
|
||||
|
||||
# Unpack handle
|
||||
is_combined_token_in_rank, \
|
||||
_, _, \
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
|
||||
src_meta, send_rdma_head, send_nvl_head = handle
|
||||
|
||||
# Launch the kernel
|
||||
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
|
||||
x, topk_weights,
|
||||
src_meta, is_combined_token_in_rank,
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
|
||||
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
|
||||
async_finish, allocate_on_comm_stream)
|
||||
return combined_x, combined_topk_weights, EventOverlap(event)
|
||||
|
||||
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
|
||||
"""
|
||||
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
|
||||
if the buffer is dirty at some time.
|
||||
For example, after running the normal dispatch/combine, you must run this function before executing any
|
||||
low-latency kernel.
|
||||
|
||||
Arguments:
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
hidden: the hidden dimension of each token.
|
||||
num_experts: the number of all experts.
|
||||
"""
|
||||
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||
async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for dispatching with IBGDA **with implicit FP8 casting**.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
low-latency kernels' result tensor at a single moment.
|
||||
|
||||
Arguments:
|
||||
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
|
||||
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
|
||||
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes
|
||||
are supported. `-1` indices (not selecting any expert) are supported.
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
num_experts: the number of all experts.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
|
||||
The second tensor is the corresponding scales for the first element with shape
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
|
||||
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
|
||||
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
packed_recv_src_info, packed_recv_layout_range)
|
||||
return (packed_recv_x, packed_recv_x_scales), packed_recv_count, handle, \
|
||||
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[torch.Tensor, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
low-latency kernels' result tensor at a single moment.
|
||||
|
||||
Arguments:
|
||||
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
|
||||
the local calculated tokens to be sent to this original rank and reduced.
|
||||
topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched
|
||||
tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals
|
||||
to the number of dispatched tokens.
|
||||
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
|
||||
tokens. The received tokens will be reduced with the weights in this tensor.
|
||||
handle: the communication handle given by the `dispatch` function.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
||||
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
Reference in New Issue
Block a user