Initial commit

This commit is contained in:
Chenggang Zhao
2025-02-24 21:56:14 +08:00
commit ebfe47e46f
32 changed files with 8461 additions and 0 deletions

8
.gitignore vendored Normal file
View File

@@ -0,0 +1,8 @@
compile_commands.json
.idea
.DS_Store
*.pyc
build/
.cache/
.vscode/
*/cmake-build-*/

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

304
README.md Normal file
View File

@@ -0,0 +1,304 @@
# DeepEP
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also as known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.
Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper.
## Performance
### Normal kernels with NVLink and RDMA forwarding
We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
### Low-latency kernels with pure RDMA
We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).
| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|
| 8 | 163 us | 46 GB/s | 8 | 318 us | 46 GB/s |
| 16 | 173 us | 43 GB/s | 16 | 329 us | 44 GB/s |
| 32 | 182 us | 41 GB/s | 32 | 350 us | 41 GB/s |
| 64 | 186 us | 40 GB/s | 64 | 353 us | 41 GB/s |
| 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s |
| 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s |
## Quick start
### Requirements
- Hopper GPUs (may support more architectures or devices later)
- Python 3.8 and above
- CUDA 12.3 and above
- PyTorch 2.1 and above
- NVLink for intranode communication
- RDMA network for internode communication
### Download and install NVSHMEM dependency
DeepEP also depends on our modified NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions.
### Development
```bash
# Build and make symbolic links for SO files
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build
# You may modify the specific SO names according to your own platform
ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so
# Run test cases
# NOTES: you may modify the `init_dist` function in `tests/utils.py`
# according to your own cluster settings, and launch into multiple nodes
python tests/test_intranode.py
python tests/test_internode.py
python tests/test_low_latency.py
```
### Installation
```bash
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
```
Then, import `deep_ep` in your Python project, and enjoy!
## Network configurations
DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.
### Traffic isolation
Traffic isolation is supported by InfiniBand through Virtual Lanes (VL).
To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:
- workloads using normal kernels
- workloads using low-latency kernels
- other workloads
For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable.
### Adaptive routing
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Currently, low-latency kernels support adaptive routing, while normal kernels do not (support may be added soon). **Enabling adaptive routing for normal internode kernels may lead to deadlocks or data corruption issues**.
For low-latency kernels, enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
- enable adaptive routing in environments with heavy network loads
- use static routing in environments with light network loads
### Congestion control
Congestion control is disabled as we have not observed significant congestion in our production environment.
## Interfaces and examples
### Example use in model training or inference prefilling
The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.
```python
import torch
import torch.distributed as dist
from typing import List, Tuple, Optional, Union
from deep_ep import Buffer, EventOverlap
# Communication buffer (will allocate at runtime)
_buffer: Optional[Buffer] = None
# Set the number of SMs to use
# NOTES: this is a static variable
Buffer.set_num_sms(24)
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
global _buffer
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
# NOTES: the adaptive routing configuration of the network **must be off**
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer
def get_hidden_bytes(x: torch.Tensor) -> int:
t = x[0] if isinstance(x, tuple) else x
return t.size(1) * max(t.element_size(), 2)
def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
global _buffer
# Calculate layout before actual dispatch
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
_buffer.get_dispatch_layout(topk_idx, num_experts,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=previous_event is not None)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
# For more advanced usages, please refer to the docs of the `dispatch` function
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=True)
# For event management, please refer to the docs of the `EventOverlap` class
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
global _buffer
# The backward process of MoE dispatch is actually a combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_grad_x, combined_grad_recv_topk_weights, event = \
_buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_grad_x, combined_grad_recv_topk_weights, event
def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[torch.Tensor, EventOverlap]:
global _buffer
# Do MoE combine
# For more advanced usages, please refer to the docs of the `combine` function
combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
# For event management, please refer to the docs of the `EventOverlap` class
return combined_x, event
def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
global _buffer
# The backward process of MoE combine is actually a dispatch
# For more advanced usages, please refer to the docs of the `combine` function
grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
# For event management, please refer to the docs of the `EventOverlap` class
return grad_x, event
```
Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.
![normal](figures/normal.png)
### Example use in inference decoding
The low latency kernels can be used in the inference decoding phase as the below example code shows.
```python
import torch
import torch.distributed as dist
from typing import Tuple, Optional
from deep_ep import Buffer
# Communication buffer (will allocate at runtime)
# NOTES: there is no SM control API for the low-latency kernels
_buffer: Optional[Buffer] = None
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:
# NOTES: the low-latency mode will consume much more space than the normal mode
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
global _buffer
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)
# Allocate a buffer if not existed or not enough buffer size
if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
assert num_experts % group.size() == 0
_buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
return _buffer
def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
global _buffer
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_expert_count, handle, event, hook = \
_buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
async_finish=False, return_recv_hook=True)
# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle, event, hook
def low_latency_combine(hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):
global _buffer
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
combined_hidden_states, event_overlap, hook = \
_buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,
async_finish=False, return_recv_hook=True)
# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states, event_overlap, hook
```
For two micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffics are happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e. the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
![low-latency](figures/low-latency.png)
## Notices
- For extreme performance, we discover and use an out-of-doc PTX instruction: `ld.global.nc.L1::no_allocate.L2::256B`. This instruction will lead to an undefined behavior: accessing volatile GPU memory with non-coherent read-only PTX modifiers `.nc`. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
- For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
## License
This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html).
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```bibtex
@misc{deepep2025,
title={DeepEP: an efficient expert-parallel communication library},
author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepEP}},
}
```

33
csrc/CMakeLists.txt Normal file
View File

@@ -0,0 +1,33 @@
# NOTES: this CMake is only for debugging; for setup, please use Torch extension
cmake_minimum_required(VERSION 3.10)
project(deep_ep LANGUAGES CUDA CXX)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(TORCH_CUDA_ARCH_LIST "9.0")
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem)
add_library(nvshmem ALIAS nvshmem::nvshmem)
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
# Seems bugs with CMake, NVCC 12 and C++ 17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 14)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
add_subdirectory(kernels)
# Link CPP and CUDA together
pybind11_add_module(deep_ep_cpp deep_ep.cpp)
target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python)

177
csrc/config.hpp Normal file
View File

@@ -0,0 +1,177 @@
#pragma once
#include "kernels/api.cuh"
#include "kernels/exception.cuh"
namespace deep_ep {
template <typename dtype_t>
dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
}
struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;
Config(int num_sms,
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) :
num_sms(num_sms),
num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
EP_HOST_ASSERT(num_sms >= 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
return 0;
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}
};
struct LowLatencyBuffer {
int num_clean_int = 0;
void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
int* dispatch_rdma_atomic_token_counter = nullptr;
void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
int* combine_rdma_recv_flag_buffer = nullptr;
std::pair<int*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
}
};
struct LowLatencyLayout {
size_t total_bytes = 0;
LowLatencyBuffer buffers[2];
template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;
const int num_local_experts = num_experts / num_ranks;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
// - 2 symmetric odd/even signaling buffers
// Message sizes
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4);
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
// Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
total_bytes += send_buffer_bytes * 2;
// Symmetric receive buffers
// TODO: optimize memory usages
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
total_bytes += recv_buffer_bytes * 2;
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes,
combine_recv_flag_buffer_bytes);
total_bytes += signaling_buffer_bytes * 2;
// Assign pointers
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
// so you may see some parameters are duplicated
for (int i = 0; i < 2; ++ i) {
buffers[i] = {
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
advance(rdma_buffer, send_buffer_bytes * i),
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
};
}
}
};
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}
} // namespace deep_ep

1208
csrc/deep_ep.cpp Normal file

File diff suppressed because it is too large Load Diff

149
csrc/deep_ep.hpp Normal file
View File

@@ -0,0 +1,149 @@
#pragma once
// Forcibly disable NDEBUG
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>
#include <tuple>
#include <vector>
#include "config.hpp"
#include "event.hpp"
#include "kernels/configs.cuh"
#include "kernels/exception.cuh"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif
namespace deep_ep {
struct Buffer {
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
private:
// Low-latency mode buffer
int low_latency_buffer_idx = 0;
bool low_latency_mode = false;
// NVLink Buffer
int64_t num_nvl_bytes;
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void* rdma_buffer_ptr = nullptr;
// Device info and communication
int device_id;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
// Stream for communication
at::cuda::CUDAStream comm_stream;
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Task fifo
int head = 0;
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** task_fifo_ptrs_gpu = nullptr;
// Workspace
void* workspace = nullptr;
// Host-side MoE info
volatile int* moe_recv_counter = nullptr;
int* moe_recv_counter_mapped = nullptr;
// Host-side expert-level MoE info
volatile int* moe_recv_expert_counter = nullptr;
int* moe_recv_expert_counter_mapped = nullptr;
// Host-side RDMA-level MoE info
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;
private:
void move_fifo_slots(int num_slots = 1);
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
~Buffer() noexcept(false);
bool is_available() const;
bool is_internode_available() const;
int get_num_rdma_ranks() const;
int get_rdma_rank() const;
int get_root_rdma_rank(bool global) const;
int get_local_device_id() const;
pybind11::bytearray get_local_ipc_handle() const;
pybind11::bytearray get_local_nvshmem_unique_id() const;
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
};
} // namespace deep_ep

43
csrc/event.hpp Normal file
View File

@@ -0,0 +1,43 @@
#include <ATen/cuda/CUDAContext.h>
#include <memory>
#include "kernels/exception.cuh"
namespace deep_ep {
struct EventHandle {
std::shared_ptr<torch::Event> event;
EventHandle() {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(at::cuda::getCurrentCUDAStream());
}
explicit EventHandle(const at::cuda::CUDAStream& stream) {
event = std::make_shared<torch::Event>(torch::kCUDA);
event->record(stream);
}
EventHandle(const EventHandle& other) = default;
void current_stream_wait() const {
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
}
};
torch::Event create_event(const at::cuda::CUDAStream &s) {
auto event = torch::Event(torch::kCUDA);
event.record(s);
return event;
}
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
EP_HOST_ASSERT(s_0.id() != s_1.id());
s_0.unwrap().wait(create_event(s_1));
}
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
s.unwrap().wait(*event.event);
}
} // namespace deep_ep

