Merge pull request #45 from deepseek-ai/ar-support

Fix AR bugs for normal kernels
This commit is contained in:
Chenggang Zhao 2025-03-06 09:48:17 +08:00 committed by GitHub
commit 41385ba5b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 17 additions and 21 deletions

View File

@ -91,9 +91,7 @@ For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_
### Adaptive routing
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Currently, low-latency kernels support adaptive routing, while normal kernels do not (support may be added soon). **Enabling adaptive routing for normal internode kernels may lead to deadlocks or data corruption issues**.
For low-latency kernels, enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
- enable adaptive routing in environments with heavy network loads
- use static routing in environments with light network loads
@ -134,7 +132,6 @@ def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
# NOTES: the adaptive routing configuration of the network **must be off**
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer
@ -282,7 +279,8 @@ For two micro-batch overlapping, you can refer to the following figure. With our
## Roadmap
- [ ] AR support (releasing soon)
- [x] AR support
- [ ] Refactor low-latency mode AR code
- [ ] A100 support (intranode only)
- [ ] Support BF16 for the low-latency dispatch kernel
- [ ] Support NVLink protocol for intranode low-latency kernels

View File

@ -372,7 +372,6 @@ nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
__device__ static __forceinline__ void
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
// NOTES: only one thread can run this function
// TODO: consider this assertion for normal AR
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
}

View File

@ -925,9 +925,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
break;
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks)
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_head = min_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head;
}
// Nanosleep and let other warps work
__nanosleep(NUM_WAIT_NANOSECONDS);
@ -1653,9 +1655,11 @@ combine(int4* combined_x, float* combined_topk_weights,
#pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks)
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_rdma_head = min_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;
}
} else {
// Find minimum head for NVL ranks
#pragma unroll

View File

@ -14,8 +14,8 @@ class Buffer:
"""
The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:
- high-throughput intranode all-to-all (dispatch and combine, using NVLink)
- high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
- low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
- high-throughput internode all-to-all (dispatch and combine, using RDMA and NVLink)
- low-latency all-to-all (dispatch and combine, using RDMA)
Attributes:
num_sms: the SMs used in high-throughput kernels.
@ -78,10 +78,6 @@ class Buffer:
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
# Disable PCIe relaxed ordering to avoid out-of-order messages
os.environ['NVSHMEM_IB_ENABLE_RELAXED_ORDERING'] = '0'
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
@ -247,7 +243,7 @@ class Buffer:
Dispatch tokens to different ranks, both intranode and internode settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA. AR must be disabled.
index should be visible via RDMA.
Arguments:
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
@ -319,7 +315,7 @@ class Buffer:
settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA. AR must be disabled.
index should be visible via RDMA.
Arguments:
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.

View File

@ -218,7 +218,6 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
# Please make sure AR (Adaptive Routing) is turned off when running normal internode kernels,
num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False