mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix AR bugs for normal kernels
This commit is contained in:
@@ -14,8 +14,8 @@ class Buffer:
|
||||
"""
|
||||
The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:
|
||||
- high-throughput intranode all-to-all (dispatch and combine, using NVLink)
|
||||
- high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
|
||||
- low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
|
||||
- high-throughput internode all-to-all (dispatch and combine, using RDMA and NVLink)
|
||||
- low-latency all-to-all (dispatch and combine, using RDMA)
|
||||
|
||||
Attributes:
|
||||
num_sms: the SMs used in high-throughput kernels.
|
||||
@@ -78,10 +78,6 @@ class Buffer:
|
||||
# NOTES: NVSHMEM initialization requires at least 256 MiB
|
||||
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
|
||||
|
||||
# Disable PCIe relaxed ordering to avoid out-of-order messages
|
||||
os.environ['NVSHMEM_IB_ENABLE_RELAXED_ORDERING'] = '0'
|
||||
|
||||
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
|
||||
# Synchronize using the root ID
|
||||
nvshmem_unique_ids = [None, ] * self.group_size
|
||||
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
|
||||
@@ -247,7 +243,7 @@ class Buffer:
|
||||
Dispatch tokens to different ranks, both intranode and internode settings are supported.
|
||||
Intranode kernels require all the ranks should be visible via NVLink.
|
||||
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
|
||||
index should be visible via RDMA. AR must be disabled.
|
||||
index should be visible via RDMA.
|
||||
|
||||
Arguments:
|
||||
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
|
||||
@@ -319,7 +315,7 @@ class Buffer:
|
||||
settings are supported.
|
||||
Intranode kernels require all the ranks should be visible via NVLink.
|
||||
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
|
||||
index should be visible via RDMA. AR must be disabled.
|
||||
index should be visible via RDMA.
|
||||
|
||||
Arguments:
|
||||
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.
|
||||
|
||||
Reference in New Issue
Block a user