View File

@@ -0,0 +1,20 @@
function(add_deep_ep_library target_name source_file)
add_library(${target_name} STATIC ${source_file})
set_target_properties(${target_name} PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD_REQUIRED ON
CXX_STANDARD 14
CUDA_STANDARD 14
CUDA_SEPARABLE_COMPILATION ON
)
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
endfunction()
add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(runtime_cuda runtime.cu)
add_deep_ep_library(internode_cuda internode.cu)
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)

153
csrc/kernels/api.cuh Normal file
View File

@@ -0,0 +1,153 @@
#pragma once
#include <vector>
namespace deep_ep {
// Intranode runtime
namespace intranode {
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
} // namespace intranode
// Internode runtime
namespace internode {
std::vector<uint8_t> get_unique_id();
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
void *alloc(size_t size, size_t alignment);
void free(void *ptr);
void barrier();
void finalize();
} // namespace internode
// Intranode kernels
namespace intranode {
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int num_sms);
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens);
} // namespace intranode
// Internode kernels
namespace internode {
int get_source_meta_bytes();
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream);
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
const bool* is_token_in_rank, int num_tokens, int num_channels,
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode);
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
const bool* is_token_in_rank,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, bool is_cached_dispatch,
cudaStream_t stream, int num_channels, bool low_latency_mode);
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode);
void combine(cudaDataType_t type,
void* combined_x, float* combined_topk_weights,
const bool* is_combined_token_in_rank,
const void* x, const float* topk_weights,
const int* combined_rdma_head, const int* combined_nvl_head,
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
} // namespace internode
// Internode low-latency kernels
namespace internode_ll {
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases);
} // namespace internode_ll
} // namespace deep_ep

138
csrc/kernels/buffer.cuh Normal file
View File

@@ -0,0 +1,138 @@
#pragma once
#include "configs.cuh"
#include "exception.cuh"
namespace deep_ep {
template <typename dtype_t>
struct Buffer {
private:
uint8_t* ptr;
public:
int total_bytes;
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer() {
return reinterpret_cast<dtype_t*>(ptr);
}
__device__ __forceinline__ dtype_t& operator[](int idx) {
return buffer()[idx];
}
};
template <typename dtype_t, int kNumRanks = 1>
struct AsymBuffer {
private:
uint8_t* ptrs[kNumRanks];
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++ i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
}
}
__device__ __forceinline__ void advance(int shift) {
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
}
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}
template<int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++ i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
return *this;
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
}
};
template <typename dtype_t, bool kDecoupled = true>
struct SymBuffer {
private:
// NOTES: for non-decoupled case, `recv_ptr` is not used
uint8_t* send_ptr;
uint8_t* recv_ptr;
int num_bytes;
public:
int total_bytes;
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
int sm_id = 0, int num_sms = 1) {
num_bytes = num_elems * sizeof(dtype_t);
int per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
}
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
}
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
}
};
} // namespace deep_ep

50
csrc/kernels/configs.cuh Normal file
View File

@@ -0,0 +1,50 @@
#pragma once
#define NUM_MAX_NVL_PEERS 8
#define NUM_MAX_RDMA_PEERS 20
#define NUM_MAX_FIFO_SLOTS 32768
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024
#define NUM_BUFFER_ALIGNMENT_BYTES 128
#define FINISHED_SUM_TAG 1024
#define NUM_CPU_TIMEOUT_SECS 100
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
#define NUM_WAIT_NANOSECONDS 500
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
// Make CLion CUDA indexing work
#ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif
// Remove Torch restrictions
#ifdef __CUDA_NO_HALF_CONVERSIONS__
#undef __CUDA_NO_HALF_CONVERSIONS__
#endif
#ifdef __CUDA_NO_HALF_OPERATORS__
#undef __CUDA_NO_HALF_OPERATORS__
#endif
#ifdef __CUDA_NO_HALF2_OPERATORS__
#undef __CUDA_NO_HALF2_OPERATORS__
#endif
#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__
#undef __CUDA_NO_BFLOAT16_CONVERSIONS__
#endif
#ifdef __CUDA_NO_BFLOAT162_OPERATORS__
#undef __CUDA_NO_BFLOAT162_OPERATORS__
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <infiniband/mlx5dv.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <device_host_transport/nvshmem_common_ibgda.h>

View File

@@ -0,0 +1,51 @@
#pragma once
#include <string>
#include <exception>
#include "configs.cuh"
#ifndef EP_STATIC_ASSERT
#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif
class EPException: public std::exception {
private:
std::string message = {};
public:
explicit EPException(const char *name, const char* file, const int line, const std::string& error) {
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
}
const char *what() const noexcept override { return message.c_str(); }
};
#ifndef CUDA_CHECK
#define CUDA_CHECK(cmd) \
do { \
cudaError_t e = (cmd); \
if (e != cudaSuccess) { \
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
} \
} while (0)
#endif
#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
throw EPException("Assertion", __FILE__, __LINE__, #cond); \
} \
} while (0)
#endif
#ifndef EP_DEVICE_ASSERT
#define EP_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif

View File

