Update README

This commit is contained in:
Chenggang Zhao 2025-06-05 14:41:51 +08:00
parent de8cfca3cf
commit d8dd185c68

View File

@ -1,6 +1,6 @@
# DeepEP
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also as known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
@ -29,7 +29,7 @@ We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400
| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|
| 8 | 77 us | 98 GB/s | 8 | 114 us | 127 GB/s |
| 8 | 77 us | 98 GB/s | 8 | 114 us | 127 GB/s |
| 16 | 118 us | 63 GB/s | 16 | 195 us | 74 GB/s |
| 32 | 155 us | 48 GB/s | 32 | 273 us | 53 GB/s |
| 64 | 173 us | 43 GB/s | 64 | 314 us | 46 GB/s |
@ -243,7 +243,7 @@ def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int,
# Allocate a buffer if not existed or not enough buffer size
if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
# NOTES: for the best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % group.size() == 0
_buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
return _buffer
@ -277,7 +277,7 @@ def low_latency_combine(hidden_states: torch.Tensor,
return combined_hidden_states, event_overlap, hook
```
For two micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffics are happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e. the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
For two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
![low-latency](figures/low-latency.png)
@ -287,14 +287,15 @@ For two micro-batch overlapping, you can refer to the following figure. With our
- [x] Refactor low-latency mode AR code
- [ ] A100 support (intranode only)
- [x] Support BF16 for the low-latency dispatch kernel
- [x] ~~Support NVLink protocol for intranode low-latency kernels~~ (conflict with hook-based overlapping)
- [ ] SM-free normal kernels
- [x] Support NVLink protocol for intranode low-latency kernels
- [ ] TMA copy instead of LD/ST
- [ ] SM-free normal kernels and refactors
## Notices
#### Easier potential overall design
Current DeepEP implementation uses queues for communication buffers which saves memory but introduces complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39.
The current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39.
#### Undefined-behavior PTX usage
@ -312,11 +313,11 @@ This code repository is released under [the MIT License](LICENSE), except for co
## Community Forks
- [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-qp solution and dual-port NIC support in IBRC transport
- [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-QP solution and dual-port NIC support in IBRC transport
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
If you use this codebase or otherwise find our work valuable, please cite:
```bibtex
@misc{deepep2025,