@@ -0,0 +1,423 @@
// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)
// Copyright (c) NVIDIA Corporation.
// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).
// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html
//
// Modified from original source:
// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "utils.cuh"
namespace deep_ep {
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
__device__ static __forceinline__
uint64_t HtoBE64(uint64_t x) {
uint64_t ret;
asm("{\n\t"
".reg .b32 ign;\n\t"
".reg .b32 lo;\n\t"
".reg .b32 hi;\n\t"
".reg .b32 new_lo;\n\t"
".reg .b32 new_hi;\n\t"
"mov.b64 {lo,hi}, %1;\n\t"
"prmt.b32 new_hi, lo, ign, 0x0123;\n\t"
"prmt.b32 new_lo, hi, ign, 0x0123;\n\t"
"mov.b64 %0, {new_lo,new_hi};\n\t"
"}" : "=l"(ret) : "l"(x));
return ret;
}
__device__ static __forceinline__
uint32_t HtoBE32(uint32_t x) {
uint32_t ret;
asm("{\n\t"
".reg .b32 ign;\n\t"
"prmt.b32 %0, %1, ign, 0x0123;\n\t"
"}" : "=r"(ret) : "r"(x));
return ret;
}
__device__ static __forceinline__
uint16_t HtoBE16(uint16_t x) {
// TODO: simplify PTX using 16-bit instructions
auto a = static_cast<uint32_t>(x);
uint32_t d;
asm volatile(
"{\n\t"
".reg .b32 mask;\n\t"
".reg .b32 ign;\n\t"
"mov.b32 mask, 0x4401;\n\t"
"mov.b32 ign, 0x0;\n\t"
"prmt.b32 %0, %1, ign, mask;\n\t"
"}"
: "=r"(d)
: "r"(a));
return static_cast<uint16_t>(d);
}
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
__device__ static __forceinline__
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
return &nvshmemi_ibgda_device_state_d;
}
__device__ static __forceinline__
nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
auto state = ibgda_get_state();
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe];
}
__device__ static __forceinline__
void ibgda_lock_acquire(int *lock) {
while (atomicCAS(lock, 0, 1) == 1);
// Prevent reordering before the lock is acquired
memory_fence_cta();
}
__device__ static __forceinline__
void ibgda_lock_release(int *lock) {
memory_fence_cta();
// Prevent reordering before lock is released
st_na_relaxed(lock, 0);
}
__device__ static __forceinline__
void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) {
// `DBREC` contains the index of the next empty `WQEBB`
__be32 dbrec_val;
__be32 *dbrec_ptr = qp->tx_wq.dbrec;
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`
asm("{\n\t"
".reg .b32 dbrec_head_16b;\n\t"
".reg .b32 ign;\n\t"
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
"}"
: "=r"(dbrec_val)
: "r"(dbrec_head));
st_na_release(dbrec_ptr, dbrec_val);
}
__device__ static __forceinline__
void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) {
auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);
ibgda_ctrl_seg_t ctrl_seg = {
.opmod_idx_opcode = HtoBE32(prod_idx << 8),
.qpn_ds = HtoBE32(qp->qpn << 8)
};
EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), "");
st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));
}
__device__ static __forceinline__
void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) {
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t old_prod_idx;
// Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence
ibgda_lock_acquire(&mvars->post_send_lock);
old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
if (new_prod_idx > old_prod_idx) {
ibgda_update_dbr(qp, new_prod_idx);
ibgda_ring_db(qp, new_prod_idx);
}
ibgda_lock_release(&mvars->post_send_lock);
}
template <bool kAlwaysDoPostSend>
__device__ static __forceinline__
void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
uint32_t num_wqes, int message_idx = 0) {
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
// WQE writes must be finished first
__threadfence();
// Wait for prior WQE slots to be filled first
auto *ready_idx = reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.ready_head);
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
// Always post, not in batch
constexpr int kNumRequestInBatch = 4;
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
ibgda_post_send(qp, new_wqe_idx);
}
__device__ static __forceinline__ void
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_inl_data_seg inl_seg;
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));
raddr_seg.raddr = HtoBE64(raddr);
raddr_seg.rkey = rkey;
raddr_seg.reserved = 0;
inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);
// `imm == std::numeric_limits<uint32_t>::max()` means no imm writes
ctrl_seg = {0};
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));
if (imm != std::numeric_limits<uint32_t>::max())
ctrl_seg.imm = HtoBE32(imm);
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4");
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));
st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));
}
__device__ static __forceinline__
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
auto log2_cumem_granularity = state->log2_cumem_granularity;
// Local key
uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity;
auto device_key = state->constmem.lkeys[idx];
auto lchunk_size = device_key.next_addr - laddr;
*lkey = device_key.key;
// Remote key
uint64_t roffset = raddr - heap_start;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
device_key = state->constmem.rkeys[idx];
} else {
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
}
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
*out_rkey = device_key.key;
// Return the minimum of local and remote chunk sizes
auto rchunk_size = device_key.next_addr - roffset;
return min(lchunk_size, rchunk_size);
}
__device__ static __forceinline__ void
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
uint64_t roffset = addr - heap_start;
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
nvshmemi_ibgda_device_key_t device_key;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
device_key = state->constmem.rkeys[idx];
else
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
*out_rkey = device_key.key;
}
__device__ static __forceinline__ uint64_t
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
auto mvars = &qp->mvars;
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
}
__device__ static __forceinline__ void*
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
uint16_t cnt = qp->tx_wq.nwqes;
uint16_t idx = wqe_idx & (cnt - 1);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
}
// Wait until wqe `idx - 1` is completed.
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
__device__ static __forceinline__ void
nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) {
auto qp = ibgda_get_rc(dst_pe, qp_id);
auto cq = qp->rx_wq.cq;
const uint32_t ncqes = cq->ncqes;
auto *cqe64 = reinterpret_cast<struct mlx5_cqe64*>(cq->cqe);
auto old_cons_idx = *cq->cons_idx;
*cq->cons_idx = old_cons_idx + 1;
// Wait until `wqe_counter >= old_cons_idx`
while ((static_cast<uint16_t>(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes));
}
__device__ static __forceinline__ void
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
// Get rkey
// NOTES: the `p` operation will not cross multiple remote chunks
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey);
// Write WQEs
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs;
wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);
// Submit requests
ibgda_submit_requests<true>(qp, base_wqe_idx, 1);
}
__device__ static __forceinline__ void
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
void **out_wqes) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_data_seg data_seg;
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
struct mlx5_wqe_raddr_seg *raddr_seg_ptr;
struct mlx5_wqe_data_seg *data_seg_ptr;
raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));
data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
raddr_seg.raddr = HtoBE64(raddr);
raddr_seg.rkey = rkey;
raddr_seg.reserved = 0;
data_seg.byte_count = HtoBE32(bytes);
data_seg.lkey = lkey;
data_seg.addr = HtoBE64(laddr);
ctrl_seg = {0};
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16");
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
__device__ static __forceinline__ void
ibgda_write_empty_recv_wqe(void *out_wqe) {
auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);
struct mlx5_wqe_data_seg data_seg;
// Make the first segment in the WQE invalid, then the entire list will be invalid
data_seg.byte_count = 0;
data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);
data_seg.addr = 0;
EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length");
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
__device__ static __forceinline__ uint64_t
nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
auto mvars = &qp->mvars;
// Allocate if not enough
constexpr int kMinIBGDARecvs = 32;
auto resv_head = mvars->rx_wq.resv_head;
auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx;
if (num_valid_slots < kMinIBGDARecvs) {
resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes;
mvars->rx_wq.resv_head = resv_head;
// Ensure WQE is written before `dbrec`
__be32 dbrec_val;
__be32 *dbrec_ptr = qp->rx_wq.dbrec;
// Compared to sending, for each QP, we only post recv in a single thread,
// so we don't need to do synchronization here
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
asm("{\n\t"
".reg .b32 dbrec_head_16b;\n\t"
".reg .b32 ign;\n\t"
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
"}" : "=r"(dbrec_val)
: "r"(static_cast<uint32_t>(resv_head)));
st_na_release(dbrec_ptr, dbrec_val);
}
// Return old number of slots
return num_valid_slots;
}
__device__ static __forceinline__ void
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
// NOTES: only one thread can run this function
// TODO: consider this assertion for normal AR
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
}
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
// Get lkey and rkey, store them into lanes
uint32_t num_wqes = 0;
__be32 my_lkey = 0;
uint64_t my_laddr = 0;
__be32 my_rkey = 0;
uint64_t my_raddr = 0;
uint64_t my_chunk_size = 0;
// Decide how many messages (theoretically 3 for maximum)
auto remaining_bytes = bytes;
while (remaining_bytes > 0) {
if (lane_id == num_wqes)
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey));
// Move one more message
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
remaining_bytes -= chunk_size;
req_lptr += chunk_size;
req_rptr += chunk_size;
++ num_wqes;
}
EP_DEVICE_ASSERT(num_wqes <= 32);
// Process WQE
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = 0;
if (lane_id == 0)
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
if (lane_id < num_wqes) {
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size,
base_wqe_idx, &wqe_ptr);
}
__syncwarp();
// Submit
if (lane_id == 0)
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
__syncwarp();
}
} // namespace deep_ep

1720
csrc/kernels/internode.cu Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,533 @@
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "ibgda_device.cuh"
namespace deep_ep {
namespace internode_ll {
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1) {
// Barrier before cleaning (in case of unfinished chunked EP)
nvshmemx_barrier_all_block();
// Clean
auto thread_id = static_cast<int>(threadIdx.x);
#pragma unroll
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
clean_0[i] = 0;
#pragma unroll
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work fine)
nvshmemx_barrier_all_block();
}
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream) {
constexpr int kNumThreads = 256;
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert,
int* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_sms = static_cast<int>(gridDim.x);
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// FP8 staffs
constexpr int kNumPerChannels = 128;
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
const int num_scales = kHidden / kNumPerChannels;
const size_t hidden_int4 = kHidden / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
const size_t num_bytes_per_msg = kHidden + num_scales * sizeof(float) + sizeof(int4);
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
// Expert counts
__shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups];
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps - 1) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * 32;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden);
const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales);
// Overlap top-k index read and source token index write
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// FP8 cast
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read and calculate local amax
auto int4_value = __ldg(x_int4 + i);
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
// Cast into send buffer
int2 int2_value;
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
}
rdma_x_int2[i] = int2_value;
}
asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
// Issue IBGDA sends
if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
const auto dst_rank = dst_expert_idx / num_local_experts;
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
__syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
}
} else if (warp_id == num_warps - 1) {
EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
// The first SM is also responsible for checking QPs
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += 32)
next_clean[i] = 0;
// Notify before executing `int_p`
__syncwarp();
#pragma unroll
for (int i = lane_id; i < num_experts; i += 32)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumWarpGroups] = {0};
const auto expert_begin_idx = sm_id * kNumWarpGroups;
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx] ++;
}
// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
}
}
}
__syncthreads();
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
dst_rank, dst_expert_local_idx, 0);
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
}
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
}
__syncwarp();
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int num_recv_tokens, recv_token_begin_idx;
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
if (sub_warp_id == 1 and lane_id == 0) {
if (src_rank != rank) {
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
EP_DEVICE_ASSERT(num_recv_tokens != 0);
} else {
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
}
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// Copy tokens
EP_DEVICE_ASSERT(num_scales <= 64);
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global);
// Copy scales
const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden);
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
// Copy source info
const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales);
if (lane_id == 0)
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
__syncwarp();
}
}
}
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
// NOTES: this part will be cleaned in `combine`
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
#define DISPATCH_LAUNCH_CASE(hidden) \
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, phases); break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
// Data type staffs
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV;
// Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += 32)
next_clean[i] = 0;
// Notify before executing `int_p`
__syncwarp();
if (lane_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
// FP8 cast and issue IBGDA sends
if (responsible_expert_idx < num_experts) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
const auto local_x = reinterpret_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// Unpack layout
int offset, num_tokens_to_send;
unpack2(layout, num_tokens_to_send, offset);
// Issue IBGDA send
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
// Copy directly to local rank, or copy to buffer and issue RDMA
auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
}
// Put finishing flag
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0);
} else {
st_na_release(rdma_recv_flag + global_expert_idx, 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
__syncwarp();
}
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
if (sub_warp_id == 0 and lane_id == 0) {
// TODO: refactor QP indices
auto src_rank = responsible_expert_idx / num_local_experts;
auto src_expert_idx = responsible_expert_idx % num_local_experts;
if (src_rank != rank) {
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
} else {
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
}
}
}
cg::this_grid().sync();
// Reduce tokens with FP8 cast
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
}
float combined_values[kNumElemsPerInt4] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<nv_bfloat16*>(&combined_values);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
}
void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, cudaStream_t stream, int phases) {
constexpr int kNumWarpsPerGroup = 10;
constexpr int kNumWarpGroups = 3;
constexpr int kNumMaxTopk = 9;
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
// Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
phases); } break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
} // namespace internode_ll
} // namespace deep_ep

803
csrc/kernels/intranode.cu Normal file
View File

@@ -0,0 +1,803 @@
#include "configs.cuh"
#include "buffer.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void
notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
if (sm_id == 0) {
// Barrier first
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
}
// After this loop:
// - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j
// - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j
int num_experts_per_rank = num_experts / kNumRanks;
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i];
#pragma unroll
for (int i = 0; i < num_experts_per_rank; ++ i)
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
}
__syncthreads();
// Wait for all ranks to be finished
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
if (thread_id < kNumRanks) {
#pragma unroll
for (int i = 1; i < kNumRanks; ++ i)
local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];
if (thread_id == rank)
*moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];
}
// Sum per-experts counts and return to CPU
auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;
if (thread_id < num_experts_per_rank) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i)
sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
moe_recv_expert_counter_mapped[thread_id] = sum;
}
__syncthreads();
// Copy rank size prefix matrix to another tensor
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];
// Extra memset for later communication queue
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
local_per_expert_buffer[i] = 0;
// Barrier
memory_fence();
__syncthreads();
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
} else {
int dst_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// Iterate over tokens
int count = 0;
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32)
count += is_token_in_rank[i * kNumRanks + dst_rank];
count = warp_reduce_sum(count);
if (lane_id == 0)
channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
}
__syncthreads();
// Pre-compute prefix sum for all channels
if (thread_id == 0) {
#pragma unroll
for (int i = 1; i < num_channels; ++ i)
channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1];
}
}
}
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int num_channels) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
num_tokens_per_rank, moe_recv_counter_mapped, \
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
buffer_ptrs, task_fifo_ptrs, head, rank); \
break
constexpr int kNumThreads = 128;
EP_HOST_ASSERT(num_experts % num_ranks == 0);
EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);
SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}
template<int kNumRanks>
__global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
// A simplified version for cached handles
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
ptr[i] = rank_prefix_matrix[i];
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[kNumRanks * kNumRanks + i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
}
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs,
int head, int rank, int num_ranks, cudaStream_t stream) {
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
break
SETUP_LAUNCH_CONFIG(1, 128, stream);
SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
}
template<int kNumRanks>
__global__ void __launch_bounds__(kNumRanks * 32, 1)
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const bool is_sender = sm_id % 2 == 0;
EP_DEVICE_ASSERT(num_sms % 2 == 0);
// Each warp is responsible for a single rank
const auto num_channels = num_sms / 2;
const auto responsible_rank = (static_cast<int>(thread_id)) / 32;
// Even-numbered blocks for sending, odd-numbered blocks for receiving
const auto responsible_channel = sm_id / 2;
int num_experts_per_rank = num_experts / kNumRanks;
EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);
EP_DEVICE_ASSERT(num_topk <= 32);
EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
// Calculate pointers by the specific layout
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
int target_rank = is_sender ? rank : responsible_rank;
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
// Channel buffer metadata
// Senders are responsible for tails, and receivers are responsible for heads
// Stored on the receiver side
// The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t`
// `start_offset`: kNumChannels * kNumRanks * sizeof(int)
// `end_offset`: kNumChannels * kNumRanks * sizeof(int)
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_end_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
// Channel data buffers, stored on the receiver side
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t)
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
// `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_idx_buffers = Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
if (is_sender) {
// Workers for sending
constexpr int num_send_warps = kNumRanks;
const auto send_thread_id = thread_id;
const auto send_warp_id = send_thread_id / 32;
const auto send_lane_id = send_thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32);
EP_DEVICE_ASSERT(num_send_warps == kNumRanks and send_warp_id == responsible_rank);
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
// NOTES: this is for distinguishing zero tokens
if (send_lane_id == 0) {
int value = responsible_channel > 0 ? channel_prefix_matrix[send_warp_id * num_channels + responsible_channel - 1] : 0;
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
value = channel_prefix_matrix[send_warp_id * num_channels + responsible_channel];
st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
}
__syncwarp();
// Get tasks
int token_start_idx, token_end_idx;
get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
// Iterate over all tokens and send by chunks
int cached_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64();
while (send_lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
break;
// Rare cases to loop again
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
trap();
}
}
__syncwarp();
int chunk_token_idx = 0;
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
if (send_lane_id == 0)
send_head[token_idx * kNumRanks + send_warp_id] = is_token_in_rank[token_idx * kNumRanks + send_warp_id] ? cached_channel_tail_idx : -1;
// Skip if not selected
if (not is_token_in_rank[token_idx * kNumRanks + send_warp_id]) {
token_idx ++;
continue;
}
// Get an empty slot
int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens;
// Copy data
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x + token_idx * hidden_int4;
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
__ldg, st_na_global);
// Copy source index
if (send_lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
// Copy `topk_idx` and `topk_weights` with transformed index
if (send_lane_id < num_topk) {
// Top-k index
int recv_expert_begin = send_warp_id * num_experts_per_rank, recv_expert_end = (send_warp_id + 1) * num_experts_per_rank;
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
// Top-k weights
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
}
// Copy `x_scales`
#pragma unroll
for (int i = send_lane_id; i < num_scales; i += 32)
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
// Move token index
chunk_token_idx ++, token_idx ++;
}
// Move tail index
__syncwarp();
if (send_lane_id == 0)
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
}
} else {
// Workers for receiving and copying into buffer
constexpr int num_recv_warps = kNumRanks;
const auto recv_thread_id = thread_id;
const auto recv_warp_id = recv_thread_id / 32;
const auto recv_lane_id = recv_thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32 and recv_warp_id == responsible_rank);
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps == kNumRanks);
// Calculate offset first
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
int rank_offset = recv_warp_id > 0 ? rank_prefix_matrix[(recv_warp_id - 1) * kNumRanks + rank] : 0;
// Receive channel offset
int total_offset, num_tokens_to_recv;
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
if (recv_lane_id == 0) {
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
recv_channel_offset[recv_warp_id * num_channels + responsible_channel] = total_offset;
num_tokens_to_recv -= total_offset;
}
total_offset = __shfl_sync(0xffffffff, total_offset, 0);
total_offset += rank_offset;
num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0);
auto start_time = clock64();
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
while (num_tokens_to_recv > 0) {
// Check channel status by lane 0
while (recv_lane_id == 0) {
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
// Ready to copy
if (cached_channel_head_idx != cached_channel_tail_idx)
break;
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv);
trap();
}
}
// Sync queue tail
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
// Copy data
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx) {
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
ld_nc_global, st_na_global);
}
// Copy `src_idx`
#pragma unroll 4
for (int chunk_idx = cached_channel_head_idx + recv_lane_id; chunk_idx < cached_channel_tail_idx; chunk_idx += 32)
recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);
// Copy `topk_idx` and `topk_weights`
#pragma unroll 4
for (int idx = recv_lane_id; idx < num_recv_tokens * num_topk; idx += 32) {
int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto recv_idx = static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;
recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);
recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);
}
// Copy `x_scales`
#pragma unroll 4
for (int i = recv_lane_id; i < num_recv_tokens * num_scales; i += 32) {
int chunk_idx = i / num_scales, scales_idx = i % num_scales;
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales + scales_idx] =
ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx);
}
// Move queue
cached_channel_head_idx += num_recv_tokens;
total_offset += num_recv_tokens;
__syncwarp();
if (recv_lane_id == 0)
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
// Exit
num_tokens_to_recv -= num_recv_tokens;
}
}
}
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
#define DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, dispatch<ranks>, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
is_token_in_rank, channel_prefix_matrix, \
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
break
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(num_sms % 2 == 0);
SETUP_LAUNCH_CONFIG(num_sms, num_ranks * 32, stream);
SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
template<int kNumRanks>
__global__ void
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank) {
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
} else {
const auto channel_id = sm_id - 1;
const auto thread_id = static_cast<int>(threadIdx.x);
const auto rank_id = thread_id / 32;
const auto lane_id = thread_id % 32;
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
#pragma unroll
for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) {
int token_idx = token_idx_tail - lane_id, expected_head = 0;
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
head = __shfl_sync(0xffffffff, current_head, i);
if (head < 0) {
if (lane_id == i)
expected_head = -last_head - 1;
} else {
last_head = head;
}
}
if (current_head < 0 and token_idx >= token_start_idx)
send_head[token_idx * kNumRanks + rank_id] = expected_head;
}
}
}
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks,
cudaStream_t stream) {
#define CACHED_NOTIFY_COMBINE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
break
const int num_threads = std::max(128, 32 * num_ranks);
EP_HOST_ASSERT(num_ranks <= num_threads);
EP_HOST_ASSERT(num_threads <= 1024);
EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);
SWITCH_RANKS(CACHED_NOTIFY_COMBINE);
#undef CACHED_NOTIFY_COMBINE
}
template<typename dtype_t, int kNumRanks, int kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void **buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_channels = num_sms / 2;
const bool is_sender = sm_id % 2 == 0;
const int responsible_channel = sm_id / 2;
EP_DEVICE_ASSERT(num_topk <= 32);
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
auto x_int4 = reinterpret_cast<const int4*>(x);
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
if (is_sender) {
// Workers for sending
// Several warps are responsible for a single rank
constexpr int num_send_warps = kNumThreads / 32;
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
const auto send_thread_id = thread_id;
const auto send_lane_id = send_thread_id % 32;
const auto send_rank_id = thread_id / num_threads_per_rank;
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
// Calculate pointers by the specific layout
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
auto num_channels_total = num_channels * kNumRanks;
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
// Channel meta data
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
// Get tasks
// NOTES: `channel_offset` is already shifted
int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;
int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;
int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];
int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset;
int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens;
// Iterate over all tokens and send by chunks
int current_channel_tail_idx = 0;
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) {
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
auto start_time = clock64();
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
while (send_lane_id == 0) {
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
break;
// Rare cases to loop again
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
trap();
}
}
__syncwarp();
// Send by chunk
#pragma unroll
for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {
// Get an empty slot
int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;
// Copy data
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
// Send source index
if (send_lane_id == 0)
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
// Send `topk_weights`
if (num_topk > 0 and send_lane_id < num_topk)
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
}
token_idx += num_round_tokens;
current_channel_tail_idx += num_round_tokens;
// Move tail index
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
}
} else {
// Workers for receiving
// One warp for moving the queue head, others for reduction
constexpr int num_recv_warps = kNumThreads / 32;
const auto recv_warp_id = thread_id / 32;
const auto recv_lane_id = thread_id % 32;
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
// Shared head, tail and retired flags for receiver warps
__shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks];
__shared__ volatile int channel_tail_idx[kNumRanks];
__shared__ volatile bool warp_retired[num_recv_warps];
if (thread_id < num_recv_warps)
warp_retired[thread_id] = false;
if (recv_lane_id < kNumRanks)
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
if (thread_id < kNumRanks)
channel_tail_idx[thread_id] = 0;
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
if (thread_id < 32) {
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
// Queue head updater
int last_head = 0;
while (recv_lane_id < kNumRanks) {
// Check retired
bool retired = true;
#pragma unroll
for (int i = 1; i < num_recv_warps; ++ i)
retired = retired and warp_retired[i];
if (retired)
break;
// Update queue tail
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
// Update minimum head
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
}
} else {
// Receivers
// Channel metadata
// All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`
Buffer<int4> channel_x_buffers[kNumRanks];
Buffer<float> channel_topk_weights_buffers[kNumRanks];
// Calculate pointers by the specific layout
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) {
auto channel_rank_offset = responsible_channel * kNumRanks + i;
auto num_channels_total = num_channels * kNumRanks;
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
}
// The same tokens as the dispatch process
int token_start_idx, token_end_idx;
get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
// Iterate over all tokens and combine
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
// Read expected head
int expected_head = -1;
if (recv_lane_id < kNumRanks) {
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
}
auto start_time = clock64();
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
trap();
}
}
__syncwarp();
// Broadcast current heads
int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];
#pragma unroll
for (int i = 0; i < kNumRanks; ++ i) {
auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i);
if (expected_head_i >= 0) {
slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;
topk_ranks[num_topk_ranks ++] = i;
}
}
// Reduce data
#pragma unroll
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
// Read buffers
int4 recv_value_int4[kNumRanks];
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
// Reduce all-to-all results
float values[kDtypePerInt4] = {0};
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j) {
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
#pragma unroll
for (int k = 0; k < kDtypePerInt4; ++ k)
values[k] += static_cast<float>(recv_value_dtypes[k]);
}
// Cast back to `dtype_t` and write
int4 out_int4;
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
#pragma unroll
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
recv_int4[token_idx * hidden_int4 + i] = out_int4;
}
// Reduce `topk_weights`
if (recv_lane_id < num_topk) {
float value = 0;
#pragma unroll
for (int i = 0; i < num_topk_ranks; ++ i)
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
}
}
// Retired
__syncwarp();
if (recv_lane_id == 0)
warp_retired[recv_warp_id] = true;
}
}
}
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
const void* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void** buffer_ptrs, int rank, int num_ranks,
cudaStream_t stream, int num_sms,
int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
buffer_ptrs, rank, \
num_max_send_tokens, num_recv_buffer_tokens); \
break
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break
// Even-numbered blocks for sending, odd-numbered blocks for receiving
EP_HOST_ASSERT(num_sms % 2 == 0);
EP_HOST_ASSERT(kNumThreads >= num_ranks * 32);
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);
#undef COMBINE_DTYPE_LAUNCH_CASE
#undef COMBINE_LAUNCH_CASE
}
} // namespace intranode
} // namespace deep_ep

60
csrc/kernels/launch.cuh Normal file
View File

@@ -0,0 +1,60 @@
#pragma once
#include "configs.cuh"
#ifndef SETUP_LAUNCH_CONFIG
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
cudaLaunchAttribute attr[1]; \
attr[0].id = cudaLaunchAttributeCooperative; \
attr[0].val.cooperative = 1; \
cfg.attrs = attr; \
cfg.numAttrs = 1
#endif
#ifndef LAUNCH_KERNEL
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
#endif
#define SWITCH_RANKS(case_macro) \
switch (num_ranks) { \
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} while (false)
#define SWITCH_RDMA_RANKS(case_macro) \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: case_macro(2); \
case 3: case_macro(3); \
case 4: case_macro(4); \
case 8: case_macro(8); \
case 16: case_macro(16); \
case 18: case_macro(18); \
case 20: case_macro(20); \
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} while (false)
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
switch (num_ranks) { \
case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
} while (false)
#define SWITCH_TYPES(case_macro) \
switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \
case CUDA_R_32F: case_macro(float); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \
} while (false)
#define SWITCH_HIDDEN(case_macro) \
switch (hidden) { \
case 2560: case_macro(2560); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
} while (false)

119
csrc/kernels/runtime.cu Normal file
View File

@@ -0,0 +1,119 @@
#include <vector>
#include <cstring>
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"
namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
}
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
break
SETUP_LAUNCH_CONFIG(1, 32, stream);
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
#undef BARRIER_LAUNCH_CASE
}
} // namespace intranode
namespace internode {
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() {
nvshmemx_uniqueid_t unique_id;
nvshmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
return result;
}
__global__ void ibgda_initialize_recv_queue(int rank) {
auto thread_idx = static_cast<int>(threadIdx.x);
auto num_threads = static_cast<int>(blockDim.x);
auto dst_rank = static_cast<int>(blockIdx.x);
if (dst_rank != rank) {
for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) {
auto qp = ibgda_get_rc(dst_rank, qp_id);
// Clean some necessary variables
for (int i = 0; i < qp->rx_wq.nwqes; ++ i)
ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i));
qp->mvars.rx_wq.resv_head = 0;
qp->mvars.rx_wq.cons_idx = 0;
// Allocate receive slots
nvshmemi_ibgda_allocate_recvs(qp);
}
}
}
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
nvshmemx_uniqueid_t root_unique_id;
nvshmemx_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}
// Normal operations use IBRC, while low-latency operations use IBGDA
if (low_latency_mode) {
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice);
// Initialize recv queues for low-latency mode AR
ibgda_initialize_recv_queue<<<num_ranks, 128>>>(rank);
}
nvshmem_barrier_all();
return nvshmem_my_pe();
}
void* alloc(size_t size, size_t alignment) {
return nvshmem_align(alignment, size);
}
void free(void* ptr) {
nvshmem_free(ptr);
}
void barrier() {
nvshmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
nvshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
}
nvshmem_finalize();
}
} // namespace internode
} // namespace deep_ep

381
csrc/kernels/utils.cuh Normal file
View File

@@ -0,0 +1,381 @@
#pragma once
#include "exception.cuh"
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
auto __dst = (DST); \
for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
_Pragma("unroll") \
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \
_Pragma("unroll") \
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \
} \
for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
}
namespace deep_ep {
template <int kBytes>
struct VecInt {};
template<> struct VecInt<1> { using vec_t = int8_t; };
template<> struct VecInt<2> { using vec_t = int16_t; };
template<> struct VecInt<4> { using vec_t = int; };
template<> struct VecInt<8> { using vec_t = int64_t; };
template<> struct VecInt<16> { using vec_t = int4; };
__device__ __forceinline__ void trap() {
asm("trap;");
}
__device__ __forceinline__ void memory_fence() {
asm volatile("fence.acq_rel.sys;":: : "memory");
}
__device__ __forceinline__ void memory_fence_gpu() {
asm volatile("fence.acq_rel.gpu;":: : "memory");
}
__device__ __forceinline__ void memory_fence_cta() {
asm volatile("fence.acq_rel.cta;":: : "memory");
}
__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) {
asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
int ret;
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
uint64_t ret;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
int ret;
asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {
int ret;
asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
return ret;
}
__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) {
int ret;
asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
return ret;
}
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
int ret;
asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) {
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
return static_cast<uint8_t>(ret);
}
__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) {
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) {
uint32_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) {
uint64_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int ld_volatile_global(const int *ptr) {
int ret;
asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float ld_volatile_global(const float *ptr) {
float ret;
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) {
int64_t ret;
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
int64_t ret;
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
#else
#define LD_NC_FUNC "ld.volatile.global"
#endif
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS,
// which does not have cache allocation, and `CONSTANT` memory does not have coherence control,
// so we have to control them by queue semantics
template <typename dtype_t>
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));
return *reinterpret_cast<dtype_t*>(&ret);
}
template <>
__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) {
uint16_t ret;
// NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit)
asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr));
return static_cast<uint8_t>(ret);
}
template <>
__device__ __forceinline__ int ld_nc_global(const int *ptr) {
int ret;
asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) {
int64_t ret;
asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ float ld_nc_global(const float *ptr) {
float ret;
asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) {
int2 ret;
asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr));
return ret;
}
template <>
__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) {
int4 ret;
asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
}
__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val));
}
__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};"
: : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
}
__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
}
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS,
// which does not have cache allocation (obviously in L1, I guess not in L2 too)
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
#define ST_NA_FUNC "st.global.L1::no_allocate"
#else
#define ST_NA_FUNC "st.global"
#endif
template <typename dtype_t>
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) {
st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr),
*reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value));
}
template <>
__device__ __forceinline__ void st_na_global(const int *ptr, const int& value) {
asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value));
}
template <>
__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) {
asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value));
}
template <>
__device__ __forceinline__ void st_na_global(const float *ptr, const float& value) {
asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value));
}
template <>
__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) {
asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};"
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}
template <typename dtype_t>
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
}
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
int& token_start_idx, int& token_end_idx) {
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
dtype_b_t packed;
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
return packed;
}
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
x = unpacked_ptr[0], y = unpacked_ptr[1];
}
template <typename dtype_t>
__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
auto send_int_values = reinterpret_cast<int*>(&ptr);
int recv_int_values[sizeof(dtype_t) / sizeof(int)];
#pragma unroll
for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i)
recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);
return *reinterpret_cast<dtype_t*>(recv_int_values);
}
__forceinline__ __device__ int warp_reduce_sum(int value) {
value += __shfl_xor_sync(0xffffffff, value, 16);
value += __shfl_xor_sync(0xffffffff, value, 8);
value += __shfl_xor_sync(0xffffffff, value, 4);
value += __shfl_xor_sync(0xffffffff, value, 2);
value += __shfl_xor_sync(0xffffffff, value, 1);
return value;
}
__forceinline__ __device__ float half_warp_reduce_max(float value) {
auto mask = __activemask();
// The mask be in `{0xffffffff, 0xffff}`
value = max(value, __shfl_xor_sync(mask, value, 8));
value = max(value, __shfl_xor_sync(mask, value, 4));
value = max(value, __shfl_xor_sync(mask, value, 2));
value = max(value, __shfl_xor_sync(mask, value, 1));
return value;
}
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
template <int kNumRanks>
__forceinline__ __device__ void move_fifo_slots(int &head) {
head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS;
}
template <int kNumRanks>
__device__ __forceinline__ bool not_finished(int *task, int expected) {
auto result = false;
auto lane_id = threadIdx.x % 32;
if (lane_id < kNumRanks)
result = ld_volatile_global(task + lane_id) != expected;
return __any_sync(0xffffffff, result);
}
template <int kNumRanks>
__forceinline__ __device__ void
timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) {
auto start_time = clock64();
while (not_finished<kNumRanks>(task_fifo_ptrs[rank] + head, expected)) {
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) {
printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank);
trap();
}
}
}
template <int kNumRanks>
__forceinline__ __device__ void
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
auto thread_id = static_cast<int>(threadIdx.x);
EP_DEVICE_ASSERT(kNumRanks <= 32);
if (thread_id < kNumRanks) {
atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG);
memory_fence();
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
}
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
}
} // namespace deep_ep

7
deep_ep/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
import torch
from .utils import EventOverlap
from .buffer import Buffer
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config

534
deep_ep/buffer.py Normal file
View File

@@ -0,0 +1,534 @@
import os
import torch
import torch.distributed as dist
from typing import Callable, List, Tuple, Optional, Union
# noinspection PyUnresolvedReferences
import deep_ep_cpp
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config, EventHandle
from .utils import EventOverlap
class Buffer:
"""
The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:
- high-throughput intranode all-to-all (dispatch and combine, using NVLink)
- high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
- low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
Attributes:
num_sms: the SMs used in high-throughput kernels.
rank: the local rank number.
group_size: the number of ranks in the group.
group: the communication group.
num_nvl_bytes: the buffer size for intranode NVLink communication.
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
runtime: the C++ runtime.
"""
num_sms: int = 20
def __init__(self, group: dist.ProcessGroup,
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
"""
Initialize the communication buffer.
Arguments:
group: the communication group.
num_nvl_bytes: the buffer size for intranode NVLink communication.
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
low_latency_mode: whether to enable low-latency mode.
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
to the number of local experts.
"""
# TODO: argument docs
# Initialize the CPP runtime
self.rank = group.rank()
self.group_size = group.size()
self.group = group
self.num_nvl_bytes = num_nvl_bytes
self.num_rdma_bytes = num_rdma_bytes
self.low_latency_mode = low_latency_mode
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
# Synchronize device IDs
device_ids = [None, ] * self.group_size
local_device_id = self.runtime.get_local_device_id()
dist.all_gather_object(device_ids, local_device_id, group)
# Synchronize IPC handles
ipc_handles = [None, ] * self.group_size
local_ipc_handle = self.runtime.get_local_ipc_handle()
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
# Synchronize NVSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
if low_latency_mode:
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
# NOTES: NVSHMEM initialization requires at least 256 MiB
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
# Synchronize using the root ID
nvshmem_unique_ids = [None, ] * self.group_size
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
# Make CPP runtime available
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
assert self.runtime.is_available()
@staticmethod
def set_num_sms(new_num_sms: int) -> None:
"""
Set the number of SMs to use in high-throughput kernels.
Arguments:
new_num_sms: the new number to be set.
"""
assert new_num_sms % 2 == 0, 'The SM count must be even'
Buffer.num_sms = new_num_sms
@staticmethod
def capture() -> EventOverlap:
"""
Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`.
Returns:
event: the captured event.
"""
return EventOverlap(EventHandle())
@staticmethod
def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int) -> int:
"""
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
Arguments:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token.
num_ranks: the number of EP group ranks.
num_experts: the number of all experts.
Returns:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor:
"""
Get the raw buffer (slice supported) as a PyTorch tensor.
Argument:
dtype: the data type (PyTorch `dtype`) for the tensor.
size: the slice size (by elements) to get from the buffer.
offset: the offset of the beginning element.
use_rdma_buffer: whether to return the RDMA buffer.
"""
tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer)
if size is None:
return tensor
assert tensor.numel() >= size.numel()
return tensor[:size.numel()].view(size)
@staticmethod
def get_dispatch_config(num_ranks: int) -> Config:
"""
Get a recommended dispatch config.
Argument:
num_ranks: the number of ranks.
Returns:
config: the recommended config.
"""
# Intranode
if num_ranks <= 8:
return Config(Buffer.num_sms, 6, 256, 6, 128)
# Internode
config_map = {
16: Config(Buffer.num_sms, 16, 288, 20, 128),
24: Config(Buffer.num_sms, 8, 288, 32, 128),
32: Config(Buffer.num_sms, 8, 288, 32, 128),
64: Config(Buffer.num_sms, 20, 288, 28, 128),
128: Config(Buffer.num_sms, 20, 560, 32, 128),
144: Config(Buffer.num_sms, 32, 720, 12, 128),
160: Config(Buffer.num_sms, 28, 720, 12, 128),
}
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
return config_map[num_ranks]
@staticmethod
def get_combine_config(num_ranks: int) -> Config:
"""
Get a recommended combine config.
Argument:
num_ranks: the number of ranks.
Returns:
config: the recommended config.
"""
# Intranode
if num_ranks <= 8:
return Config(Buffer.num_sms, 6, 256, 6, 128)
# Internode
config_map = {
16: Config(Buffer.num_sms, 2, 288, 28, 128),
24: Config(Buffer.num_sms, 1, 288, 20, 128),
32: Config(Buffer.num_sms, 1, 288, 20, 128),
64: Config(Buffer.num_sms, 1, 288, 20, 128),
128: Config(Buffer.num_sms, 1, 560, 12, 128),
144: Config(Buffer.num_sms, 2, 720, 8, 128),
160: Config(Buffer.num_sms, 2, 720, 8, 128),
}
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
return config_map[num_ranks]
# noinspection PyTypeChecker
def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]:
"""
Calculate the layout required for later communication.
Arguments:
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
num_experts: the number of experts.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
rank (with the same GPU index), return `None` for intranode settings.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
async_finish, allocate_on_comm_stream)
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
# noinspection PyTypeChecker
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
"""
Dispatch tokens to different ranks, both intranode and internode settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA. AR must be disabled.
Arguments:
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as
`[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]`
(requiring divisible) with type `torch.float`.
handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time.
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
rank (with the same GPU index), return `None` for intranode settings.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
expert_alignment: align the number of tokens received by each local expert to this variable.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the
received token count.
recv_topk_idx: received expert indices.
recv_topk_weights: received expert weights.
num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by
each local expert, aligned to the input `expert_alignment`.
handle: the returned communication handle.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
# Default config
config = self.get_dispatch_config(self.group_size) if config is None else config
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
if handle is not None:
assert topk_idx is None and topk_weights is None
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle
num_recv_tokens = recv_src_idx.size(0)
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
x, x_scales, None, None,
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
# noinspection PyTypeChecker
def combine(self, x: torch.Tensor, handle: Tuple,
topk_weights: Optional[torch.Tensor] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
"""
Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode
settings are supported.
Intranode kernels require all the ranks should be visible via NVLink.
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
index should be visible via RDMA. AR must be disabled.
Arguments:
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.
handle: a must-set communication handle, you can obtain this from the dispatch function.
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks.
config: the performance tuning config.
previous_event: the event to wait before actually executing the kernel.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
Returns:
recv_x: the reduced token from its dispatched ranks.
recv_topk_weights: the reduced top-k weights from its dispatch ranks.
event: the event after executing the kernel (valid only if `async_finish` is set).
"""
# Default config
config = self.get_combine_config(self.group_size) if config is None else config
# Internode
if self.runtime.get_num_rdma_ranks() > 1:
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
# NOTES: the second `_` is for the sending side, so we should use the third one
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
# Launch the kernel
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
x, topk_weights,
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return recv_x, recv_topk_weights, EventOverlap(event)
# noinspection PyTypeChecker
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Optional[Tuple] = None,
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
"""
Internode dispatch implementation, for more details, please refer to the `dispatch` docs.
Normally, you should not directly call this function.
"""
assert config is not None
# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
if handle is not None:
assert topk_idx is None and topk_weights is None
is_token_in_rank, \
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
recv_src_meta, send_rdma_head, send_nvl_head = handle
num_recv_tokens = recv_src_meta.size(0)
num_rdma_recv_tokens = send_nvl_head.size(0)
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights,
None, None, is_token_in_rank, None,
num_recv_tokens, num_rdma_recv_tokens,
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
0, 0, None, None, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (is_token_in_rank,
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
# noinspection PyTypeChecker
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
topk_weights: Optional[torch.Tensor] = None,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
"""
Internode combine implementation, for more details, please refer to the `combine` docs.
Normally, you should not directly call this function.
"""
assert config is not None
# Unpack handle
is_combined_token_in_rank, \
_, _, \
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
src_meta, send_rdma_head, send_nvl_head = handle
# Launch the kernel
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
x, topk_weights,
src_meta, is_combined_token_in_rank,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
async_finish, allocate_on_comm_stream)
return combined_x, combined_topk_weights, EventOverlap(event)
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
"""
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
if the buffer is dirty at some time.
For example, after running the normal dispatch/combine, you must run this function before executing any
low-latency kernel.
Arguments:
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
hidden: the hidden dimension of each token.
num_experts: the number of all experts.
"""
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
# noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
A low-latency implementation for dispatching with IBGDA **with implicit FP8 casting**.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes
are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
handle: the communication handle to be used in the `low_latency_combine` function.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales), packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
(specifically, IBGDA must be enabled).
Even for ranks in the same node, NVLink are fully disabled for simplicity.
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
low-latency kernels' result tensor at a single moment.
Arguments:
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
the local calculated tokens to be sent to this original rank and reduced.
topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched
tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals
to the number of dispatched tokens.
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook

60
deep_ep/utils.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
from typing import Any, Optional, Tuple
# noinspection PyUnresolvedReferences
from deep_ep_cpp import Config, EventHandle
class EventOverlap:
"""
A wrapper class to manage CUDA events, also for better overlapping convenience.
Attributes:
event: the CUDA event captured.
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
"""
def __init__(self, event: Optional[EventHandle] = None,
extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
"""
Initialize the class.
Arguments:
event: the CUDA event captured.
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
"""
self.event = event
# NOTES: we use extra tensors to achieve stream recording, otherwise,
# stream recording will be incompatible with CUDA graph.
self.extra_tensors = extra_tensors
def current_stream_wait(self) -> None:
"""
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
"""
assert self.event is not None
self.event.current_stream_wait()
def __enter__(self) -> Any:
"""
Utility for overlapping and Python `with` syntax.
You can overlap the kernels on the current stream with the following example:
```python
event_overlap = event_after_all_to_all_kernels()
with event_overlap():
do_something_on_current_stream()
# After exiting the `with` scope, the current stream with wait the event to be finished.
```
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""
Utility for overlapping and Python `with` syntax.
Please follow the example in the `__enter__` function.
"""
if self.event is not None:
self.event.current_stream_wait()

BIN
figures/low-latency.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 674 KiB

BIN
figures/normal.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 506 KiB

63
setup.py Normal file
View File

@@ -0,0 +1,63 @@
import os
import subprocess
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
if __name__ == '__main__':
nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM'
print(f'NVSHMEM directory: {nvshmem_dir}')
# TODO: currently, we only support Hopper architecture, we may add Ampere support later
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']
include_dirs = ['csrc/', f'{nvshmem_dir}/include']
sources = ['csrc/deep_ep.cpp',
'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu',
'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
library_dirs = [f'{nvshmem_dir}/lib']
# Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
# Disable DLTO (default by PyTorch)
nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']
extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']
extra_compile_args = {
'cxx': cxx_flags,
'nvcc': nvcc_flags,
'nvcc_dlink': nvcc_dlink
}
# noinspection PyBroadException
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
revision = ''
setuptools.setup(
name='deep_ep',
version='1.0.0' + revision,
packages=setuptools.find_packages(
include=['deep_ep']
),
ext_modules=[
CUDAExtension(
name='deep_ep_cpp',
include_dirs=include_dirs,
library_dirs=library_dirs,
sources=sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
],
cmdclass={
'build_ext': BuildExtension
}
)

246
tests/test_internode.py Normal file
View File

@@ -0,0 +1,246 @@
import os
import time
import torch
import torch.distributed as dist
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
# Test compatibility with low latency functions
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print()
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
# NOTES: handle must be refreshed
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print()
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
# Please make sure AR (Adaptive Routing) is turned off when running normal internode kernels,
num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
if local_rank == 0:
print()
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
if __name__ == '__main__':
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)

223
tests/test_intranode.py Normal file
View File

@@ -0,0 +1,223 @@
import os
import time
import torch
import torch.distributed as dist
# noinspection PyUnresolvedReferences
import deep_ep
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
# Test compatibility with low latency functions
import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
x_e4m3 = per_token_cast_to_fp8(x)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
print()
group.barrier()
time.sleep(1)
# Config
nvl_buffer_size = 256
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, rank_prefix_matrix):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = rank_prefix_matrix[i][rank].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank,
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
if with_topk:
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
# Checks
rank_prefix_matrix = handle[0]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, rank_prefix_matrix)
# Test cached dispatch (must without top-k staffs)
# NOTES: handle must be refreshed
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
# Test combine
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
if with_topk:
combine_args.update({'topk_weights': recv_topk_weights})
if previous_mode:
dispatch_args.update({'previous_event': buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
if local_rank == 0:
print(' passed', flush=True)
if local_rank == 0:
print()
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank,
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
'config': dispatch_config if dispatch_config is not None else config}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
print()
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank)
for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, rank, buffer, group)
if local_rank == 0:
print()
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
if __name__ == '__main__':
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)

160
tests/test_low_latency.py Normal file
View File

@@ -0,0 +1,160 @@
import random
import torch
import torch.distributed as dist
from functools import partial
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for i in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for return_recv_hook in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
# Check combine correctness
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f'Error: diff={diff}'
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us')
return hash_value
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

192
tests/utils.py Normal file
View File

@@ -0,0 +1,192 @@
import os
import sys
import numpy as np
import torch
import torch.distributed as dist
from typing import Optional
def init_dist(local_rank: int, num_local_ranks: int):
# NOTES: you may rewrite this function with your own cluster settings
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
port = int(os.getenv('MASTER_PORT', '8361'))
num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group(
backend='nccl',
init_method=f'tcp://{ip}:{port}',
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank
)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')
torch.cuda.set_device(local_rank)
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
# Warmup
for _ in range(num_warmups):
fn()
# Flush L2
cache.zero_()
# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()
times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
return np.average(times), np.min(times), np.max(times)
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
fn()
prof.step()
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)
# Return average kernel times
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()

126
third-party/README.md vendored Normal file
View File

@@ -0,0 +1,126 @@
# Install NVSHMEM
## Important notices
**This project is neither sponsored nor supported by NVIDIA.**
**Use of NVIDIA NVSHMEM is governed by the terms at [NVSHMEM Software License Agreement](https://docs.nvidia.com/nvshmem/api/sla.html).**
## Prerequisites
1. [GDRCopy](https://github.com/NVIDIA/gdrcopy) (v2.4 and above recommended) is a low-latency GPU memory copy library based on NVIDIA GPUDirect RDMA technology, and *it requires kernel module installation with root privileges.*
2. Hardware requirements
- GPUDirect RDMA capable devices, see [GPUDirect RDMA Documentation](https://docs.nvidia.com/cuda/gpudirect-rdma/)
- InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/)
- For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements)
## Installation procedure
### 1. Install GDRCopy
GDRCopy requires kernel module installation on the host system. Complete these steps on the bare-metal host before container deployment:
#### Build and installation
```bash
git clone https://github.com/NVIDIA/gdrcopy
cd gdrcopy
make -j$(nproc)
sudo make prefix=/opt/gdrcopy install
```
#### Kernel module installation
```bash
cd packages
CUDA=/path/to/cuda ./build-deb-packages.sh
sudo dpkg -i gdrdrv-dkms_2.4-4_amd64.deb \
libgdrapi_2.4-4_amd64.deb \
gdrcopy-tests_2.4-4_amd64.deb \
gdrcopy_2.4-4_amd64.deb
sudo ./insmod.sh # Load kernel modules on bare-metal system
```
#### Container environment notes
For containerized environments:
1. Host: keep kernel modules loaded (`gdrdrv`)
2. Container: install DEB packages *without* rebuilding modules:
```bash
sudo dpkg -i gdrcopy_2.4-4_amd64.deb \
libgdrapi_2.4-4_amd64.deb \
gdrcopy-tests_2.4-4_amd64.deb
```
#### Verification
```bash
gdrcopy_copybw # Should show bandwidth test results
```
### 2. Acquiring NVSHMEM source code
Download NVSHMEM v3.1.7 from the [NVIDIA NVSHMEM Archive](https://developer.nvidia.com/nvshmem-archive).
### 3. Apply our custom patch
Navigate to your NVSHMEM source directory and apply our provided patch:
```bash
git apply /path/to/deep_ep/dir/third-party/nvshmem.patch
```
### 4. Configure NVIDIA driver
Enable IBGDA by modifying `/etc/modprobe.d/nvidia.conf`:
```bash
options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"
```
Update kernel configuration:
```bash
sudo update-initramfs -u
sudo reboot
```
For more detailed configurations, please refer to the [NVSHMEM Installation Guide](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html).
### 5. Build and installation
The following example demonstrates building NVSHMEM with IBGDA support:
```bash
CUDA_HOME=/path/to/cuda && \
GDRCOPY_HOME=/path/to/gdrcopy && \
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/path/to/your/dir/to/install
cd build
make -j$(nproc)
make install
```
## Post-installation configuration
Set environment variables in your shell configuration:
```bash
export NVSHMEM_DIR=/path/to/your/dir/to/install # Use for DeepEP installation
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export PATH="${NVSHMEM_DIR}/bin:$PATH"
```
## Verification
```bash
nvshmem-info -a # Should display details of nvshmem
```

456
third-party/nvshmem.patch vendored Normal file
View File

@@ -0,0 +1,456 @@
From 9d784943e1032f15dd7cdd2599192937ba9d9343 Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Fri, 20 Dec 2024 10:57:12 +0800
Subject: [PATCH 1/5] Change QP creating order.
---
src/modules/transport/ibgda/ibgda.cpp | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
index 31bc56a..ff02f50 100644
--- a/src/modules/transport/ibgda/ibgda.cpp
+++ b/src/modules/transport/ibgda/ibgda.cpp
@@ -2921,17 +2921,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id
INFO(ibgda_state->log_level, "Creating %d RC QPs", device->rc.num_eps_per_pe);
for (int i = 0; i < num_rc_eps; ++i) {
// Do not create loopback to self
- if (i / device->rc.num_eps_per_pe == mype) {
+ int dst_pe = (i + 1 + mype) % n_pes;
+ int offset = i / n_pes;
+ int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset;
+ if (dst_pe == mype) {
continue;
}
- status = ibgda_create_qp(&device->rc.eps[i], device, portid, i,
+ status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i,
NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
- "ibgda_create_dci failed on RC #%d.", i);
+ "ibgda_create_dci failed on RC #%d.", mapped_i);
- status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device);
+ status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
- "ibgda_get_rc_handle failed on RC #%d.", i);
+ "ibgda_get_rc_handle failed on RC #%d.", mapped_i);
}
if (num_rc_eps) {
--
2.25.1
From 3cd3938bcbbabed7fb7675032afb02647ea9c2fe Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Mon, 23 Dec 2024 09:55:27 +0800
Subject: [PATCH 2/5] Disable timeout check
---
CMakeLists.txt | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 771ff98..9246d29 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -140,7 +140,7 @@ option(NVSHMEM_NVTX "Enable NVSHMEM NVTX support" ${NVSHMEM_NVTX_DEFAULT})
option(NVSHMEM_PMIX_SUPPORT "Enable Compilation of the PMIX bootstrap and PMIX specific code" $ENV{NVSHMEM_PMIX_SUPPORT})
option(NVSHMEM_SHMEM_SUPPORT "Enable Compilation of the SHMEM bootstrap and SHMEM specific code" $ENV{NVSHMEM_SHMEM_SUPPORT})
option(NVSHMEM_TEST_STATIC_LIB "Force tests to link only against the combined nvshmem.a binary" $ENV{NVSHMEM_TEST_STATIC_LIB})
-option(NVSHMEM_TIMEOUT_DEVICE_POLLING "Enable timeouts for NVSHMEM device-side polling functions (e.g. wait_until)" $ENV{NVSHMEM_TIMEOUT_DEVICE_POLLING})
+option(NVSHMEM_TIMEOUT_DEVICE_POLLING "Enable timeouts for NVSHMEM device-side polling functions (e.g. wait_until)" OFF)
option(NVSHMEM_TRACE "Enable NVSHMEM trace print events" $ENV{NVSHMEM_TRACE})
option(NVSHMEM_UCX_SUPPORT "Enable compilation of the UCX remote transport" $ENV{NVSHMEM_UCX_SUPPORT})
option(NVSHMEM_USE_DLMALLOC "Set dlmalloc as the NVSHMEM heap allocation method" $ENV{NVSHMEM_USE_DLMALLOC})
@@ -165,6 +165,7 @@ set(NVSHMEM_PREFIX ${NVSHMEM_PREFIX_DEFAULT} CACHE PATH "path to NVSHMEM install
set(PMIX_HOME ${PMIX_HOME_DEFAULT} CACHE PATH "path to PMIX installation")
set(SHMEM_HOME ${MPI_HOME} CACHE PATH "path to SHMEM installation")
set(UCX_HOME ${UCX_HOME_DEFAULT} CACHE PATH "path to UCX installation")
+set(NVSHMEM_TIMEOUT_DEVICE_POLLING OFF)
message(STATUS "NVSHMEM_PREFIX: ${NVSHMEM_PREFIX}")
message(STATUS "NVSHMEM_DEVEL: ${NVSHMEM_DEVEL}")
--
2.25.1
From 4e0eaff589d38f448715e43a935479451a41c0fe Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Fri, 10 Jan 2025 11:53:38 +0800
Subject: [PATCH 3/5] Add recv queue and recv cq for rc qps.
Let the ibgda rc qps use regular recv queue.
Add recv queue to ibgda dev qp.
IBGDA create recv cq
Setup recv cq.
fix recv queue.
Remove some useless idx.
Longer recv queue.
---
.../nvshmem_common_ibgda.h | 19 +++++-
src/modules/transport/ibgda/ibgda.cpp | 65 ++++++++++++++++---
2 files changed, 71 insertions(+), 13 deletions(-)
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
index 32f6d02..7d4e250 100644
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
@@ -168,14 +168,17 @@ typedef struct {
uint64_t get_head; // last wqe idx + 1 with a "fetch" operation (g, get, amo_fetch)
uint64_t get_tail; // last wqe idx + 1 polled with cst; get_tail > get_head is possible
} tx_wq;
+ struct {
+ uint64_t resv_head; // last reserved wqe idx + 1
+ } rx_wq;
struct {
uint64_t head;
uint64_t tail;
} ibuf;
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96,
- "ibgda_device_qp_management_v1 must be 96 bytes.");
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
+ "ibgda_device_qp_management_v1 must be 104 bytes.");
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
@@ -199,9 +202,19 @@ typedef struct nvshmemi_ibgda_device_qp {
// May point to mvars.prod_idx or internal prod_idx
uint64_t *prod_idx;
} tx_wq;
+ struct {
+ uint16_t nwqes;
+ uint64_t tail;
+ void *wqe;
+ __be32 *dbrec;
+ void *bf;
+ nvshmemi_ibgda_device_cq_t *cq;
+ // May point to mvars.prod_idx or internal prod_idx
+ uint64_t *prod_idx;
+ } rx_wq;
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
} nvshmemi_ibgda_device_qp_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, "ibgda_device_qp_v1 must be 184 bytes.");
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
index ff02f50..b8d6bc7 100644
--- a/src/modules/transport/ibgda/ibgda.cpp
+++ b/src/modules/transport/ibgda/ibgda.cpp
@@ -194,6 +194,7 @@ struct ibgda_ep {
off_t dbr_offset;
struct ibgda_cq *send_cq;
+ struct ibgda_cq *recv_cq;
struct ibv_ah *ah;
uint32_t user_index;
@@ -1520,7 +1521,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
struct ibv_context *context = device->context;
- unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes;
+ // Each RC qp has one send CQ and one recv CQ.
+ unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2;
assert(ibgda_qp_depth > 0);
size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
@@ -1683,7 +1685,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
}
// Allocate and map WQ buffer for all QPs.
- wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB; // num_wqebb is always a power of 2
+ // Todo: reduce the size of wq buffer.
+ wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2; // num_wqebb is always a power of 2
wq_buf_size = wq_buf_size_per_qp * num_eps;
status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "cannot allocate wq buf.\n");
@@ -1864,8 +1867,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
int cqe_version = 0;
struct ibgda_cq *send_cq = NULL;
+ struct ibgda_cq *recv_cq = NULL;
size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
+ size_t num_recv_wqe = ibgda_qp_depth;
+ size_t recv_wqe_size = 16;
int status = 0;
@@ -1893,6 +1899,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
status = ibgda_create_cq(&send_cq, device);
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
+ status = ibgda_create_cq(&recv_cq, device);
+ NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
+ }
+
ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep));
NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
"Unable to allocate mem for ep.\n");
@@ -1921,12 +1932,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED);
DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn);
DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id); // BF register
- DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue
- DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn);
- DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn);
+ DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn);
DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb));
- DEVX_SET(qpc, qp_context, log_rq_size, 0);
DEVX_SET(qpc, qp_context, cs_req, 0); // Disable CS Request
DEVX_SET(qpc, qp_context, cs_res, 0); // Disable CS Response
DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE); // Enable dbr_umem_id
@@ -1935,6 +1943,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id); // DBR buffer
DEVX_SET(qpc, qp_context, user_index, qp_idx);
DEVX_SET(qpc, qp_context, page_offset, 0);
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){
+ DEVX_SET(qpc, qp_context, rq_type, 0); // Regular recv queue
+ DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe
+ DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B
+ } else {
+ DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue, DC must use this.
+ DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
+ DEVX_SET(qpc, qp_context, log_rq_size, 0);
+ }
ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));
NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out,
@@ -1944,9 +1961,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
ep->portid = portid;
ep->sq_cnt = num_wqebb;
- ep->sq_buf_offset = 0;
+ ep->sq_buf_offset = num_recv_wqe * recv_wqe_size;
- ep->rq_cnt = 0;
+ ep->rq_cnt = num_recv_wqe;
ep->rq_buf_offset = 0;
ep->wq_mobject = device->qp_shared_object.wq_mobject;
@@ -1960,6 +1977,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
ep->uar_mobject = uar_mobject;
ep->send_cq = send_cq;
+ ep->recv_cq = recv_cq;
ep->qp_type = qp_type;
@@ -1971,6 +1989,7 @@ out:
if (status) {
if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject);
if (send_cq) ibgda_destroy_cq(send_cq);
+ if (recv_cq) ibgda_destroy_cq(recv_cq);
if (ep) free(ep);
}
@@ -2269,6 +2288,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) {
ibgda_destroy_cq(ep->send_cq);
}
+ if (ep->recv_cq) {
+ ibgda_destroy_cq(ep->recv_cq);
+ }
+
if (ep->ah) {
ftable.destroy_ah(ep->ah);
}
@@ -2300,7 +2323,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
dev_qp->qpn = ep->qpn;
assert(ep->wq_mobject->has_gpu_mapping);
- dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset);
+ dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset);
if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) {
assert(ep->dbr_mobject->has_gpu_mapping);
@@ -2312,6 +2335,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
}
dev_qp->tx_wq.nwqes = ep->sq_cnt;
+ if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
+ dev_qp->rx_wq.nwqes = ep->rq_cnt;
+ dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset);
+ dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset);
+ dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr;
+ }
ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr;
ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps);
@@ -2361,6 +2390,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
nvshmemi_ibgda_device_cq_t *cq_d = NULL;
nvshmemi_ibgda_device_cq_t *cq_h = NULL;
+ nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL;
+ nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL;
+
uint8_t *qp_group_switches_d = NULL;
const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars);
@@ -2368,6 +2400,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx);
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
+ const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
@@ -2405,7 +2438,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
num_dct_handles += device->dct.num_eps * n_pes;
num_dci_handles += device->dci.num_eps;
num_rc_handles += device->rc.num_eps_per_pe * n_pes;
- num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1));
+ num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2);
num_shared_dci_handles += device->dci.num_shared_eps;
}
num_elements = num_dct_handles - NVSHMEMI_IBGDA_MAX_CONST_DCTS;
@@ -2441,6 +2474,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
for (int i = 0; i < num_cq_handles; i++) {
nvshmemi_init_ibgda_device_cq(cq_h[i]);
}
+
+ recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h));
+ NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "recv_cq calloc err.");
+ nvshmemi_init_ibgda_device_cq(recv_cq_h[0]);
/* allocate host memory for dct, rc, cq, dci end */
/* allocate device memory for dct, rc, cq, dci start */
@@ -2544,6 +2581,14 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
}
++cq_idx;
+
+ rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx];
+
+ ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
+ cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
+ cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
+ cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
+ ++cq_idx;
}
}
}
--
2.25.1
From 0cc285269f154049f1c9775e07e306e03228eedc Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Sat, 8 Feb 2025 18:02:39 +0800
Subject: [PATCH 4/5] Maintain recv queue's cons_idx.
---
src/include/device_host_transport/nvshmem_common_ibgda.h | 5 +++--
src/modules/transport/ibgda/ibgda.cpp | 6 ++++--
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
index 7d4e250..502645d 100644
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
@@ -170,6 +170,7 @@ typedef struct {
} tx_wq;
struct {
uint64_t resv_head; // last reserved wqe idx + 1
+ uint64_t cons_idx; // polled wqe idx + 1 (consumer index + 1)
} rx_wq;
struct {
uint64_t head;
@@ -177,7 +178,7 @@ typedef struct {
} ibuf;
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112,
"ibgda_device_qp_management_v1 must be 104 bytes.");
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
@@ -214,7 +215,7 @@ typedef struct nvshmemi_ibgda_device_qp {
} rx_wq;
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
} nvshmemi_ibgda_device_qp_v1;
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, "ibgda_device_qp_v1 must be 248 bytes.");
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
index b8d6bc7..a1cfe2e 100644
--- a/src/modules/transport/ibgda/ibgda.cpp
+++ b/src/modules/transport/ibgda/ibgda.cpp
@@ -1063,7 +1063,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) {
ibgda_host_mem_free(mobject);
}
-static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) {
+static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) {
int status = 0;
struct ibgda_cq *gcq = NULL;
@@ -1114,7 +1114,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device)
cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context);
DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);
DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B);
- DEVX_SET(cqc, cq_context, cc, 0x1); // Use collapsed CQ
+ DEVX_SET(cqc, cq_context, cc, cc); // Use collapsed CQ
DEVX_SET(cqc, cq_context, oi, 0x1); // Allow overrun
DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id);
DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe));
@@ -2401,6 +2401,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
+ const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx);
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
@@ -2586,6 +2587,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
+ cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset);
cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
++cq_idx;
--
2.25.1
From f91eb8510f8c9aa4f5769bd88434db5ab000e65a Mon Sep 17 00:00:00 2001
From: Shangyan Zhou <sy.zhou@deepseek.com>
Date: Tue, 11 Feb 2025 11:00:57 +0800
Subject: [PATCH 5/5] Init rx_wq counters.
---
src/include/device_host_transport/nvshmem_common_ibgda.h | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
index 502645d..f0bc328 100644
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
@@ -46,6 +46,8 @@
qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
+ qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
+ qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
} while (0);
--
2.25.1