mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Initial commit
This commit is contained in:
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
compile_commands.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.pyc
|
||||
build/
|
||||
.cache/
|
||||
.vscode/
|
||||
*/cmake-build-*/
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
304
README.md
Normal file
304
README.md
Normal file
@@ -0,0 +1,304 @@
|
||||
# DeepEP
|
||||
|
||||
DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also as known as MoE dispatch and combine. The library also supports low-precision operations, including FP8.
|
||||
|
||||
To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control.
|
||||
|
||||
For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource.
|
||||
|
||||
Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper.
|
||||
|
||||
## Performance
|
||||
|
||||
### Normal kernels with NVLink and RDMA forwarding
|
||||
|
||||
We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining).
|
||||
|
||||
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
|
||||
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
|
||||
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
|
||||
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
|
||||
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
|
||||
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
|
||||
|
||||
### Low-latency kernels with pure RDMA
|
||||
|
||||
We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining).
|
||||
|
||||
| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth |
|
||||
|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:|
|
||||
| 8 | 163 us | 46 GB/s | 8 | 318 us | 46 GB/s |
|
||||
| 16 | 173 us | 43 GB/s | 16 | 329 us | 44 GB/s |
|
||||
| 32 | 182 us | 41 GB/s | 32 | 350 us | 41 GB/s |
|
||||
| 64 | 186 us | 40 GB/s | 64 | 353 us | 41 GB/s |
|
||||
| 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s |
|
||||
| 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s |
|
||||
|
||||
## Quick start
|
||||
|
||||
### Requirements
|
||||
|
||||
- Hopper GPUs (may support more architectures or devices later)
|
||||
- Python 3.8 and above
|
||||
- CUDA 12.3 and above
|
||||
- PyTorch 2.1 and above
|
||||
- NVLink for intranode communication
|
||||
- RDMA network for internode communication
|
||||
|
||||
### Download and install NVSHMEM dependency
|
||||
|
||||
DeepEP also depends on our modified NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions.
|
||||
|
||||
### Development
|
||||
|
||||
```bash
|
||||
# Build and make symbolic links for SO files
|
||||
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build
|
||||
# You may modify the specific SO names according to your own platform
|
||||
ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so
|
||||
|
||||
# Run test cases
|
||||
# NOTES: you may modify the `init_dist` function in `tests/utils.py`
|
||||
# according to your own cluster settings, and launch into multiple nodes
|
||||
python tests/test_intranode.py
|
||||
python tests/test_internode.py
|
||||
python tests/test_low_latency.py
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
|
||||
```
|
||||
|
||||
Then, import `deep_ep` in your Python project, and enjoy!
|
||||
|
||||
## Network configurations
|
||||
|
||||
DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well.
|
||||
|
||||
### Traffic isolation
|
||||
|
||||
Traffic isolation is supported by InfiniBand through Virtual Lanes (VL).
|
||||
|
||||
To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows:
|
||||
|
||||
- workloads using normal kernels
|
||||
- workloads using low-latency kernels
|
||||
- other workloads
|
||||
|
||||
For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable.
|
||||
|
||||
### Adaptive routing
|
||||
|
||||
Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Currently, low-latency kernels support adaptive routing, while normal kernels do not (support may be added soon). **Enabling adaptive routing for normal internode kernels may lead to deadlocks or data corruption issues**.
|
||||
|
||||
For low-latency kernels, enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance:
|
||||
|
||||
- enable adaptive routing in environments with heavy network loads
|
||||
- use static routing in environments with light network loads
|
||||
|
||||
### Congestion control
|
||||
|
||||
Congestion control is disabled as we have not observed significant congestion in our production environment.
|
||||
|
||||
## Interfaces and examples
|
||||
|
||||
### Example use in model training or inference prefilling
|
||||
|
||||
The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
from deep_ep import Buffer, EventOverlap
|
||||
|
||||
# Communication buffer (will allocate at runtime)
|
||||
_buffer: Optional[Buffer] = None
|
||||
|
||||
# Set the number of SMs to use
|
||||
# NOTES: this is a static variable
|
||||
Buffer.set_num_sms(24)
|
||||
|
||||
|
||||
# You may call this function at the framework initialization
|
||||
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
|
||||
global _buffer
|
||||
|
||||
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
|
||||
num_nvl_bytes, num_rdma_bytes = 0, 0
|
||||
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
|
||||
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
|
||||
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
|
||||
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
# NOTES: the adaptive routing configuration of the network **must be off**
|
||||
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
|
||||
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
||||
return _buffer
|
||||
|
||||
|
||||
def get_hidden_bytes(x: torch.Tensor) -> int:
|
||||
t = x[0] if isinstance(x, tuple) else x
|
||||
return t.size(1) * max(t.element_size(), 2)
|
||||
|
||||
|
||||
def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
|
||||
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
|
||||
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
|
||||
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
|
||||
# refer to the docs of `Buffer.dispatch`
|
||||
global _buffer
|
||||
|
||||
# Calculate layout before actual dispatch
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
|
||||
_buffer.get_dispatch_layout(topk_idx, num_experts,
|
||||
previous_event=previous_event, async_finish=True,
|
||||
allocate_on_comm_stream=previous_event is not None)
|
||||
# Do MoE dispatch
|
||||
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph
|
||||
# For more advanced usages, please refer to the docs of the `dispatch` function
|
||||
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
|
||||
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
|
||||
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
|
||||
previous_event=previous_event, async_finish=True,
|
||||
allocate_on_comm_stream=True)
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
|
||||
|
||||
|
||||
def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
|
||||
Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
|
||||
global _buffer
|
||||
|
||||
# The backward process of MoE dispatch is actually a combine
|
||||
# For more advanced usages, please refer to the docs of the `combine` function
|
||||
combined_grad_x, combined_grad_recv_topk_weights, event = \
|
||||
_buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
|
||||
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return combined_grad_x, combined_grad_recv_topk_weights, event
|
||||
|
||||
|
||||
def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
|
||||
Tuple[torch.Tensor, EventOverlap]:
|
||||
global _buffer
|
||||
|
||||
# Do MoE combine
|
||||
# For more advanced usages, please refer to the docs of the `combine` function
|
||||
combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None)
|
||||
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return combined_x, event
|
||||
|
||||
|
||||
def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
|
||||
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
|
||||
global _buffer
|
||||
|
||||
# The backward process of MoE combine is actually a dispatch
|
||||
# For more advanced usages, please refer to the docs of the `combine` function
|
||||
grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None)
|
||||
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return grad_x, event
|
||||
```
|
||||
|
||||
Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows.
|
||||
|
||||

|
||||
|
||||
### Example use in inference decoding
|
||||
|
||||
The low latency kernels can be used in the inference decoding phase as the below example code shows.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from deep_ep import Buffer
|
||||
|
||||
# Communication buffer (will allocate at runtime)
|
||||
# NOTES: there is no SM control API for the low-latency kernels
|
||||
_buffer: Optional[Buffer] = None
|
||||
|
||||
|
||||
# You may call this function at the framework initialization
|
||||
def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer:
|
||||
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||
global _buffer
|
||||
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts)
|
||||
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes:
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert num_experts % group.size() == 0
|
||||
_buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size())
|
||||
return _buffer
|
||||
|
||||
|
||||
def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int):
|
||||
global _buffer
|
||||
|
||||
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||
recv_hidden_states, recv_expert_count, handle, event, hook = \
|
||||
_buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish=False, return_recv_hook=True)
|
||||
|
||||
# NOTES: the actual tensor will not be received only if you call `hook()`,
|
||||
# it is useful for double-batch overlapping, but **without any SM occupation**
|
||||
# If you don't want to overlap, please set `return_recv_hook=False`
|
||||
# Later, you can use our GEMM library to do the computation with this specific format
|
||||
return recv_hidden_states, recv_expert_count, handle, event, hook
|
||||
|
||||
|
||||
def low_latency_combine(hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple):
|
||||
global _buffer
|
||||
|
||||
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||
combined_hidden_states, event_overlap, hook = \
|
||||
_buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle,
|
||||
async_finish=False, return_recv_hook=True)
|
||||
|
||||
# NOTES: the same behavior as described in the dispatch kernel
|
||||
return combined_hidden_states, event_overlap, hook
|
||||
```
|
||||
|
||||
For two micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffics are happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e. the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload.
|
||||
|
||||

|
||||
|
||||
## Notices
|
||||
|
||||
- For extreme performance, we discover and use an out-of-doc PTX instruction: `ld.global.nc.L1::no_allocate.L2::256B`. This instruction will lead to an undefined behavior: accessing volatile GPU memory with non-coherent read-only PTX modifiers `.nc`. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue.
|
||||
- For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster.
|
||||
|
||||
## License
|
||||
|
||||
This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{deepep2025,
|
||||
title={DeepEP: an efficient expert-parallel communication library},
|
||||
author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao},
|
||||
year={2025},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/deepseek-ai/DeepEP}},
|
||||
}
|
||||
```
|
||||
33
csrc/CMakeLists.txt
Normal file
33
csrc/CMakeLists.txt
Normal file
@@ -0,0 +1,33 @@
|
||||
# NOTES: this CMake is only for debugging; for setup, please use Torch extension
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
project(deep_ep LANGUAGES CUDA CXX)
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
|
||||
set(CUDA_SEPARABLE_COMPILATION ON)
|
||||
list(APPEND CUDA_NVCC_FLAGS "-O3")
|
||||
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
|
||||
|
||||
set(TORCH_CUDA_ARCH_LIST "9.0")
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
find_package(pybind11 REQUIRED)
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem)
|
||||
|
||||
add_library(nvshmem ALIAS nvshmem::nvshmem)
|
||||
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
|
||||
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
|
||||
|
||||
# Seems bugs with CMake, NVCC 12 and C++ 17
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 14)
|
||||
|
||||
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
|
||||
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
|
||||
|
||||
add_subdirectory(kernels)
|
||||
|
||||
# Link CPP and CUDA together
|
||||
pybind11_add_module(deep_ep_cpp deep_ep.cpp)
|
||||
target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
|
||||
177
csrc/config.hpp
Normal file
177
csrc/config.hpp
Normal file
@@ -0,0 +1,177 @@
|
||||
#pragma once
|
||||
|
||||
#include "kernels/api.cuh"
|
||||
#include "kernels/exception.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
template <typename dtype_t>
|
||||
dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
dtype_t align(dtype_t a, dtype_t b) {
|
||||
return cell_div<dtype_t>(a, b) * b;
|
||||
}
|
||||
|
||||
struct Config {
|
||||
int num_sms;
|
||||
int num_max_nvl_chunked_send_tokens;
|
||||
int num_max_nvl_chunked_recv_tokens;
|
||||
int num_max_rdma_chunked_send_tokens;
|
||||
int num_max_rdma_chunked_recv_tokens;
|
||||
|
||||
Config(int num_sms,
|
||||
int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) :
|
||||
num_sms(num_sms),
|
||||
num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens),
|
||||
num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens),
|
||||
num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens),
|
||||
num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) {
|
||||
EP_HOST_ASSERT(num_sms >= 0);
|
||||
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
|
||||
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
|
||||
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
|
||||
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
|
||||
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
|
||||
}
|
||||
|
||||
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
|
||||
// Below are some assumptions
|
||||
// TODO: add assertions
|
||||
constexpr int kNumMaxTopK = 128;
|
||||
constexpr int kNumMaxScales = 128;
|
||||
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
|
||||
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
|
||||
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
|
||||
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
|
||||
const int num_channels = num_sms / 2;
|
||||
|
||||
size_t num_bytes = 0;
|
||||
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
|
||||
num_bytes = ((num_bytes + 127) / 128) * 128;
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
|
||||
// Legacy mode
|
||||
if (num_ranks <= NUM_MAX_NVL_PEERS)
|
||||
return 0;
|
||||
|
||||
// Below are some assumptions
|
||||
// TODO: add assertions
|
||||
constexpr int kNumMaxTopK = 128;
|
||||
constexpr int kNumMaxScales = 128;
|
||||
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
|
||||
const int num_channels = num_sms / 2;
|
||||
|
||||
size_t num_bytes = 0;
|
||||
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
|
||||
num_bytes = ((num_bytes + 127) / 128) * 128;
|
||||
return num_bytes;
|
||||
}
|
||||
};
|
||||
|
||||
struct LowLatencyBuffer {
|
||||
int num_clean_int = 0;
|
||||
|
||||
void* dispatch_rdma_send_buffer = nullptr;
|
||||
void* dispatch_rdma_recv_data_buffer = nullptr;
|
||||
int* dispatch_rdma_recv_count_buffer = nullptr;
|
||||
int* dispatch_rdma_atomic_token_counter = nullptr;
|
||||
|
||||
void* combine_rdma_send_buffer = nullptr;
|
||||
void* combine_rdma_recv_data_buffer = nullptr;
|
||||
int* combine_rdma_recv_flag_buffer = nullptr;
|
||||
|
||||
std::pair<int*, int> clean_meta() {
|
||||
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
|
||||
return {dispatch_rdma_recv_count_buffer, num_clean_int};
|
||||
}
|
||||
};
|
||||
|
||||
struct LowLatencyLayout {
|
||||
size_t total_bytes = 0;
|
||||
LowLatencyBuffer buffers[2];
|
||||
|
||||
template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
|
||||
out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
|
||||
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
|
||||
}
|
||||
|
||||
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
|
||||
const int num_scales = hidden / 128;
|
||||
const int num_local_experts = num_experts / num_ranks;
|
||||
|
||||
// Dispatch and combine layout:
|
||||
// - 2 symmetric odd/even send buffer
|
||||
// - 2 symmetric odd/even receive buffers
|
||||
// - 2 symmetric odd/even signaling buffers
|
||||
|
||||
// Message sizes
|
||||
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
|
||||
size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4);
|
||||
size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16);
|
||||
|
||||
// Send buffer
|
||||
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
|
||||
size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
|
||||
size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes);
|
||||
EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0);
|
||||
total_bytes += send_buffer_bytes * 2;
|
||||
|
||||
// Symmetric receive buffers
|
||||
// TODO: optimize memory usages
|
||||
size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
|
||||
size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg;
|
||||
size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes);
|
||||
EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0);
|
||||
total_bytes += recv_buffer_bytes * 2;
|
||||
|
||||
// Symmetric signaling buffers
|
||||
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
|
||||
size_t dispatch_recv_atomic_token_counter_bytes = num_local_experts * sizeof(int);
|
||||
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
|
||||
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes + dispatch_recv_atomic_token_counter_bytes,
|
||||
combine_recv_flag_buffer_bytes);
|
||||
total_bytes += signaling_buffer_bytes * 2;
|
||||
|
||||
// Assign pointers
|
||||
// NOTES: we still leave some space for distinguishing dispatch/combine buffer,
|
||||
// so you may see some parameters are duplicated
|
||||
for (int i = 0; i < 2; ++ i) {
|
||||
buffers[i] = {
|
||||
static_cast<int>(signaling_buffer_bytes / sizeof(int)),
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i + dispatch_recv_count_buffer_bytes),
|
||||
advance(rdma_buffer, send_buffer_bytes * i),
|
||||
advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i),
|
||||
advance<int*>(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i)
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
|
||||
auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes;
|
||||
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
1208
csrc/deep_ep.cpp
Normal file
1208
csrc/deep_ep.cpp
Normal file
File diff suppressed because it is too large
Load Diff
149
csrc/deep_ep.hpp
Normal file
149
csrc/deep_ep.hpp
Normal file
@@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
// Forcibly disable NDEBUG
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/types.h>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "event.hpp"
|
||||
#include "kernels/configs.cuh"
|
||||
#include "kernels/exception.cuh"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME deep_ep_cpp
|
||||
#endif
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
struct Buffer {
|
||||
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
|
||||
|
||||
private:
|
||||
// Low-latency mode buffer
|
||||
int low_latency_buffer_idx = 0;
|
||||
bool low_latency_mode = false;
|
||||
|
||||
// NVLink Buffer
|
||||
int64_t num_nvl_bytes;
|
||||
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
|
||||
void** buffer_ptrs_gpu = nullptr;
|
||||
|
||||
// NVSHMEM Buffer
|
||||
int64_t num_rdma_bytes;
|
||||
void* rdma_buffer_ptr = nullptr;
|
||||
|
||||
// Device info and communication
|
||||
int device_id;
|
||||
int rank, rdma_rank, nvl_rank;
|
||||
int num_ranks, num_rdma_ranks, num_nvl_ranks;
|
||||
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
|
||||
|
||||
// Stream for communication
|
||||
at::cuda::CUDAStream comm_stream;
|
||||
|
||||
// After IPC/NVSHMEM synchronization, this flag will be true
|
||||
bool available = false;
|
||||
|
||||
// Task fifo
|
||||
int head = 0;
|
||||
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
|
||||
int** task_fifo_ptrs_gpu = nullptr;
|
||||
|
||||
// Workspace
|
||||
void* workspace = nullptr;
|
||||
|
||||
// Host-side MoE info
|
||||
volatile int* moe_recv_counter = nullptr;
|
||||
int* moe_recv_counter_mapped = nullptr;
|
||||
|
||||
// Host-side expert-level MoE info
|
||||
volatile int* moe_recv_expert_counter = nullptr;
|
||||
int* moe_recv_expert_counter_mapped = nullptr;
|
||||
|
||||
// Host-side RDMA-level MoE info
|
||||
volatile int* moe_recv_rdma_counter = nullptr;
|
||||
int* moe_recv_rdma_counter_mapped = nullptr;
|
||||
|
||||
private:
|
||||
void move_fifo_slots(int num_slots = 1);
|
||||
|
||||
public:
|
||||
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
|
||||
|
||||
~Buffer() noexcept(false);
|
||||
|
||||
bool is_available() const;
|
||||
|
||||
bool is_internode_available() const;
|
||||
|
||||
int get_num_rdma_ranks() const;
|
||||
|
||||
int get_rdma_rank() const;
|
||||
|
||||
int get_root_rdma_rank(bool global) const;
|
||||
|
||||
int get_local_device_id() const;
|
||||
|
||||
pybind11::bytearray get_local_ipc_handle() const;
|
||||
|
||||
pybind11::bytearray get_local_nvshmem_unique_id() const;
|
||||
|
||||
torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;
|
||||
|
||||
void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
|
||||
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
|
||||
bool async, bool allocate_on_comm_stream);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
|
||||
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
|
||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
|
||||
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
|
||||
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
|
||||
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
|
||||
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
|
||||
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
|
||||
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
|
||||
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
|
||||
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
|
||||
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
|
||||
|
||||
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook);
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook);
|
||||
};
|
||||
|
||||
} // namespace deep_ep
|
||||
43
csrc/event.hpp
Normal file
43
csrc/event.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <memory>
|
||||
|
||||
#include "kernels/exception.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
struct EventHandle {
|
||||
std::shared_ptr<torch::Event> event;
|
||||
|
||||
EventHandle() {
|
||||
event = std::make_shared<torch::Event>(torch::kCUDA);
|
||||
event->record(at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
explicit EventHandle(const at::cuda::CUDAStream& stream) {
|
||||
event = std::make_shared<torch::Event>(torch::kCUDA);
|
||||
event->record(stream);
|
||||
}
|
||||
|
||||
EventHandle(const EventHandle& other) = default;
|
||||
|
||||
void current_stream_wait() const {
|
||||
at::cuda::getCurrentCUDAStream().unwrap().wait(*event);
|
||||
}
|
||||
};
|
||||
|
||||
torch::Event create_event(const at::cuda::CUDAStream &s) {
|
||||
auto event = torch::Event(torch::kCUDA);
|
||||
event.record(s);
|
||||
return event;
|
||||
}
|
||||
|
||||
void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) {
|
||||
EP_HOST_ASSERT(s_0.id() != s_1.id());
|
||||
s_0.unwrap().wait(create_event(s_1));
|
||||
}
|
||||
|
||||
void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) {
|
||||
s.unwrap().wait(*event.event);
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
20
csrc/kernels/CMakeLists.txt
Normal file
20
csrc/kernels/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
function(add_deep_ep_library target_name source_file)
|
||||
add_library(${target_name} STATIC ${source_file})
|
||||
set_target_properties(${target_name} PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CUDA_STANDARD_REQUIRED ON
|
||||
CXX_STANDARD 14
|
||||
CUDA_STANDARD 14
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
)
|
||||
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
|
||||
endfunction()
|
||||
|
||||
add_deep_ep_library(intranode_cuda intranode.cu)
|
||||
add_deep_ep_library(runtime_cuda runtime.cu)
|
||||
add_deep_ep_library(internode_cuda internode.cu)
|
||||
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
|
||||
|
||||
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
|
||||
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
|
||||
153
csrc/kernels/api.cuh
Normal file
153
csrc/kernels/api.cuh
Normal file
@@ -0,0 +1,153 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
// Intranode runtime
|
||||
namespace intranode {
|
||||
|
||||
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
// Internode runtime
|
||||
namespace internode {
|
||||
|
||||
std::vector<uint8_t> get_unique_id();
|
||||
|
||||
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
|
||||
|
||||
void *alloc(size_t size, size_t alignment);
|
||||
|
||||
void free(void *ptr);
|
||||
|
||||
void barrier();
|
||||
|
||||
void finalize();
|
||||
|
||||
} // namespace internode
|
||||
|
||||
// Intranode kernels
|
||||
namespace intranode {
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
cudaStream_t stream, int num_sms);
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
const void* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
// Internode kernels
|
||||
namespace internode {
|
||||
|
||||
int get_source_meta_bytes();
|
||||
|
||||
void get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
const bool* is_token_in_rank, int num_tokens, int num_channels,
|
||||
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
|
||||
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank,
|
||||
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool low_latency_mode);
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
|
||||
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
int* send_rdma_head, int* send_nvl_head,
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
cudaStream_t stream, int num_channels, bool low_latency_mode);
|
||||
|
||||
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
|
||||
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool is_cached_dispatch, bool low_latency_mode);
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* combined_x, float* combined_topk_weights,
|
||||
const bool* is_combined_token_in_rank,
|
||||
const void* x, const float* topk_weights,
|
||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
|
||||
|
||||
} // namespace internode
|
||||
|
||||
// Internode low-latency kernels
|
||||
namespace internode_ll {
|
||||
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
} // namespace deep_ep
|
||||
138
csrc/kernels/buffer.cuh
Normal file
138
csrc/kernels/buffer.cuh
Normal file
@@ -0,0 +1,138 @@
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
template <typename dtype_t>
|
||||
struct Buffer {
|
||||
private:
|
||||
uint8_t* ptr;
|
||||
|
||||
public:
|
||||
int total_bytes;
|
||||
|
||||
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
|
||||
|
||||
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
|
||||
total_bytes = num_elems * sizeof(dtype_t);
|
||||
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer() {
|
||||
return reinterpret_cast<dtype_t*>(ptr);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t& operator[](int idx) {
|
||||
return buffer()[idx];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t, int kNumRanks = 1>
|
||||
struct AsymBuffer {
|
||||
private:
|
||||
uint8_t* ptrs[kNumRanks];
|
||||
int num_bytes;
|
||||
|
||||
public:
|
||||
int total_bytes;
|
||||
|
||||
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
|
||||
int sm_id = 0, int num_sms = 1, int offset = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks == 1, "");
|
||||
num_bytes = num_elems * sizeof(dtype_t);
|
||||
|
||||
int per_channel_bytes = num_bytes * num_ranks;
|
||||
total_bytes = per_channel_bytes * num_sms;
|
||||
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
|
||||
int sm_id = 0, int num_sms = 1, int offset = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks > 1, "");
|
||||
num_bytes = num_elems * sizeof(dtype_t);
|
||||
|
||||
int per_channel_bytes = num_bytes * num_ranks;
|
||||
total_bytes = per_channel_bytes * num_sms;
|
||||
for (int i = 0; i < kNumRanks; ++ i) {
|
||||
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
|
||||
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void advance(int shift) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i)
|
||||
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<int kNumAlsoRanks>
|
||||
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
|
||||
for (int i = 0; i < kNumAlsoRanks; ++ i)
|
||||
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
|
||||
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
|
||||
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t, bool kDecoupled = true>
|
||||
struct SymBuffer {
|
||||
private:
|
||||
// NOTES: for non-decoupled case, `recv_ptr` is not used
|
||||
uint8_t* send_ptr;
|
||||
uint8_t* recv_ptr;
|
||||
int num_bytes;
|
||||
|
||||
public:
|
||||
int total_bytes;
|
||||
|
||||
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
|
||||
int sm_id = 0, int num_sms = 1) {
|
||||
num_bytes = num_elems * sizeof(dtype_t);
|
||||
|
||||
int per_channel_bytes = num_bytes * num_ranks;
|
||||
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
|
||||
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
|
||||
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
|
||||
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
|
||||
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
|
||||
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_ep
|
||||
50
csrc/kernels/configs.cuh
Normal file
50
csrc/kernels/configs.cuh
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#define NUM_MAX_NVL_PEERS 8
|
||||
#define NUM_MAX_RDMA_PEERS 20
|
||||
#define NUM_MAX_FIFO_SLOTS 32768
|
||||
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
|
||||
#define NUM_MAX_LOCAL_EXPERTS 1024
|
||||
#define NUM_BUFFER_ALIGNMENT_BYTES 128
|
||||
|
||||
#define FINISHED_SUM_TAG 1024
|
||||
#define NUM_CPU_TIMEOUT_SECS 100
|
||||
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
|
||||
#define NUM_WAIT_NANOSECONDS 500
|
||||
|
||||
#define LOW_LATENCY_SEND_PHASE 1
|
||||
#define LOW_LATENCY_RECV_PHASE 2
|
||||
|
||||
// Make CLion CUDA indexing work
|
||||
#ifdef __CLION_IDE__
|
||||
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
|
||||
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
// Remove Torch restrictions
|
||||
#ifdef __CUDA_NO_HALF_CONVERSIONS__
|
||||
#undef __CUDA_NO_HALF_CONVERSIONS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_HALF_OPERATORS__
|
||||
#undef __CUDA_NO_HALF_OPERATORS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_HALF2_OPERATORS__
|
||||
#undef __CUDA_NO_HALF2_OPERATORS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__
|
||||
#undef __CUDA_NO_BFLOAT16_CONVERSIONS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_BFLOAT162_OPERATORS__
|
||||
#undef __CUDA_NO_BFLOAT162_OPERATORS__
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nvshmem.h>
|
||||
#include <nvshmemx.h>
|
||||
#include <infiniband/mlx5dv.h>
|
||||
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
|
||||
#include <device_host_transport/nvshmem_common_ibgda.h>
|
||||
51
csrc/kernels/exception.cuh
Normal file
51
csrc/kernels/exception.cuh
Normal file
@@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <exception>
|
||||
|
||||
#include "configs.cuh"
|
||||
|
||||
#ifndef EP_STATIC_ASSERT
|
||||
#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
class EPException: public std::exception {
|
||||
private:
|
||||
std::string message = {};
|
||||
|
||||
public:
|
||||
explicit EPException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef EP_HOST_ASSERT
|
||||
#define EP_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
throw EPException("Assertion", __FILE__, __LINE__, #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef EP_DEVICE_ASSERT
|
||||
#define EP_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
423
csrc/kernels/ibgda_device.cuh
Normal file
423
csrc/kernels/ibgda_device.cuh
Normal file
@@ -0,0 +1,423 @@
|
||||
// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)
|
||||
// Copyright (c) NVIDIA Corporation.
|
||||
// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).
|
||||
// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html
|
||||
//
|
||||
// Modified from original source:
|
||||
// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint64_t HtoBE64(uint64_t x) {
|
||||
uint64_t ret;
|
||||
asm("{\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
".reg .b32 lo;\n\t"
|
||||
".reg .b32 hi;\n\t"
|
||||
".reg .b32 new_lo;\n\t"
|
||||
".reg .b32 new_hi;\n\t"
|
||||
"mov.b64 {lo,hi}, %1;\n\t"
|
||||
"prmt.b32 new_hi, lo, ign, 0x0123;\n\t"
|
||||
"prmt.b32 new_lo, hi, ign, 0x0123;\n\t"
|
||||
"mov.b64 %0, {new_lo,new_hi};\n\t"
|
||||
"}" : "=l"(ret) : "l"(x));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint32_t HtoBE32(uint32_t x) {
|
||||
uint32_t ret;
|
||||
asm("{\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"prmt.b32 %0, %1, ign, 0x0123;\n\t"
|
||||
"}" : "=r"(ret) : "r"(x));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint16_t HtoBE16(uint16_t x) {
|
||||
// TODO: simplify PTX using 16-bit instructions
|
||||
auto a = static_cast<uint32_t>(x);
|
||||
uint32_t d;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .b32 mask;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"mov.b32 mask, 0x4401;\n\t"
|
||||
"mov.b32 ign, 0x0;\n\t"
|
||||
"prmt.b32 %0, %1, ign, mask;\n\t"
|
||||
"}"
|
||||
: "=r"(d)
|
||||
: "r"(a));
|
||||
return static_cast<uint16_t>(d);
|
||||
}
|
||||
|
||||
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
|
||||
|
||||
__device__ static __forceinline__
|
||||
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
|
||||
return &nvshmemi_ibgda_device_state_d;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
|
||||
auto state = ibgda_get_state();
|
||||
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
|
||||
return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe];
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_lock_acquire(int *lock) {
|
||||
while (atomicCAS(lock, 0, 1) == 1);
|
||||
|
||||
// Prevent reordering before the lock is acquired
|
||||
memory_fence_cta();
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_lock_release(int *lock) {
|
||||
memory_fence_cta();
|
||||
|
||||
// Prevent reordering before lock is released
|
||||
st_na_relaxed(lock, 0);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) {
|
||||
// `DBREC` contains the index of the next empty `WQEBB`
|
||||
__be32 dbrec_val;
|
||||
__be32 *dbrec_ptr = qp->tx_wq.dbrec;
|
||||
|
||||
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`
|
||||
asm("{\n\t"
|
||||
".reg .b32 dbrec_head_16b;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
|
||||
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
|
||||
"}"
|
||||
: "=r"(dbrec_val)
|
||||
: "r"(dbrec_head));
|
||||
st_na_release(dbrec_ptr, dbrec_val);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) {
|
||||
auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);
|
||||
ibgda_ctrl_seg_t ctrl_seg = {
|
||||
.opmod_idx_opcode = HtoBE32(prod_idx << 8),
|
||||
.qpn_ds = HtoBE32(qp->qpn << 8)
|
||||
};
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), "");
|
||||
st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) {
|
||||
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
|
||||
uint64_t old_prod_idx;
|
||||
|
||||
// Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence
|
||||
ibgda_lock_acquire(&mvars->post_send_lock);
|
||||
|
||||
old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
|
||||
if (new_prod_idx > old_prod_idx) {
|
||||
ibgda_update_dbr(qp, new_prod_idx);
|
||||
ibgda_ring_db(qp, new_prod_idx);
|
||||
}
|
||||
ibgda_lock_release(&mvars->post_send_lock);
|
||||
}
|
||||
|
||||
template <bool kAlwaysDoPostSend>
|
||||
__device__ static __forceinline__
|
||||
void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
|
||||
uint32_t num_wqes, int message_idx = 0) {
|
||||
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
|
||||
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
|
||||
|
||||
// WQE writes must be finished first
|
||||
__threadfence();
|
||||
|
||||
// Wait for prior WQE slots to be filled first
|
||||
auto *ready_idx = reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.ready_head);
|
||||
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
|
||||
|
||||
// Always post, not in batch
|
||||
constexpr int kNumRequestInBatch = 4;
|
||||
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
|
||||
ibgda_post_send(qp, new_wqe_idx);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
|
||||
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
|
||||
ibgda_ctrl_seg_t ctrl_seg;
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_inl_data_seg inl_seg;
|
||||
|
||||
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
|
||||
auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
|
||||
auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
|
||||
auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));
|
||||
|
||||
raddr_seg.raddr = HtoBE64(raddr);
|
||||
raddr_seg.rkey = rkey;
|
||||
raddr_seg.reserved = 0;
|
||||
|
||||
inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);
|
||||
|
||||
// `imm == std::numeric_limits<uint32_t>::max()` means no imm writes
|
||||
ctrl_seg = {0};
|
||||
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
|
||||
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
|
||||
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));
|
||||
if (imm != std::numeric_limits<uint32_t>::max())
|
||||
ctrl_seg.imm = HtoBE32(imm);
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4");
|
||||
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
|
||||
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
|
||||
st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));
|
||||
st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
|
||||
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
|
||||
auto state = ibgda_get_state();
|
||||
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
|
||||
auto log2_cumem_granularity = state->log2_cumem_granularity;
|
||||
|
||||
// Local key
|
||||
uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity;
|
||||
auto device_key = state->constmem.lkeys[idx];
|
||||
auto lchunk_size = device_key.next_addr - laddr;
|
||||
*lkey = device_key.key;
|
||||
|
||||
// Remote key
|
||||
uint64_t roffset = raddr - heap_start;
|
||||
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
|
||||
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
|
||||
device_key = state->constmem.rkeys[idx];
|
||||
} else {
|
||||
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
|
||||
}
|
||||
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
|
||||
*out_rkey = device_key.key;
|
||||
|
||||
// Return the minimum of local and remote chunk sizes
|
||||
auto rchunk_size = device_key.next_addr - roffset;
|
||||
return min(lchunk_size, rchunk_size);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
|
||||
auto state = ibgda_get_state();
|
||||
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
|
||||
|
||||
uint64_t roffset = addr - heap_start;
|
||||
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
|
||||
nvshmemi_ibgda_device_key_t device_key;
|
||||
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
|
||||
device_key = state->constmem.rkeys[idx];
|
||||
else
|
||||
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
|
||||
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
|
||||
*out_rkey = device_key.key;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ uint64_t
|
||||
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
|
||||
auto mvars = &qp->mvars;
|
||||
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void*
|
||||
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
|
||||
uint16_t cnt = qp->tx_wq.nwqes;
|
||||
uint16_t idx = wqe_idx & (cnt - 1);
|
||||
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
|
||||
}
|
||||
|
||||
// Wait until wqe `idx - 1` is completed.
|
||||
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
|
||||
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) {
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
auto cq = qp->rx_wq.cq;
|
||||
|
||||
const uint32_t ncqes = cq->ncqes;
|
||||
auto *cqe64 = reinterpret_cast<struct mlx5_cqe64*>(cq->cqe);
|
||||
auto old_cons_idx = *cq->cons_idx;
|
||||
*cq->cons_idx = old_cons_idx + 1;
|
||||
|
||||
// Wait until `wqe_counter >= old_cons_idx`
|
||||
while ((static_cast<uint16_t>(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
|
||||
// Get rkey
|
||||
// NOTES: the `p` operation will not cross multiple remote chunks
|
||||
__be32 rkey;
|
||||
uint64_t raddr;
|
||||
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey);
|
||||
|
||||
// Write WQEs
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
void *wqe_ptrs;
|
||||
wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
|
||||
ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);
|
||||
|
||||
// Submit requests
|
||||
ibgda_submit_requests<true>(qp, base_wqe_idx, 1);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
|
||||
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
|
||||
void **out_wqes) {
|
||||
ibgda_ctrl_seg_t ctrl_seg;
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_data_seg data_seg;
|
||||
|
||||
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
|
||||
void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
|
||||
struct mlx5_wqe_raddr_seg *raddr_seg_ptr;
|
||||
struct mlx5_wqe_data_seg *data_seg_ptr;
|
||||
|
||||
raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));
|
||||
data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
|
||||
|
||||
raddr_seg.raddr = HtoBE64(raddr);
|
||||
raddr_seg.rkey = rkey;
|
||||
raddr_seg.reserved = 0;
|
||||
|
||||
data_seg.byte_count = HtoBE32(bytes);
|
||||
data_seg.lkey = lkey;
|
||||
data_seg.addr = HtoBE64(laddr);
|
||||
|
||||
ctrl_seg = {0};
|
||||
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
|
||||
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
|
||||
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16");
|
||||
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
|
||||
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_empty_recv_wqe(void *out_wqe) {
|
||||
auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);
|
||||
struct mlx5_wqe_data_seg data_seg;
|
||||
|
||||
// Make the first segment in the WQE invalid, then the entire list will be invalid
|
||||
data_seg.byte_count = 0;
|
||||
data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);
|
||||
data_seg.addr = 0;
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length");
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ uint64_t
|
||||
nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
|
||||
auto mvars = &qp->mvars;
|
||||
|
||||
// Allocate if not enough
|
||||
constexpr int kMinIBGDARecvs = 32;
|
||||
auto resv_head = mvars->rx_wq.resv_head;
|
||||
auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx;
|
||||
if (num_valid_slots < kMinIBGDARecvs) {
|
||||
resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes;
|
||||
mvars->rx_wq.resv_head = resv_head;
|
||||
|
||||
// Ensure WQE is written before `dbrec`
|
||||
__be32 dbrec_val;
|
||||
__be32 *dbrec_ptr = qp->rx_wq.dbrec;
|
||||
|
||||
// Compared to sending, for each QP, we only post recv in a single thread,
|
||||
// so we don't need to do synchronization here
|
||||
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
|
||||
asm("{\n\t"
|
||||
".reg .b32 dbrec_head_16b;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
|
||||
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
|
||||
"}" : "=r"(dbrec_val)
|
||||
: "r"(static_cast<uint32_t>(resv_head)));
|
||||
st_na_release(dbrec_ptr, dbrec_val);
|
||||
}
|
||||
|
||||
// Return old number of slots
|
||||
return num_valid_slots;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
|
||||
// NOTES: only one thread can run this function
|
||||
// TODO: consider this assertion for normal AR
|
||||
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||
// Get lkey and rkey, store them into lanes
|
||||
uint32_t num_wqes = 0;
|
||||
__be32 my_lkey = 0;
|
||||
uint64_t my_laddr = 0;
|
||||
__be32 my_rkey = 0;
|
||||
uint64_t my_raddr = 0;
|
||||
uint64_t my_chunk_size = 0;
|
||||
|
||||
// Decide how many messages (theoretically 3 for maximum)
|
||||
auto remaining_bytes = bytes;
|
||||
while (remaining_bytes > 0) {
|
||||
if (lane_id == num_wqes)
|
||||
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey));
|
||||
|
||||
// Move one more message
|
||||
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
|
||||
remaining_bytes -= chunk_size;
|
||||
req_lptr += chunk_size;
|
||||
req_rptr += chunk_size;
|
||||
++ num_wqes;
|
||||
}
|
||||
EP_DEVICE_ASSERT(num_wqes <= 32);
|
||||
|
||||
// Process WQE
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t base_wqe_idx = 0;
|
||||
if (lane_id == 0)
|
||||
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
|
||||
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
|
||||
if (lane_id < num_wqes) {
|
||||
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
|
||||
ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size,
|
||||
base_wqe_idx, &wqe_ptr);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Submit
|
||||
if (lane_id == 0)
|
||||
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
1720
csrc/kernels/internode.cu
Normal file
1720
csrc/kernels/internode.cu
Normal file
File diff suppressed because it is too large
Load Diff
533
csrc/kernels/internode_ll.cu
Normal file
533
csrc/kernels/internode_ll.cu
Normal file
@@ -0,0 +1,533 @@
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace internode_ll {
|
||||
|
||||
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
|
||||
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1) {
|
||||
// Barrier before cleaning (in case of unfinished chunked EP)
|
||||
nvshmemx_barrier_all_block();
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
|
||||
clean_0[i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
|
||||
clean_1[i] = 0;
|
||||
|
||||
// Barrier after cleaning (make sure low-latency mode work fine)
|
||||
nvshmemx_barrier_all_block();
|
||||
}
|
||||
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kNumThreads = 256;
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
|
||||
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||
}
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_local_experts = num_experts / num_ranks;
|
||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// FP8 staffs
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_int4 = kHidden / sizeof(int4);
|
||||
|
||||
// Message package: hidden data, FP8 scales, index at source
|
||||
// NOTES: currently we have 3 reserved int fields for future use
|
||||
const size_t num_bytes_per_msg = kHidden + num_scales * sizeof(float) + sizeof(int4);
|
||||
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
|
||||
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
|
||||
|
||||
// Sending phase
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||
goto LOW_LATENCY_DISPATCH_RECV;
|
||||
|
||||
// Expert counts
|
||||
__shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups];
|
||||
|
||||
// There are 2 kinds of warps in this part:
|
||||
// 1. The first-kind warps for FP8 cast and sending top-k tokens
|
||||
// 2. The last warp for reading `topk_idx` and count for per-expert information
|
||||
if (warp_id < num_warps - 1) {
|
||||
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
|
||||
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
|
||||
const auto num_threads = (num_warps - 1) * 32;
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
|
||||
|
||||
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
||||
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden);
|
||||
const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales);
|
||||
|
||||
// Overlap top-k index read and source token index write
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
// FP8 cast
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
|
||||
// Read and calculate local amax
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
float amax = kFP8Margin, scale, scale_inv;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; ++ j) {
|
||||
fp32_values[j] = static_cast<float>(bf16_values[j]);
|
||||
amax = fmaxf(amax, fabsf(fp32_values[j]));
|
||||
}
|
||||
|
||||
// Reduce amax and scale
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
|
||||
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
|
||||
if (lane_id == 0 or lane_id == 16)
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
// Cast into send buffer
|
||||
int2 int2_value;
|
||||
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; j += 2) {
|
||||
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
|
||||
fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
rdma_x_int2[i] = int2_value;
|
||||
}
|
||||
asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
|
||||
|
||||
// Issue IBGDA sends
|
||||
if (dst_expert_idx >= 0) {
|
||||
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
|
||||
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
|
||||
const auto dst_rank = dst_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
|
||||
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
slot_idx * num_bytes_per_msg;
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
|
||||
} else {
|
||||
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
}
|
||||
|
||||
// Increase counter after finishing
|
||||
__syncwarp();
|
||||
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
|
||||
}
|
||||
}
|
||||
} else if (warp_id == num_warps - 1) {
|
||||
EP_DEVICE_ASSERT(num_sms > 1);
|
||||
if (sm_id == 0) {
|
||||
// The first SM is also responsible for checking QPs
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
|
||||
|
||||
// The first SM is also responsible for cleaning the next buffer
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_next_clean_int; i += 32)
|
||||
next_clean[i] = 0;
|
||||
|
||||
// Notify before executing `int_p`
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_experts; i += 32)
|
||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
|
||||
}
|
||||
|
||||
// This SM should be responsible for some destination experts, read `topk_idx` for them
|
||||
int expert_count[kNumWarpGroups] = {0};
|
||||
const auto expert_begin_idx = sm_id * kNumWarpGroups;
|
||||
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
|
||||
|
||||
// Per lane count
|
||||
#pragma unroll 8
|
||||
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
|
||||
auto idx = static_cast<int>(__ldg(topk_idx + i));
|
||||
if (idx >= expert_begin_idx and idx < expert_end_idx)
|
||||
expert_count[idx - expert_begin_idx] ++;
|
||||
}
|
||||
|
||||
// Warp reduce
|
||||
#pragma unroll
|
||||
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
|
||||
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
|
||||
if (lane_id == 0) {
|
||||
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
|
||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Issue count sends
|
||||
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
|
||||
|
||||
// Wait local sends issued and send expert counts
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
|
||||
dst_rank, dst_expert_local_idx, 0);
|
||||
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||
}
|
||||
|
||||
// Clean workspace for next use
|
||||
atomic_counter_per_expert[responsible_expert_idx] = 0;
|
||||
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Receiving phase
|
||||
LOW_LATENCY_DISPATCH_RECV:
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
// Receiving and packing
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
|
||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||
|
||||
// Shared between sub-warps in warp groups
|
||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
||||
|
||||
// Wait tokens to arrive
|
||||
// NOTES: using sub-warp 1 to overlap with sub-warp 0
|
||||
int num_recv_tokens, recv_token_begin_idx;
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
if (src_rank != rank) {
|
||||
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
|
||||
num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
|
||||
EP_DEVICE_ASSERT(num_recv_tokens != 0);
|
||||
} else {
|
||||
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||
}
|
||||
num_recv_tokens = -num_recv_tokens - 1;
|
||||
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
|
||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
||||
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
||||
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
||||
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
|
||||
|
||||
// Copy tokens
|
||||
EP_DEVICE_ASSERT(num_scales <= 64);
|
||||
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
|
||||
// Copy data
|
||||
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
|
||||
const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
||||
const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global);
|
||||
|
||||
// Copy scales
|
||||
const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
|
||||
// Copy source info
|
||||
const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales);
|
||||
if (lane_id == 0)
|
||||
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
|
||||
// NOTES: this part will be cleaned in `combine`
|
||||
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases); break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
#undef DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int* atomic_clean_flag,
|
||||
int num_combined_tokens, int hidden, int num_topk,
|
||||
int num_max_dispatch_tokens_per_rank,
|
||||
int num_experts, int rank, int num_ranks,
|
||||
int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto num_threads = static_cast<int>(blockDim.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
const auto num_local_experts = num_experts / num_ranks;
|
||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// Data type staffs
|
||||
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
|
||||
|
||||
// Message package
|
||||
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
|
||||
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16);
|
||||
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
|
||||
|
||||
// Sending phase
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||
goto LOW_LATENCY_COMBINE_RECV;
|
||||
|
||||
// Clean up next buffer
|
||||
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_next_clean_int; i += 32)
|
||||
next_clean[i] = 0;
|
||||
|
||||
// Notify before executing `int_p`
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
atomic_add_release_global(atomic_clean_flag, num_experts);
|
||||
}
|
||||
|
||||
// FP8 cast and issue IBGDA sends
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
|
||||
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
|
||||
const auto local_x = reinterpret_cast<const int4*>(x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
|
||||
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
|
||||
|
||||
// Unpack layout
|
||||
int offset, num_tokens_to_send;
|
||||
unpack2(layout, num_tokens_to_send, offset);
|
||||
|
||||
// Issue IBGDA send
|
||||
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
|
||||
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
|
||||
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
|
||||
|
||||
// Copy directly to local rank, or copy to buffer and issue RDMA
|
||||
auto src_idx = __ldg(local_src_info + token_idx);
|
||||
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
|
||||
if (dst_rank == rank) {
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Put finishing flag
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0);
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||
}
|
||||
atomic_add_release_global(atomic_clean_flag, -1);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Receiving phase
|
||||
LOW_LATENCY_COMBINE_RECV:
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
// Wait all ranks to arrive and notify PCIe usage
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
||||
if (sub_warp_id == 0 and lane_id == 0) {
|
||||
// TODO: refactor QP indices
|
||||
auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
auto src_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
if (src_rank != rank) {
|
||||
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
|
||||
} else {
|
||||
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
cg::this_grid().sync();
|
||||
|
||||
// Reduce tokens with FP8 cast
|
||||
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
|
||||
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
|
||||
if (thread_id < hidden_bf16_int4) {
|
||||
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
|
||||
// Read top-k indices and weights
|
||||
int reg_topk_idx[kNumMaxTopk];
|
||||
float reg_topk_weights[kNumMaxTopk];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk; ++ i) {
|
||||
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
|
||||
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
|
||||
}
|
||||
|
||||
float combined_values[kNumElemsPerInt4] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
|
||||
// Read from sources
|
||||
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
|
||||
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
|
||||
|
||||
// Reduce
|
||||
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
|
||||
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
|
||||
}
|
||||
|
||||
// Write results
|
||||
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
|
||||
auto combined_bf16 = reinterpret_cast<nv_bfloat16*>(&combined_values);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
|
||||
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
|
||||
// Check workspace
|
||||
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
|
||||
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(hidden) { \
|
||||
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
combined_x, \
|
||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||
x, topk_idx, topk_weights, src_info, layout_range, \
|
||||
next_clean, num_next_clean_int, \
|
||||
atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
|
||||
#undef COMBINE_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
} // namespace deep_ep
|
||||
803
csrc/kernels/intranode.cu
Normal file
803
csrc/kernels/intranode.cu
Normal file
@@ -0,0 +1,803 @@
|
||||
#include "configs.cuh"
|
||||
#include "buffer.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace intranode {
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
|
||||
|
||||
if (sm_id == 0) {
|
||||
// Barrier first
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
int *per_rank_buffer, *per_expert_buffer;
|
||||
if (thread_id < kNumRanks) {
|
||||
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
|
||||
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
|
||||
}
|
||||
|
||||
// After this loop:
|
||||
// - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j
|
||||
// - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j
|
||||
int num_experts_per_rank = num_experts / kNumRanks;
|
||||
if (thread_id < kNumRanks) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i)
|
||||
per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_experts_per_rank; ++ i)
|
||||
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for all ranks to be finished
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
// Sum per-rank counts and return to CPU
|
||||
// Also pre-compute the prefix sum for data sending
|
||||
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
if (thread_id < kNumRanks) {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < kNumRanks; ++ i)
|
||||
local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];
|
||||
if (thread_id == rank)
|
||||
*moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];
|
||||
}
|
||||
|
||||
// Sum per-experts counts and return to CPU
|
||||
auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;
|
||||
if (thread_id < num_experts_per_rank) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i)
|
||||
sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];
|
||||
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
|
||||
moe_recv_expert_counter_mapped[thread_id] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Copy rank size prefix matrix to another tensor
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
||||
rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];
|
||||
|
||||
// Extra memset for later communication queue
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
local_per_expert_buffer[i] = 0;
|
||||
|
||||
// Barrier
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
} else {
|
||||
int dst_rank = sm_id - 1;
|
||||
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
||||
|
||||
// Iterate over tokens
|
||||
int count = 0;
|
||||
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32)
|
||||
count += is_token_in_rank[i * kNumRanks + dst_rank];
|
||||
count = warp_reduce_sum(count);
|
||||
if (lane_id == 0)
|
||||
channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Pre-compute prefix sum for all channels
|
||||
if (thread_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_channels; ++ i)
|
||||
channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
cudaStream_t stream, int num_channels) {
|
||||
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
|
||||
num_tokens_per_rank, moe_recv_counter_mapped, \
|
||||
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
|
||||
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
|
||||
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
constexpr int kNumThreads = 128;
|
||||
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
|
||||
SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
|
||||
#undef NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
|
||||
// A simplified version for cached handles
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
// Copy and clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
||||
ptr[i] = rank_prefix_matrix[i];
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
ptr[kNumRanks * kNumRanks + i] = 0;
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
}
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs,
|
||||
int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
|
||||
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 128, stream);
|
||||
SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);
|
||||
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void __launch_bounds__(kNumRanks * 32, 1)
|
||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void **buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
EP_DEVICE_ASSERT(num_sms % 2 == 0);
|
||||
|
||||
// Each warp is responsible for a single rank
|
||||
const auto num_channels = num_sms / 2;
|
||||
const auto responsible_rank = (static_cast<int>(thread_id)) / 32;
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving
|
||||
const auto responsible_channel = sm_id / 2;
|
||||
|
||||
int num_experts_per_rank = num_experts / kNumRanks;
|
||||
EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);
|
||||
EP_DEVICE_ASSERT(num_topk <= 32);
|
||||
EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
|
||||
EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
|
||||
int target_rank = is_sender ? rank : responsible_rank;
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
|
||||
|
||||
// Channel buffer metadata
|
||||
// Senders are responsible for tails, and receivers are responsible for heads
|
||||
// Stored on the receiver side
|
||||
// The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t`
|
||||
// `start_offset`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `end_offset`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_end_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
|
||||
// Channel data buffers, stored on the receiver side
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t)
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
// `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)
|
||||
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
|
||||
auto channel_topk_idx_buffers = Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
constexpr int num_send_warps = kNumRanks;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
EP_DEVICE_ASSERT(num_send_warps == kNumRanks and send_warp_id == responsible_rank);
|
||||
|
||||
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
|
||||
// NOTES: this is for distinguishing zero tokens
|
||||
if (send_lane_id == 0) {
|
||||
int value = responsible_channel > 0 ? channel_prefix_matrix[send_warp_id * num_channels + responsible_channel - 1] : 0;
|
||||
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
|
||||
value = channel_prefix_matrix[send_warp_id * num_channels + responsible_channel];
|
||||
st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Get tasks
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
|
||||
|
||||
// Iterate over all tokens and send by chunks
|
||||
int cached_channel_tail_idx = 0;
|
||||
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
auto start_time = clock64();
|
||||
while (send_lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
|
||||
break;
|
||||
|
||||
// Rare cases to loop again
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
int chunk_token_idx = 0;
|
||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||
if (send_lane_id == 0)
|
||||
send_head[token_idx * kNumRanks + send_warp_id] = is_token_in_rank[token_idx * kNumRanks + send_warp_id] ? cached_channel_tail_idx : -1;
|
||||
// Skip if not selected
|
||||
if (not is_token_in_rank[token_idx * kNumRanks + send_warp_id]) {
|
||||
token_idx ++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get an empty slot
|
||||
int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens;
|
||||
|
||||
// Copy data
|
||||
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x + token_idx * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
|
||||
__ldg, st_na_global);
|
||||
|
||||
// Copy source index
|
||||
if (send_lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights` with transformed index
|
||||
if (send_lane_id < num_topk) {
|
||||
// Top-k index
|
||||
int recv_expert_begin = send_warp_id * num_experts_per_rank, recv_expert_end = (send_warp_id + 1) * num_experts_per_rank;
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
|
||||
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
|
||||
|
||||
// Top-k weights
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
|
||||
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = send_lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
|
||||
// Move token index
|
||||
chunk_token_idx ++, token_idx ++;
|
||||
}
|
||||
|
||||
// Move tail index
|
||||
__syncwarp();
|
||||
if (send_lane_id == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
// Workers for receiving and copying into buffer
|
||||
constexpr int num_recv_warps = kNumRanks;
|
||||
const auto recv_thread_id = thread_id;
|
||||
const auto recv_warp_id = recv_thread_id / 32;
|
||||
const auto recv_lane_id = recv_thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and recv_warp_id == responsible_rank);
|
||||
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps == kNumRanks);
|
||||
|
||||
// Calculate offset first
|
||||
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
int rank_offset = recv_warp_id > 0 ? rank_prefix_matrix[(recv_warp_id - 1) * kNumRanks + rank] : 0;
|
||||
|
||||
// Receive channel offset
|
||||
int total_offset, num_tokens_to_recv;
|
||||
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
|
||||
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (recv_lane_id == 0) {
|
||||
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
|
||||
recv_channel_offset[recv_warp_id * num_channels + responsible_channel] = total_offset;
|
||||
num_tokens_to_recv -= total_offset;
|
||||
}
|
||||
total_offset = __shfl_sync(0xffffffff, total_offset, 0);
|
||||
total_offset += rank_offset;
|
||||
num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0);
|
||||
|
||||
auto start_time = clock64();
|
||||
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
||||
while (num_tokens_to_recv > 0) {
|
||||
// Check channel status by lane 0
|
||||
while (recv_lane_id == 0) {
|
||||
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
|
||||
|
||||
// Ready to copy
|
||||
if (cached_channel_head_idx != cached_channel_tail_idx)
|
||||
break;
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
|
||||
// Sync queue tail
|
||||
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
|
||||
|
||||
// Copy data
|
||||
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
|
||||
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx) {
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
|
||||
ld_nc_global, st_na_global);
|
||||
}
|
||||
|
||||
// Copy `src_idx`
|
||||
#pragma unroll 4
|
||||
for (int chunk_idx = cached_channel_head_idx + recv_lane_id; chunk_idx < cached_channel_tail_idx; chunk_idx += 32)
|
||||
recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights`
|
||||
#pragma unroll 4
|
||||
for (int idx = recv_lane_id; idx < num_recv_tokens * num_topk; idx += 32) {
|
||||
int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto recv_idx = static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
|
||||
auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;
|
||||
recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);
|
||||
recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll 4
|
||||
for (int i = recv_lane_id; i < num_recv_tokens * num_scales; i += 32) {
|
||||
int chunk_idx = i / num_scales, scales_idx = i % num_scales;
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales + scales_idx] =
|
||||
ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx);
|
||||
}
|
||||
|
||||
// Move queue
|
||||
cached_channel_head_idx += num_recv_tokens;
|
||||
total_offset += num_recv_tokens;
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
|
||||
|
||||
// Exit
|
||||
num_tokens_to_recv -= num_recv_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<ranks>, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_ranks * 32, stream);
|
||||
SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
|
||||
#undef DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
if (sm_id == 0) {
|
||||
// Barrier before cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
ptr[i] = 0;
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
} else {
|
||||
const auto channel_id = sm_id - 1;
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto rank_id = thread_id / 32;
|
||||
const auto lane_id = thread_id % 32;
|
||||
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
||||
|
||||
// NOTES: `1 << 25` is a heuristic large number
|
||||
int last_head = 1 << 25;
|
||||
#pragma unroll
|
||||
for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) {
|
||||
int token_idx = token_idx_tail - lane_id, expected_head = 0;
|
||||
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
|
||||
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
|
||||
head = __shfl_sync(0xffffffff, current_head, i);
|
||||
if (head < 0) {
|
||||
if (lane_id == i)
|
||||
expected_head = -last_head - 1;
|
||||
} else {
|
||||
last_head = head;
|
||||
}
|
||||
}
|
||||
if (current_head < 0 and token_idx >= token_start_idx)
|
||||
send_head[token_idx * kNumRanks + rank_id] = expected_head;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_COMBINE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
|
||||
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
const int num_threads = std::max(128, 32 * num_ranks);
|
||||
EP_HOST_ASSERT(num_ranks <= num_threads);
|
||||
EP_HOST_ASSERT(num_threads <= 1024);
|
||||
EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
|
||||
SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);
|
||||
SWITCH_RANKS(CACHED_NOTIFY_COMBINE);
|
||||
#undef CACHED_NOTIFY_COMBINE
|
||||
}
|
||||
|
||||
template<typename dtype_t, int kNumRanks, int kNumThreads>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void **buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_channels = num_sms / 2;
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
const int responsible_channel = sm_id / 2;
|
||||
EP_DEVICE_ASSERT(num_topk <= 32);
|
||||
|
||||
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
||||
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
|
||||
auto x_int4 = reinterpret_cast<const int4*>(x);
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
// Several warps are responsible for a single rank
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_rank_id = thread_id / num_threads_per_rank;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
|
||||
|
||||
// Channel meta data
|
||||
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
|
||||
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
|
||||
// Get tasks
|
||||
// NOTES: `channel_offset` is already shifted
|
||||
int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;
|
||||
int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;
|
||||
int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];
|
||||
int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset;
|
||||
int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens;
|
||||
|
||||
// Iterate over all tokens and send by chunks
|
||||
int current_channel_tail_idx = 0;
|
||||
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) {
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
auto start_time = clock64();
|
||||
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
|
||||
while (send_lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
|
||||
break;
|
||||
|
||||
// Rare cases to loop again
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Send by chunk
|
||||
#pragma unroll
|
||||
for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {
|
||||
// Get an empty slot
|
||||
int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;
|
||||
|
||||
// Copy data
|
||||
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
||||
|
||||
// Send source index
|
||||
if (send_lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
|
||||
|
||||
// Send `topk_weights`
|
||||
if (num_topk > 0 and send_lane_id < num_topk)
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
|
||||
}
|
||||
token_idx += num_round_tokens;
|
||||
current_channel_tail_idx += num_round_tokens;
|
||||
|
||||
// Move tail index
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
// Workers for receiving
|
||||
// One warp for moving the queue head, others for reduction
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
const auto recv_warp_id = thread_id / 32;
|
||||
const auto recv_lane_id = thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
|
||||
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
|
||||
|
||||
// Shared head, tail and retired flags for receiver warps
|
||||
__shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks];
|
||||
__shared__ volatile int channel_tail_idx[kNumRanks];
|
||||
__shared__ volatile bool warp_retired[num_recv_warps];
|
||||
if (thread_id < num_recv_warps)
|
||||
warp_retired[thread_id] = false;
|
||||
if (recv_lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
|
||||
if (thread_id < kNumRanks)
|
||||
channel_tail_idx[thread_id] = 0;
|
||||
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
||||
|
||||
if (thread_id < 32) {
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
|
||||
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
||||
|
||||
// Queue head updater
|
||||
int last_head = 0;
|
||||
while (recv_lane_id < kNumRanks) {
|
||||
// Check retired
|
||||
bool retired = true;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_recv_warps; ++ i)
|
||||
retired = retired and warp_retired[i];
|
||||
if (retired)
|
||||
break;
|
||||
|
||||
// Update queue tail
|
||||
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
|
||||
|
||||
// Update minimum head
|
||||
int min_head = std::numeric_limits<int>::max();
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
|
||||
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
|
||||
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
|
||||
}
|
||||
} else {
|
||||
// Receivers
|
||||
// Channel metadata
|
||||
// All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`
|
||||
Buffer<int4> channel_x_buffers[kNumRanks];
|
||||
Buffer<float> channel_topk_weights_buffers[kNumRanks];
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i) {
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + i;
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
|
||||
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
|
||||
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
}
|
||||
|
||||
// The same tokens as the dispatch process
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
|
||||
|
||||
// Iterate over all tokens and combine
|
||||
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
|
||||
// Read expected head
|
||||
int expected_head = -1;
|
||||
if (recv_lane_id < kNumRanks) {
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||
}
|
||||
auto start_time = clock64();
|
||||
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Broadcast current heads
|
||||
int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i) {
|
||||
auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i);
|
||||
if (expected_head_i >= 0) {
|
||||
slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;
|
||||
topk_ranks[num_topk_ranks ++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce data
|
||||
#pragma unroll
|
||||
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
|
||||
// Read buffers
|
||||
int4 recv_value_int4[kNumRanks];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
|
||||
|
||||
// Reduce all-to-all results
|
||||
float values[kDtypePerInt4] = {0};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j) {
|
||||
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kDtypePerInt4; ++ k)
|
||||
values[k] += static_cast<float>(recv_value_dtypes[k]);
|
||||
}
|
||||
|
||||
// Cast back to `dtype_t` and write
|
||||
int4 out_int4;
|
||||
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
||||
recv_int4[token_idx * hidden_int4 + i] = out_int4;
|
||||
}
|
||||
|
||||
// Reduce `topk_weights`
|
||||
if (recv_lane_id < num_topk) {
|
||||
float value = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk_ranks; ++ i)
|
||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
|
||||
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// Retired
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
warp_retired[recv_warp_id] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
const void* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
|
||||
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
||||
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
break
|
||||
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
EP_HOST_ASSERT(kNumThreads >= num_ranks * 32);
|
||||
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
|
||||
SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);
|
||||
#undef COMBINE_DTYPE_LAUNCH_CASE
|
||||
#undef COMBINE_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
} // namespace deep_ep
|
||||
60
csrc/kernels/launch.cuh
Normal file
60
csrc/kernels/launch.cuh
Normal file
@@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
|
||||
#ifndef SETUP_LAUNCH_CONFIG
|
||||
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
|
||||
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
|
||||
cudaLaunchAttribute attr[1]; \
|
||||
attr[0].id = cudaLaunchAttributeCooperative; \
|
||||
attr[0].val.cooperative = 1; \
|
||||
cfg.attrs = attr; \
|
||||
cfg.numAttrs = 1
|
||||
#endif
|
||||
|
||||
#ifndef LAUNCH_KERNEL
|
||||
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
|
||||
#endif
|
||||
|
||||
#define SWITCH_RANKS(case_macro) \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(2); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_RDMA_RANKS(case_macro) \
|
||||
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
|
||||
case 2: case_macro(2); \
|
||||
case 3: case_macro(3); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
case 16: case_macro(16); \
|
||||
case 18: case_macro(18); \
|
||||
case 20: case_macro(20); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(dtype, 2); \
|
||||
case 4: case_macro(dtype, 4); \
|
||||
case 8: case_macro(dtype, 8); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_TYPES(case_macro) \
|
||||
switch (type) { \
|
||||
case CUDA_R_16BF: case_macro(nv_bfloat16); \
|
||||
case CUDA_R_32F: case_macro(float); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported type"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_HIDDEN(case_macro) \
|
||||
switch (hidden) { \
|
||||
case 2560: case_macro(2560); \
|
||||
case 5120: case_macro(5120); \
|
||||
case 7168: case_macro(7168); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
|
||||
} while (false)
|
||||
119
csrc/kernels/runtime.cu
Normal file
119
csrc/kernels/runtime.cu
Normal file
@@ -0,0 +1,119 @@
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace intranode {
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
}
|
||||
|
||||
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define BARRIER_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 32, stream);
|
||||
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
|
||||
#undef BARRIER_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
namespace internode {
|
||||
|
||||
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
|
||||
nvshmem_team_config_t cpu_rdma_team_config;
|
||||
|
||||
std::vector<uint8_t> get_unique_id() {
|
||||
nvshmemx_uniqueid_t unique_id;
|
||||
nvshmemx_get_uniqueid(&unique_id);
|
||||
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
|
||||
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
|
||||
return result;
|
||||
}
|
||||
|
||||
__global__ void ibgda_initialize_recv_queue(int rank) {
|
||||
auto thread_idx = static_cast<int>(threadIdx.x);
|
||||
auto num_threads = static_cast<int>(blockDim.x);
|
||||
|
||||
auto dst_rank = static_cast<int>(blockIdx.x);
|
||||
if (dst_rank != rank) {
|
||||
for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) {
|
||||
auto qp = ibgda_get_rc(dst_rank, qp_id);
|
||||
|
||||
// Clean some necessary variables
|
||||
for (int i = 0; i < qp->rx_wq.nwqes; ++ i)
|
||||
ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i));
|
||||
qp->mvars.rx_wq.resv_head = 0;
|
||||
qp->mvars.rx_wq.cons_idx = 0;
|
||||
|
||||
// Allocate receive slots
|
||||
nvshmemi_ibgda_allocate_recvs(qp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
|
||||
nvshmemx_uniqueid_t root_unique_id;
|
||||
nvshmemx_init_attr_t attr;
|
||||
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
|
||||
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
|
||||
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
|
||||
|
||||
// Create sub-RDMA teams
|
||||
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
|
||||
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
|
||||
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
|
||||
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
|
||||
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
|
||||
num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
|
||||
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
|
||||
}
|
||||
|
||||
// Normal operations use IBRC, while low-latency operations use IBGDA
|
||||
if (low_latency_mode) {
|
||||
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
|
||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||
|
||||
bool ibgda_is_initialized = false;
|
||||
cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice);
|
||||
|
||||
// Initialize recv queues for low-latency mode AR
|
||||
ibgda_initialize_recv_queue<<<num_ranks, 128>>>(rank);
|
||||
}
|
||||
nvshmem_barrier_all();
|
||||
return nvshmem_my_pe();
|
||||
}
|
||||
|
||||
void* alloc(size_t size, size_t alignment) {
|
||||
return nvshmem_align(alignment, size);
|
||||
}
|
||||
|
||||
void free(void* ptr) {
|
||||
nvshmem_free(ptr);
|
||||
}
|
||||
|
||||
void barrier() {
|
||||
nvshmem_barrier_all();
|
||||
}
|
||||
|
||||
void finalize() {
|
||||
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
|
||||
nvshmem_team_destroy(cpu_rdma_team);
|
||||
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
|
||||
}
|
||||
nvshmem_finalize();
|
||||
}
|
||||
|
||||
} // namespace internode
|
||||
|
||||
} // namespace deep_ep
|
||||
381
csrc/kernels/utils.cuh
Normal file
381
csrc/kernels/utils.cuh
Normal file
@@ -0,0 +1,381 @@
|
||||
#pragma once
|
||||
|
||||
#include "exception.cuh"
|
||||
|
||||
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
|
||||
{ \
|
||||
constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \
|
||||
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
|
||||
auto __src = (SRC); \
|
||||
auto __dst = (DST); \
|
||||
for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
|
||||
_Pragma("unroll") \
|
||||
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
|
||||
unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \
|
||||
_Pragma("unroll") \
|
||||
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
|
||||
ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \
|
||||
} \
|
||||
for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \
|
||||
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
|
||||
}
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
template <int kBytes>
|
||||
struct VecInt {};
|
||||
template<> struct VecInt<1> { using vec_t = int8_t; };
|
||||
template<> struct VecInt<2> { using vec_t = int16_t; };
|
||||
template<> struct VecInt<4> { using vec_t = int; };
|
||||
template<> struct VecInt<8> { using vec_t = int64_t; };
|
||||
template<> struct VecInt<16> { using vec_t = int4; };
|
||||
|
||||
__device__ __forceinline__ void trap() {
|
||||
asm("trap;");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void memory_fence() {
|
||||
asm volatile("fence.acq_rel.sys;":: : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void memory_fence_gpu() {
|
||||
asm volatile("fence.acq_rel.gpu;":: : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void memory_fence_cta() {
|
||||
asm volatile("fence.acq_rel.cta;":: : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) {
|
||||
asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
|
||||
asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
|
||||
asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
|
||||
uint64_t ret;
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {
|
||||
int ret;
|
||||
asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) {
|
||||
int ret;
|
||||
asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) {
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return static_cast<uint8_t>(ret);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) {
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) {
|
||||
uint64_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_volatile_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_volatile_global(const float *ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) {
|
||||
int64_t ret;
|
||||
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
|
||||
int64_t ret;
|
||||
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
|
||||
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
|
||||
#else
|
||||
#define LD_NC_FUNC "ld.volatile.global"
|
||||
#endif
|
||||
|
||||
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS,
|
||||
// which does not have cache allocation, and `CONSTANT` memory does not have coherence control,
|
||||
// so we have to control them by queue semantics
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
|
||||
auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));
|
||||
return *reinterpret_cast<dtype_t*>(&ret);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) {
|
||||
uint16_t ret;
|
||||
// NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit)
|
||||
asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return static_cast<uint8_t>(ret);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int ld_nc_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) {
|
||||
int64_t ret;
|
||||
asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ float ld_nc_global(const float *ptr) {
|
||||
float ret;
|
||||
asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) {
|
||||
int2 ret;
|
||||
asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) {
|
||||
int4 ret;
|
||||
asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};"
|
||||
: : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
|
||||
}
|
||||
|
||||
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS,
|
||||
// which does not have cache allocation (obviously in L1, I guess not in L2 too)
|
||||
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
|
||||
#define ST_NA_FUNC "st.global.L1::no_allocate"
|
||||
#else
|
||||
#define ST_NA_FUNC "st.global"
|
||||
#endif
|
||||
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) {
|
||||
st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr),
|
||||
*reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const int *ptr, const int& value) {
|
||||
asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) {
|
||||
asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const float *ptr, const float& value) {
|
||||
asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) {
|
||||
asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};"
|
||||
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
|
||||
return cell_div<dtype_t>(a, b) * b;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
|
||||
int& token_start_idx, int& token_end_idx) {
|
||||
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
|
||||
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
|
||||
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
|
||||
}
|
||||
|
||||
template <typename dtype_a_t, typename dtype_b_t>
|
||||
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
|
||||
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
|
||||
dtype_b_t packed;
|
||||
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
|
||||
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
|
||||
return packed;
|
||||
}
|
||||
|
||||
template <typename dtype_a_t, typename dtype_b_t>
|
||||
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
|
||||
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
|
||||
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
|
||||
x = unpacked_ptr[0], y = unpacked_ptr[1];
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
|
||||
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
|
||||
auto send_int_values = reinterpret_cast<int*>(&ptr);
|
||||
int recv_int_values[sizeof(dtype_t) / sizeof(int)];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i)
|
||||
recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);
|
||||
return *reinterpret_cast<dtype_t*>(recv_int_values);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int warp_reduce_sum(int value) {
|
||||
value += __shfl_xor_sync(0xffffffff, value, 16);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 8);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 4);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 2);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 1);
|
||||
return value;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float half_warp_reduce_max(float value) {
|
||||
auto mask = __activemask();
|
||||
// The mask be in `{0xffffffff, 0xffff}`
|
||||
value = max(value, __shfl_xor_sync(mask, value, 8));
|
||||
value = max(value, __shfl_xor_sync(mask, value, 4));
|
||||
value = max(value, __shfl_xor_sync(mask, value, 2));
|
||||
value = max(value, __shfl_xor_sync(mask, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int get_lane_id() {
|
||||
int lane_id;
|
||||
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void move_fifo_slots(int &head) {
|
||||
head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS;
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__device__ __forceinline__ bool not_finished(int *task, int expected) {
|
||||
auto result = false;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
if (lane_id < kNumRanks)
|
||||
result = ld_volatile_global(task + lane_id) != expected;
|
||||
return __any_sync(0xffffffff, result);
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) {
|
||||
auto start_time = clock64();
|
||||
while (not_finished<kNumRanks>(task_fifo_ptrs[rank] + head, expected)) {
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) {
|
||||
printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
|
||||
if (thread_id < kNumRanks) {
|
||||
atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG);
|
||||
memory_fence();
|
||||
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
|
||||
}
|
||||
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
7
deep_ep/__init__.py
Normal file
7
deep_ep/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import torch
|
||||
|
||||
from .utils import EventOverlap
|
||||
from .buffer import Buffer
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from deep_ep_cpp import Config
|
||||
534
deep_ep/buffer.py
Normal file
534
deep_ep/buffer.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Callable, List, Tuple, Optional, Union
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep_cpp
|
||||
# noinspection PyUnresolvedReferences
|
||||
from deep_ep_cpp import Config, EventHandle
|
||||
from .utils import EventOverlap
|
||||
|
||||
|
||||
class Buffer:
|
||||
"""
|
||||
The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports:
|
||||
- high-throughput intranode all-to-all (dispatch and combine, using NVLink)
|
||||
- high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
|
||||
- low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
|
||||
|
||||
Attributes:
|
||||
num_sms: the SMs used in high-throughput kernels.
|
||||
rank: the local rank number.
|
||||
group_size: the number of ranks in the group.
|
||||
group: the communication group.
|
||||
num_nvl_bytes: the buffer size for intranode NVLink communication.
|
||||
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
|
||||
runtime: the C++ runtime.
|
||||
"""
|
||||
|
||||
num_sms: int = 20
|
||||
|
||||
def __init__(self, group: dist.ProcessGroup,
|
||||
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
|
||||
"""
|
||||
Initialize the communication buffer.
|
||||
|
||||
Arguments:
|
||||
group: the communication group.
|
||||
num_nvl_bytes: the buffer size for intranode NVLink communication.
|
||||
num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication.
|
||||
low_latency_mode: whether to enable low-latency mode.
|
||||
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
|
||||
to the number of local experts.
|
||||
"""
|
||||
|
||||
# TODO: argument docs
|
||||
# Initialize the CPP runtime
|
||||
self.rank = group.rank()
|
||||
self.group_size = group.size()
|
||||
self.group = group
|
||||
self.num_nvl_bytes = num_nvl_bytes
|
||||
self.num_rdma_bytes = num_rdma_bytes
|
||||
self.low_latency_mode = low_latency_mode
|
||||
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)
|
||||
|
||||
# Synchronize device IDs
|
||||
device_ids = [None, ] * self.group_size
|
||||
local_device_id = self.runtime.get_local_device_id()
|
||||
dist.all_gather_object(device_ids, local_device_id, group)
|
||||
|
||||
# Synchronize IPC handles
|
||||
ipc_handles = [None, ] * self.group_size
|
||||
local_ipc_handle = self.runtime.get_local_ipc_handle()
|
||||
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
|
||||
|
||||
# Synchronize NVSHMEM unique IDs
|
||||
root_unique_id = None
|
||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
|
||||
if low_latency_mode:
|
||||
assert num_qps_per_rank > 0
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
||||
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
|
||||
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
|
||||
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
|
||||
# NOTES: NVSHMEM initialization requires at least 256 MiB
|
||||
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
|
||||
|
||||
# NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code
|
||||
# Synchronize using the root ID
|
||||
nvshmem_unique_ids = [None, ] * self.group_size
|
||||
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
|
||||
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
|
||||
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
|
||||
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
|
||||
|
||||
# Make CPP runtime available
|
||||
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
|
||||
assert self.runtime.is_available()
|
||||
|
||||
@staticmethod
|
||||
def set_num_sms(new_num_sms: int) -> None:
|
||||
"""
|
||||
Set the number of SMs to use in high-throughput kernels.
|
||||
|
||||
Arguments:
|
||||
new_num_sms: the new number to be set.
|
||||
"""
|
||||
|
||||
assert new_num_sms % 2 == 0, 'The SM count must be even'
|
||||
Buffer.num_sms = new_num_sms
|
||||
|
||||
@staticmethod
|
||||
def capture() -> EventOverlap:
|
||||
"""
|
||||
Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`.
|
||||
|
||||
Returns:
|
||||
event: the captured event.
|
||||
"""
|
||||
return EventOverlap(EventHandle())
|
||||
|
||||
@staticmethod
|
||||
def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int) -> int:
|
||||
"""
|
||||
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
|
||||
|
||||
Arguments:
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
hidden: the hidden dimension of each token.
|
||||
num_ranks: the number of EP group ranks.
|
||||
num_experts: the number of all experts.
|
||||
|
||||
Returns:
|
||||
size: the RDMA buffer size recommended.
|
||||
"""
|
||||
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
|
||||
|
||||
def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None,
|
||||
offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Get the raw buffer (slice supported) as a PyTorch tensor.
|
||||
|
||||
Argument:
|
||||
dtype: the data type (PyTorch `dtype`) for the tensor.
|
||||
size: the slice size (by elements) to get from the buffer.
|
||||
offset: the offset of the beginning element.
|
||||
use_rdma_buffer: whether to return the RDMA buffer.
|
||||
"""
|
||||
tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer)
|
||||
if size is None:
|
||||
return tensor
|
||||
|
||||
assert tensor.numel() >= size.numel()
|
||||
return tensor[:size.numel()].view(size)
|
||||
|
||||
@staticmethod
|
||||
def get_dispatch_config(num_ranks: int) -> Config:
|
||||
"""
|
||||
Get a recommended dispatch config.
|
||||
|
||||
Argument:
|
||||
num_ranks: the number of ranks.
|
||||
|
||||
Returns:
|
||||
config: the recommended config.
|
||||
"""
|
||||
# Intranode
|
||||
if num_ranks <= 8:
|
||||
return Config(Buffer.num_sms, 6, 256, 6, 128)
|
||||
|
||||
# Internode
|
||||
config_map = {
|
||||
16: Config(Buffer.num_sms, 16, 288, 20, 128),
|
||||
24: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
32: Config(Buffer.num_sms, 8, 288, 32, 128),
|
||||
64: Config(Buffer.num_sms, 20, 288, 28, 128),
|
||||
128: Config(Buffer.num_sms, 20, 560, 32, 128),
|
||||
144: Config(Buffer.num_sms, 32, 720, 12, 128),
|
||||
160: Config(Buffer.num_sms, 28, 720, 12, 128),
|
||||
}
|
||||
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
|
||||
return config_map[num_ranks]
|
||||
|
||||
@staticmethod
|
||||
def get_combine_config(num_ranks: int) -> Config:
|
||||
"""
|
||||
Get a recommended combine config.
|
||||
|
||||
Argument:
|
||||
num_ranks: the number of ranks.
|
||||
|
||||
Returns:
|
||||
config: the recommended config.
|
||||
"""
|
||||
# Intranode
|
||||
if num_ranks <= 8:
|
||||
return Config(Buffer.num_sms, 6, 256, 6, 128)
|
||||
|
||||
# Internode
|
||||
config_map = {
|
||||
16: Config(Buffer.num_sms, 2, 288, 28, 128),
|
||||
24: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
32: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
64: Config(Buffer.num_sms, 1, 288, 20, 128),
|
||||
128: Config(Buffer.num_sms, 1, 560, 12, 128),
|
||||
144: Config(Buffer.num_sms, 2, 720, 8, 128),
|
||||
160: Config(Buffer.num_sms, 2, 720, 8, 128),
|
||||
}
|
||||
assert num_ranks in config_map, f'Unsupported number of EP ranks: {num_ranks}'
|
||||
return config_map[num_ranks]
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]:
|
||||
"""
|
||||
Calculate the layout required for later communication.
|
||||
|
||||
Arguments:
|
||||
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token,
|
||||
`-1` means no selections.
|
||||
num_experts: the number of experts.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
|
||||
|
||||
Returns:
|
||||
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
|
||||
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
|
||||
rank (with the same GPU index), return `None` for intranode settings.
|
||||
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
|
||||
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
"""
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
|
||||
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
|
||||
async_finish, allocate_on_comm_stream)
|
||||
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
handle: Optional[Tuple] = None,
|
||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
|
||||
"""
|
||||
Dispatch tokens to different ranks, both intranode and internode settings are supported.
|
||||
Intranode kernels require all the ranks should be visible via NVLink.
|
||||
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
|
||||
index should be visible via RDMA. AR must be disabled.
|
||||
|
||||
Arguments:
|
||||
x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`,
|
||||
and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as
|
||||
`[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]`
|
||||
(requiring divisible) with type `torch.float`.
|
||||
handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time.
|
||||
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
|
||||
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
|
||||
rank (with the same GPU index), return `None` for intranode settings.
|
||||
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
|
||||
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
|
||||
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
|
||||
`-1` means no selections.
|
||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch.
|
||||
expert_alignment: align the number of tokens received by each local expert to this variable.
|
||||
config: the performance tuning config.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
|
||||
|
||||
Returns:
|
||||
recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the
|
||||
received token count.
|
||||
recv_topk_idx: received expert indices.
|
||||
recv_topk_weights: received expert weights.
|
||||
num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by
|
||||
each local expert, aligned to the input `expert_alignment`.
|
||||
handle: the returned communication handle.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
"""
|
||||
# Default config
|
||||
config = self.get_dispatch_config(self.group_size) if config is None else config
|
||||
|
||||
# Internode
|
||||
if self.runtime.get_num_rdma_ranks() > 1:
|
||||
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||
|
||||
# Launch the kernel with cached or non-cached mode
|
||||
x, x_scales = x if isinstance(x, tuple) else (x, None)
|
||||
if handle is not None:
|
||||
assert topk_idx is None and topk_weights is None
|
||||
rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle
|
||||
num_recv_tokens = recv_src_idx.size(0)
|
||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
|
||||
x, x_scales, None, None,
|
||||
None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||
else:
|
||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
|
||||
self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def combine(self, x: torch.Tensor, handle: Tuple,
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
|
||||
"""
|
||||
Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode
|
||||
settings are supported.
|
||||
Intranode kernels require all the ranks should be visible via NVLink.
|
||||
Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU
|
||||
index should be visible via RDMA. AR must be disabled.
|
||||
|
||||
Arguments:
|
||||
x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks.
|
||||
handle: a must-set communication handle, you can obtain this from the dispatch function.
|
||||
topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks.
|
||||
config: the performance tuning config.
|
||||
previous_event: the event to wait before actually executing the kernel.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream.
|
||||
|
||||
Returns:
|
||||
recv_x: the reduced token from its dispatched ranks.
|
||||
recv_topk_weights: the reduced top-k weights from its dispatch ranks.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
"""
|
||||
# Default config
|
||||
config = self.get_combine_config(self.group_size) if config is None else config
|
||||
|
||||
# Internode
|
||||
if self.runtime.get_num_rdma_ranks() > 1:
|
||||
return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream)
|
||||
|
||||
# NOTES: the second `_` is for the sending side, so we should use the third one
|
||||
rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle
|
||||
|
||||
# Launch the kernel
|
||||
recv_x, recv_topk_weights, event = self.runtime.intranode_combine(
|
||||
x, topk_weights,
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config,
|
||||
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return recv_x, recv_topk_weights, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
handle: Optional[Tuple] = None,
|
||||
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
|
||||
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
|
||||
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], List[int], Tuple, EventOverlap]:
|
||||
"""
|
||||
Internode dispatch implementation, for more details, please refer to the `dispatch` docs.
|
||||
Normally, you should not directly call this function.
|
||||
"""
|
||||
assert config is not None
|
||||
|
||||
# Launch the kernel with cached or non-cached mode
|
||||
x, x_scales = x if isinstance(x, tuple) else (x, None)
|
||||
if handle is not None:
|
||||
assert topk_idx is None and topk_weights is None
|
||||
is_token_in_rank, \
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
recv_src_meta, send_rdma_head, send_nvl_head = handle
|
||||
num_recv_tokens = recv_src_meta.size(0)
|
||||
num_rdma_recv_tokens = send_nvl_head.size(0)
|
||||
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
|
||||
x, x_scales, topk_idx, topk_weights,
|
||||
None, None, is_token_in_rank, None,
|
||||
num_recv_tokens, num_rdma_recv_tokens,
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
|
||||
else:
|
||||
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
|
||||
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, \
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
recv_src_meta, send_rdma_head, send_nvl_head, event = self.runtime.internode_dispatch(
|
||||
x, x_scales, topk_idx, topk_weights,
|
||||
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
|
||||
0, 0, None, None, None, None,
|
||||
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
|
||||
handle = (is_token_in_rank,
|
||||
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta, send_rdma_head, send_nvl_head)
|
||||
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
config: Optional[Config] = None,
|
||||
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
|
||||
allocate_on_comm_stream: bool = False) -> \
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]:
|
||||
"""
|
||||
Internode combine implementation, for more details, please refer to the `combine` docs.
|
||||
Normally, you should not directly call this function.
|
||||
"""
|
||||
assert config is not None
|
||||
|
||||
# Unpack handle
|
||||
is_combined_token_in_rank, \
|
||||
_, _, \
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \
|
||||
src_meta, send_rdma_head, send_nvl_head = handle
|
||||
|
||||
# Launch the kernel
|
||||
combined_x, combined_topk_weights, event = self.runtime.internode_combine(
|
||||
x, topk_weights,
|
||||
src_meta, is_combined_token_in_rank,
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
|
||||
send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None),
|
||||
async_finish, allocate_on_comm_stream)
|
||||
return combined_x, combined_topk_weights, EventOverlap(event)
|
||||
|
||||
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
|
||||
"""
|
||||
As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer
|
||||
if the buffer is dirty at some time.
|
||||
For example, after running the normal dispatch/combine, you must run this function before executing any
|
||||
low-latency kernel.
|
||||
|
||||
Arguments:
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
hidden: the hidden dimension of each token.
|
||||
num_experts: the number of all experts.
|
||||
"""
|
||||
self.runtime.clean_low_latency_buffer(num_max_dispatch_tokens_per_rank, hidden, num_experts)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int, num_experts: int,
|
||||
async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for dispatching with IBGDA **with implicit FP8 casting**.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
low-latency kernels' result tensor at a single moment.
|
||||
|
||||
Arguments:
|
||||
x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are
|
||||
supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`.
|
||||
topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes
|
||||
are supported. `-1` indices (not selecting any expert) are supported.
|
||||
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
|
||||
num_experts: the number of all experts.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
recv_x: a tuple with received tokens for each expert. The first element is a `torch.Tensor` shaped as
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
|
||||
The second tensor is the corresponding scales for the first element with shape
|
||||
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`.
|
||||
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
|
||||
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
|
||||
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph).
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
|
||||
self.runtime.low_latency_dispatch(x, topk_idx,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, num_experts)
|
||||
tensors_to_record = (x, topk_idx,
|
||||
packed_recv_x, packed_recv_x_scales, packed_recv_count,
|
||||
packed_recv_src_info, packed_recv_layout_range)
|
||||
return (packed_recv_x, packed_recv_x_scales), packed_recv_count, handle, \
|
||||
EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
Tuple[torch.Tensor, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||
This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA
|
||||
(specifically, IBGDA must be enabled).
|
||||
Even for ranks in the same node, NVLink are fully disabled for simplicity.
|
||||
Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2
|
||||
low-latency kernels' result tensor at a single moment.
|
||||
|
||||
Arguments:
|
||||
x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`,
|
||||
the local calculated tokens to be sent to this original rank and reduced.
|
||||
topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched
|
||||
tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals
|
||||
to the number of dispatched tokens.
|
||||
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
|
||||
tokens. The received tokens will be reduced with the weights in this tensor.
|
||||
handle: the communication handle given by the `dispatch` function.
|
||||
async_finish: the current stream will not wait for the communication kernels to be finished if set.
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
|
||||
Returns:
|
||||
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
||||
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
60
deep_ep/utils.py
Normal file
60
deep_ep/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from deep_ep_cpp import Config, EventHandle
|
||||
|
||||
|
||||
class EventOverlap:
|
||||
"""
|
||||
A wrapper class to manage CUDA events, also for better overlapping convenience.
|
||||
|
||||
Attributes:
|
||||
event: the CUDA event captured.
|
||||
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
|
||||
"""
|
||||
|
||||
def __init__(self, event: Optional[EventHandle] = None,
|
||||
extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None:
|
||||
"""
|
||||
Initialize the class.
|
||||
|
||||
Arguments:
|
||||
event: the CUDA event captured.
|
||||
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
|
||||
"""
|
||||
self.event = event
|
||||
|
||||
# NOTES: we use extra tensors to achieve stream recording, otherwise,
|
||||
# stream recording will be incompatible with CUDA graph.
|
||||
self.extra_tensors = extra_tensors
|
||||
|
||||
def current_stream_wait(self) -> None:
|
||||
"""
|
||||
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
|
||||
"""
|
||||
assert self.event is not None
|
||||
self.event.current_stream_wait()
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
"""
|
||||
Utility for overlapping and Python `with` syntax.
|
||||
|
||||
You can overlap the kernels on the current stream with the following example:
|
||||
```python
|
||||
event_overlap = event_after_all_to_all_kernels()
|
||||
with event_overlap():
|
||||
do_something_on_current_stream()
|
||||
# After exiting the `with` scope, the current stream with wait the event to be finished.
|
||||
```
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""
|
||||
Utility for overlapping and Python `with` syntax.
|
||||
|
||||
Please follow the example in the `__enter__` function.
|
||||
"""
|
||||
if self.event is not None:
|
||||
self.event.current_stream_wait()
|
||||
BIN
figures/low-latency.png
Normal file
BIN
figures/low-latency.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 674 KiB |
BIN
figures/normal.png
Normal file
BIN
figures/normal.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 506 KiB |
63
setup.py
Normal file
63
setup.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import subprocess
|
||||
import setuptools
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
|
||||
assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM'
|
||||
print(f'NVSHMEM directory: {nvshmem_dir}')
|
||||
|
||||
# TODO: currently, we only support Hopper architecture, we may add Ampere support later
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
|
||||
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
|
||||
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
|
||||
nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']
|
||||
include_dirs = ['csrc/', f'{nvshmem_dir}/include']
|
||||
sources = ['csrc/deep_ep.cpp',
|
||||
'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu',
|
||||
'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
|
||||
library_dirs = [f'{nvshmem_dir}/lib']
|
||||
|
||||
# Disable aggressive PTX instructions
|
||||
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
|
||||
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
|
||||
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
|
||||
|
||||
# Disable DLTO (default by PyTorch)
|
||||
nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']
|
||||
extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']
|
||||
extra_compile_args = {
|
||||
'cxx': cxx_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
'nvcc_dlink': nvcc_dlink
|
||||
}
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except Exception as _:
|
||||
revision = ''
|
||||
|
||||
setuptools.setup(
|
||||
name='deep_ep',
|
||||
version='1.0.0' + revision,
|
||||
packages=setuptools.find_packages(
|
||||
include=['deep_ep']
|
||||
),
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='deep_ep_cpp',
|
||||
include_dirs=include_dirs,
|
||||
library_dirs=library_dirs,
|
||||
sources=sources,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=extra_link_args
|
||||
)
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
}
|
||||
)
|
||||
246
tests/test_internode.py
Normal file
246
tests/test_internode.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
import test_low_latency
|
||||
|
||||
|
||||
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks
|
||||
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
||||
if local_rank == 0:
|
||||
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
|
||||
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
|
||||
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
|
||||
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
rdma_rank_idx = rank_idx // num_local_ranks
|
||||
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
|
||||
inplace_unique(rdma_rank_idx, num_nodes)
|
||||
|
||||
# RDMA dispatch counts
|
||||
rdma_idx = topk_idx // (num_experts // num_nodes)
|
||||
rdma_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rdma_idx, num_nodes)
|
||||
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
|
||||
num_tokens_per_rdma_rank = torch.empty((num_nodes, ), dtype=torch.int, device='cuda')
|
||||
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
|
||||
for i in range(num_nodes):
|
||||
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
|
||||
buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
|
||||
print()
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, recv_gbl_rank_prefix_sum):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = recv_gbl_rank_prefix_sum[i].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
|
||||
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank,
|
||||
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
|
||||
if with_topk:
|
||||
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
|
||||
|
||||
# Checks
|
||||
recv_gbl_rank_prefix_sum = handle[-4]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
|
||||
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
||||
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
# NOTES: handle must be refreshed
|
||||
if not with_topk:
|
||||
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test combine
|
||||
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
if with_topk:
|
||||
combine_args.update({'topk_weights': recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
|
||||
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(' passed', flush=True)
|
||||
if local_rank == 0:
|
||||
print()
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes
|
||||
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
for rdma_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
|
||||
tune_args = {'x': current_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
|
||||
print()
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor([best_results[0], best_results[1], best_results[2]], dtype=torch.int32, device='cuda')
|
||||
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
|
||||
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size)
|
||||
|
||||
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank,
|
||||
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
|
||||
'config': dispatch_config if dispatch_config is not None else config}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
for rdma_chunk_size in range(8, 33, 4):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size)
|
||||
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
|
||||
print()
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
# Please make sure AR (Adaptive Routing) is turned off when running normal internode kernels,
|
||||
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility = False
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
|
||||
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24, ):
|
||||
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
|
||||
if local_rank == 0:
|
||||
print()
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
if test_ll_compatibility:
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
|
||||
223
tests/test_intranode.py
Normal file
223
tests/test_intranode.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to_fp8, per_token_cast_back
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
import test_low_latency
|
||||
|
||||
|
||||
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
|
||||
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
||||
if local_rank == 0:
|
||||
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
|
||||
topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts, ), dtype=torch.int, device='cuda')
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks, ), dtype=torch.int, device='cuda')
|
||||
token_idx_in_rank = torch.full((num_ranks, num_tokens), -1, dtype=torch.long, device='cuda')
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(count, dtype=torch.long, device='cuda')
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = \
|
||||
buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True)
|
||||
print()
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
nvl_buffer_size = 256
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, rank_prefix_matrix):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = rank_prefix_matrix[i][rank].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
|
||||
dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank,
|
||||
'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode}
|
||||
if with_topk:
|
||||
dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
|
||||
|
||||
# Checks
|
||||
rank_prefix_matrix = handle[0]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}'
|
||||
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
|
||||
check_data(recv_topk_weights, rank_prefix_matrix)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
# NOTES: handle must be refreshed
|
||||
if not with_topk:
|
||||
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
|
||||
# Test combine
|
||||
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
|
||||
if with_topk:
|
||||
combine_args.update({'topk_weights': recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({'previous_event': buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1))
|
||||
ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(' passed', flush=True)
|
||||
if local_rank == 0:
|
||||
print()
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {'x': current_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ')
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
|
||||
print()
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
|
||||
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
|
||||
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size)
|
||||
|
||||
dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank,
|
||||
'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert,
|
||||
'config': dispatch_config if dispatch_config is not None else config}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {'x': recv_x, 'handle': handle, 'config': config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ')
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)')
|
||||
print()
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility, num_rdma_bytes = False, 0
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24, ):
|
||||
test_main(i, local_rank, num_local_ranks, num_ranks, rank, buffer, group)
|
||||
if local_rank == 0:
|
||||
print()
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
if test_ll_compatibility:
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
|
||||
160
tests/test_low_latency.py
Normal file
160
tests/test_low_latency.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
|
||||
import deep_ep
|
||||
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
|
||||
|
||||
|
||||
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
|
||||
torch.manual_seed(seed + rank)
|
||||
random.seed(seed + rank)
|
||||
|
||||
assert num_experts % num_ranks == 0
|
||||
num_local_experts = num_experts // num_ranks
|
||||
|
||||
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
|
||||
rank_offset = 128
|
||||
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
|
||||
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
|
||||
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
|
||||
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
|
||||
|
||||
# Randomly mask some positions
|
||||
for i in range(10):
|
||||
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
|
||||
|
||||
# Check dispatch correctness
|
||||
do_check = True
|
||||
hash_value, num_times = 0, 0
|
||||
for return_recv_hook in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous())
|
||||
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
|
||||
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
|
||||
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
|
||||
|
||||
# Check expert indices
|
||||
int_mask = (2 ** 32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
|
||||
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
|
||||
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
|
||||
assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f'Error: diff={diff}'
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
|
||||
assert tmp.abs().amax().item() <= 1
|
||||
|
||||
# Create some amax outliers
|
||||
for i in range(num_outliers):
|
||||
tmp[random.randint(0, num_tokens - 1)] *= 1e3
|
||||
return tmp
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def large_gemm_with_hook(hook):
|
||||
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
|
||||
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
|
||||
mat_0 @ mat_1
|
||||
hook()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func(return_recv_hook):
|
||||
recv_x, recv_count, handle, event, hook = \
|
||||
buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
|
||||
async_finish=False, return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
return_recv_hook=return_recv_hook)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
|
||||
# Calculate bandwidth
|
||||
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
|
||||
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
|
||||
for i in range(num_tokens):
|
||||
num_selections = (topk_idx[i] != -1).sum().item()
|
||||
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
|
||||
num_combine_comm_bytes += num_bf16_bytes * num_selections
|
||||
|
||||
# Dispatch + combine testing
|
||||
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
|
||||
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
|
||||
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
|
||||
|
||||
# Separate profiling
|
||||
for return_recv_hook in (False, True):
|
||||
group.barrier()
|
||||
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
|
||||
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
|
||||
suppress_kineto_output=True)
|
||||
if not return_recv_hook:
|
||||
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
|
||||
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
|
||||
else:
|
||||
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
|
||||
f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us')
|
||||
|
||||
return hash_value
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
|
||||
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
|
||||
if local_rank == 0:
|
||||
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
|
||||
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // num_ranks)
|
||||
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
|
||||
|
||||
do_pressure_test = False
|
||||
for seed in range(int(1e9) if do_pressure_test else 0):
|
||||
if local_rank == 0:
|
||||
print(f'Testing with seed {seed} ...', flush=True)
|
||||
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
|
||||
for i in range(20):
|
||||
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO: you may modify NUMA binding for less CPU overhead
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
192
tests/utils.py
Normal file
192
tests/utils.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def init_dist(local_rank: int, num_local_ranks: int):
|
||||
# NOTES: you may rewrite this function with your own cluster settings
|
||||
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
||||
port = int(os.getenv('MASTER_PORT', '8361'))
|
||||
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
||||
node_rank = int(os.getenv('RANK', 0))
|
||||
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
||||
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=f'tcp://{ip}:{port}',
|
||||
world_size=num_nodes * num_local_ranks,
|
||||
rank=node_rank * num_local_ranks + local_rank
|
||||
)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device('cuda')
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double() + 1, y.double() + 1
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return (1 - sim).item()
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor):
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
|
||||
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
|
||||
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
|
||||
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
def inplace_unique(x: torch.Tensor, num_slots: int):
|
||||
assert x.dim() == 2
|
||||
mask = x < 0
|
||||
x_padded = x.masked_fill(mask, num_slots)
|
||||
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
|
||||
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
|
||||
bin_count = bin_count[:, :num_slots]
|
||||
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
|
||||
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
|
||||
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
|
||||
x[:, :].fill_(-1)
|
||||
valid_len = min(num_slots, x.size(1))
|
||||
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
|
||||
|
||||
|
||||
def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
|
||||
num_tokens, num_experts = scores.shape
|
||||
scores = scores.view(num_tokens, num_groups, -1)
|
||||
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
|
||||
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
|
||||
return (scores * mask).view(num_tokens, num_experts)
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
fn()
|
||||
|
||||
# Flush L2
|
||||
cache.zero_()
|
||||
|
||||
# Testing
|
||||
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
for i in range(num_tests):
|
||||
# Record
|
||||
start_events[i].record()
|
||||
fn()
|
||||
end_events[i].record()
|
||||
if post_fn is not None:
|
||||
post_fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = np.array([s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)])[1:]
|
||||
return np.average(times), np.min(times), np.max(times)
|
||||
|
||||
|
||||
class empty_suppress:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
pass
|
||||
|
||||
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, 'w')
|
||||
self.errnull_file = open(os.devnull, 'w')
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
|
||||
trace_path: Optional[str] = None, barrier_comm_profiling: bool = False):
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof:
|
||||
for i in range(2):
|
||||
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
|
||||
if barrier_comm_profiling:
|
||||
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
lhs @ rhs
|
||||
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
|
||||
for _ in range(num_tests):
|
||||
fn()
|
||||
prof.step()
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tupled = isinstance(kernel_names, tuple)
|
||||
prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
prof.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {'ms': 1e3, 'us': 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
kernel_times.append(float(time_str.replace(unit, '')) / scale)
|
||||
break
|
||||
break
|
||||
return tuple(kernel_times) if is_tupled else kernel_times[0]
|
||||
|
||||
|
||||
def hash_tensor(t: torch.Tensor):
|
||||
return t.view(torch.int64).sum().item()
|
||||
126
third-party/README.md
vendored
Normal file
126
third-party/README.md
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
# Install NVSHMEM
|
||||
|
||||
## Important notices
|
||||
|
||||
**This project is neither sponsored nor supported by NVIDIA.**
|
||||
|
||||
**Use of NVIDIA NVSHMEM is governed by the terms at [NVSHMEM Software License Agreement](https://docs.nvidia.com/nvshmem/api/sla.html).**
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. [GDRCopy](https://github.com/NVIDIA/gdrcopy) (v2.4 and above recommended) is a low-latency GPU memory copy library based on NVIDIA GPUDirect RDMA technology, and *it requires kernel module installation with root privileges.*
|
||||
|
||||
2. Hardware requirements
|
||||
- GPUDirect RDMA capable devices, see [GPUDirect RDMA Documentation](https://docs.nvidia.com/cuda/gpudirect-rdma/)
|
||||
- InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/)
|
||||
- For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements)
|
||||
|
||||
## Installation procedure
|
||||
|
||||
### 1. Install GDRCopy
|
||||
|
||||
GDRCopy requires kernel module installation on the host system. Complete these steps on the bare-metal host before container deployment:
|
||||
|
||||
#### Build and installation
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA/gdrcopy
|
||||
cd gdrcopy
|
||||
make -j$(nproc)
|
||||
sudo make prefix=/opt/gdrcopy install
|
||||
```
|
||||
|
||||
#### Kernel module installation
|
||||
|
||||
```bash
|
||||
cd packages
|
||||
CUDA=/path/to/cuda ./build-deb-packages.sh
|
||||
sudo dpkg -i gdrdrv-dkms_2.4-4_amd64.deb \
|
||||
libgdrapi_2.4-4_amd64.deb \
|
||||
gdrcopy-tests_2.4-4_amd64.deb \
|
||||
gdrcopy_2.4-4_amd64.deb
|
||||
sudo ./insmod.sh # Load kernel modules on bare-metal system
|
||||
```
|
||||
|
||||
#### Container environment notes
|
||||
|
||||
For containerized environments:
|
||||
1. Host: keep kernel modules loaded (`gdrdrv`)
|
||||
2. Container: install DEB packages *without* rebuilding modules:
|
||||
```bash
|
||||
sudo dpkg -i gdrcopy_2.4-4_amd64.deb \
|
||||
libgdrapi_2.4-4_amd64.deb \
|
||||
gdrcopy-tests_2.4-4_amd64.deb
|
||||
```
|
||||
|
||||
#### Verification
|
||||
|
||||
```bash
|
||||
gdrcopy_copybw # Should show bandwidth test results
|
||||
```
|
||||
|
||||
### 2. Acquiring NVSHMEM source code
|
||||
|
||||
Download NVSHMEM v3.1.7 from the [NVIDIA NVSHMEM Archive](https://developer.nvidia.com/nvshmem-archive).
|
||||
|
||||
### 3. Apply our custom patch
|
||||
|
||||
Navigate to your NVSHMEM source directory and apply our provided patch:
|
||||
|
||||
```bash
|
||||
git apply /path/to/deep_ep/dir/third-party/nvshmem.patch
|
||||
```
|
||||
|
||||
### 4. Configure NVIDIA driver
|
||||
|
||||
Enable IBGDA by modifying `/etc/modprobe.d/nvidia.conf`:
|
||||
|
||||
```bash
|
||||
options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"
|
||||
```
|
||||
|
||||
Update kernel configuration:
|
||||
|
||||
```bash
|
||||
sudo update-initramfs -u
|
||||
sudo reboot
|
||||
```
|
||||
|
||||
For more detailed configurations, please refer to the [NVSHMEM Installation Guide](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html).
|
||||
|
||||
### 5. Build and installation
|
||||
|
||||
The following example demonstrates building NVSHMEM with IBGDA support:
|
||||
|
||||
```bash
|
||||
CUDA_HOME=/path/to/cuda && \
|
||||
GDRCOPY_HOME=/path/to/gdrcopy && \
|
||||
NVSHMEM_SHMEM_SUPPORT=0 \
|
||||
NVSHMEM_UCX_SUPPORT=0 \
|
||||
NVSHMEM_USE_NCCL=0 \
|
||||
NVSHMEM_IBGDA_SUPPORT=1 \
|
||||
NVSHMEM_PMIX_SUPPORT=0 \
|
||||
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
|
||||
NVSHMEM_USE_GDRCOPY=1 \
|
||||
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/path/to/your/dir/to/install
|
||||
|
||||
cd build
|
||||
make -j$(nproc)
|
||||
make install
|
||||
```
|
||||
|
||||
## Post-installation configuration
|
||||
|
||||
Set environment variables in your shell configuration:
|
||||
|
||||
```bash
|
||||
export NVSHMEM_DIR=/path/to/your/dir/to/install # Use for DeepEP installation
|
||||
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
|
||||
export PATH="${NVSHMEM_DIR}/bin:$PATH"
|
||||
```
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
nvshmem-info -a # Should display details of nvshmem
|
||||
```
|
||||
456
third-party/nvshmem.patch
vendored
Normal file
456
third-party/nvshmem.patch
vendored
Normal file
@@ -0,0 +1,456 @@
|
||||
From 9d784943e1032f15dd7cdd2599192937ba9d9343 Mon Sep 17 00:00:00 2001
|
||||
From: Shangyan Zhou <sy.zhou@deepseek.com>
|
||||
Date: Fri, 20 Dec 2024 10:57:12 +0800
|
||||
Subject: [PATCH 1/5] Change QP creating order.
|
||||
|
||||
---
|
||||
src/modules/transport/ibgda/ibgda.cpp | 13 ++++++++-----
|
||||
1 file changed, 8 insertions(+), 5 deletions(-)
|
||||
|
||||
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
|
||||
index 31bc56a..ff02f50 100644
|
||||
--- a/src/modules/transport/ibgda/ibgda.cpp
|
||||
+++ b/src/modules/transport/ibgda/ibgda.cpp
|
||||
@@ -2921,17 +2921,20 @@ int nvshmemt_ibgda_connect_endpoints(nvshmem_transport_t t, int *selected_dev_id
|
||||
INFO(ibgda_state->log_level, "Creating %d RC QPs", device->rc.num_eps_per_pe);
|
||||
for (int i = 0; i < num_rc_eps; ++i) {
|
||||
// Do not create loopback to self
|
||||
- if (i / device->rc.num_eps_per_pe == mype) {
|
||||
+ int dst_pe = (i + 1 + mype) % n_pes;
|
||||
+ int offset = i / n_pes;
|
||||
+ int mapped_i = dst_pe * device->rc.num_eps_per_pe + offset;
|
||||
+ if (dst_pe == mype) {
|
||||
continue;
|
||||
}
|
||||
- status = ibgda_create_qp(&device->rc.eps[i], device, portid, i,
|
||||
+ status = ibgda_create_qp(&device->rc.eps[mapped_i], device, portid, mapped_i,
|
||||
NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC);
|
||||
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
|
||||
- "ibgda_create_dci failed on RC #%d.", i);
|
||||
+ "ibgda_create_dci failed on RC #%d.", mapped_i);
|
||||
|
||||
- status = ibgda_get_rc_handle(&local_rc_handles[i], device->rc.eps[i], device);
|
||||
+ status = ibgda_get_rc_handle(&local_rc_handles[mapped_i], device->rc.eps[mapped_i], device);
|
||||
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out,
|
||||
- "ibgda_get_rc_handle failed on RC #%d.", i);
|
||||
+ "ibgda_get_rc_handle failed on RC #%d.", mapped_i);
|
||||
}
|
||||
|
||||
if (num_rc_eps) {
|
||||
--
|
||||
2.25.1
|
||||
|
||||
|
||||
From 3cd3938bcbbabed7fb7675032afb02647ea9c2fe Mon Sep 17 00:00:00 2001
|
||||
From: Shangyan Zhou <sy.zhou@deepseek.com>
|
||||
Date: Mon, 23 Dec 2024 09:55:27 +0800
|
||||
Subject: [PATCH 2/5] Disable timeout check
|
||||
|
||||
---
|
||||
CMakeLists.txt | 3 ++-
|
||||
1 file changed, 2 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 771ff98..9246d29 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -140,7 +140,7 @@ option(NVSHMEM_NVTX "Enable NVSHMEM NVTX support" ${NVSHMEM_NVTX_DEFAULT})
|
||||
option(NVSHMEM_PMIX_SUPPORT "Enable Compilation of the PMIX bootstrap and PMIX specific code" $ENV{NVSHMEM_PMIX_SUPPORT})
|
||||
option(NVSHMEM_SHMEM_SUPPORT "Enable Compilation of the SHMEM bootstrap and SHMEM specific code" $ENV{NVSHMEM_SHMEM_SUPPORT})
|
||||
option(NVSHMEM_TEST_STATIC_LIB "Force tests to link only against the combined nvshmem.a binary" $ENV{NVSHMEM_TEST_STATIC_LIB})
|
||||
-option(NVSHMEM_TIMEOUT_DEVICE_POLLING "Enable timeouts for NVSHMEM device-side polling functions (e.g. wait_until)" $ENV{NVSHMEM_TIMEOUT_DEVICE_POLLING})
|
||||
+option(NVSHMEM_TIMEOUT_DEVICE_POLLING "Enable timeouts for NVSHMEM device-side polling functions (e.g. wait_until)" OFF)
|
||||
option(NVSHMEM_TRACE "Enable NVSHMEM trace print events" $ENV{NVSHMEM_TRACE})
|
||||
option(NVSHMEM_UCX_SUPPORT "Enable compilation of the UCX remote transport" $ENV{NVSHMEM_UCX_SUPPORT})
|
||||
option(NVSHMEM_USE_DLMALLOC "Set dlmalloc as the NVSHMEM heap allocation method" $ENV{NVSHMEM_USE_DLMALLOC})
|
||||
@@ -165,6 +165,7 @@ set(NVSHMEM_PREFIX ${NVSHMEM_PREFIX_DEFAULT} CACHE PATH "path to NVSHMEM install
|
||||
set(PMIX_HOME ${PMIX_HOME_DEFAULT} CACHE PATH "path to PMIX installation")
|
||||
set(SHMEM_HOME ${MPI_HOME} CACHE PATH "path to SHMEM installation")
|
||||
set(UCX_HOME ${UCX_HOME_DEFAULT} CACHE PATH "path to UCX installation")
|
||||
+set(NVSHMEM_TIMEOUT_DEVICE_POLLING OFF)
|
||||
|
||||
message(STATUS "NVSHMEM_PREFIX: ${NVSHMEM_PREFIX}")
|
||||
message(STATUS "NVSHMEM_DEVEL: ${NVSHMEM_DEVEL}")
|
||||
--
|
||||
2.25.1
|
||||
|
||||
|
||||
From 4e0eaff589d38f448715e43a935479451a41c0fe Mon Sep 17 00:00:00 2001
|
||||
From: Shangyan Zhou <sy.zhou@deepseek.com>
|
||||
Date: Fri, 10 Jan 2025 11:53:38 +0800
|
||||
Subject: [PATCH 3/5] Add recv queue and recv cq for rc qps.
|
||||
|
||||
Let the ibgda rc qps use regular recv queue.
|
||||
|
||||
Add recv queue to ibgda dev qp.
|
||||
|
||||
IBGDA create recv cq
|
||||
|
||||
Setup recv cq.
|
||||
|
||||
fix recv queue.
|
||||
|
||||
Remove some useless idx.
|
||||
|
||||
Longer recv queue.
|
||||
---
|
||||
.../nvshmem_common_ibgda.h | 19 +++++-
|
||||
src/modules/transport/ibgda/ibgda.cpp | 65 ++++++++++++++++---
|
||||
2 files changed, 71 insertions(+), 13 deletions(-)
|
||||
|
||||
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
index 32f6d02..7d4e250 100644
|
||||
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
@@ -168,14 +168,17 @@ typedef struct {
|
||||
uint64_t get_head; // last wqe idx + 1 with a "fetch" operation (g, get, amo_fetch)
|
||||
uint64_t get_tail; // last wqe idx + 1 polled with cst; get_tail > get_head is possible
|
||||
} tx_wq;
|
||||
+ struct {
|
||||
+ uint64_t resv_head; // last reserved wqe idx + 1
|
||||
+ } rx_wq;
|
||||
struct {
|
||||
uint64_t head;
|
||||
uint64_t tail;
|
||||
} ibuf;
|
||||
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
|
||||
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
|
||||
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 96,
|
||||
- "ibgda_device_qp_management_v1 must be 96 bytes.");
|
||||
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
|
||||
+ "ibgda_device_qp_management_v1 must be 104 bytes.");
|
||||
|
||||
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
|
||||
|
||||
@@ -199,9 +202,19 @@ typedef struct nvshmemi_ibgda_device_qp {
|
||||
// May point to mvars.prod_idx or internal prod_idx
|
||||
uint64_t *prod_idx;
|
||||
} tx_wq;
|
||||
+ struct {
|
||||
+ uint16_t nwqes;
|
||||
+ uint64_t tail;
|
||||
+ void *wqe;
|
||||
+ __be32 *dbrec;
|
||||
+ void *bf;
|
||||
+ nvshmemi_ibgda_device_cq_t *cq;
|
||||
+ // May point to mvars.prod_idx or internal prod_idx
|
||||
+ uint64_t *prod_idx;
|
||||
+ } rx_wq;
|
||||
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
|
||||
} nvshmemi_ibgda_device_qp_v1;
|
||||
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 184, "ibgda_device_qp_v1 must be 184 bytes.");
|
||||
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
|
||||
|
||||
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
|
||||
|
||||
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
|
||||
index ff02f50..b8d6bc7 100644
|
||||
--- a/src/modules/transport/ibgda/ibgda.cpp
|
||||
+++ b/src/modules/transport/ibgda/ibgda.cpp
|
||||
@@ -194,6 +194,7 @@ struct ibgda_ep {
|
||||
off_t dbr_offset;
|
||||
|
||||
struct ibgda_cq *send_cq;
|
||||
+ struct ibgda_cq *recv_cq;
|
||||
struct ibv_ah *ah;
|
||||
|
||||
uint32_t user_index;
|
||||
@@ -1520,7 +1521,8 @@ static int ibgda_create_cq_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
|
||||
|
||||
struct ibv_context *context = device->context;
|
||||
|
||||
- unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes;
|
||||
+ // Each RC qp has one send CQ and one recv CQ.
|
||||
+ unsigned int num_cqs = device->dci.num_eps + device->rc.num_eps_per_pe * n_pes * 2;
|
||||
|
||||
assert(ibgda_qp_depth > 0);
|
||||
size_t num_cqe = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
|
||||
@@ -1683,7 +1685,8 @@ static int ibgda_create_qp_shared_objects(nvshmemt_ibgda_state_t *ibgda_state,
|
||||
}
|
||||
|
||||
// Allocate and map WQ buffer for all QPs.
|
||||
- wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB; // num_wqebb is always a power of 2
|
||||
+ // Todo: reduce the size of wq buffer.
|
||||
+ wq_buf_size_per_qp = num_wqebb * MLX5_SEND_WQE_BB * 2; // num_wqebb is always a power of 2
|
||||
wq_buf_size = wq_buf_size_per_qp * num_eps;
|
||||
status = ibgda_nic_control_alloc(&wq_mobject, wq_buf_size, IBGDA_GPAGE_SIZE);
|
||||
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "cannot allocate wq buf.\n");
|
||||
@@ -1864,8 +1867,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
|
||||
int cqe_version = 0;
|
||||
|
||||
struct ibgda_cq *send_cq = NULL;
|
||||
+ struct ibgda_cq *recv_cq = NULL;
|
||||
|
||||
size_t num_wqebb = IBGDA_ROUND_UP_POW2_OR_0(ibgda_qp_depth);
|
||||
+ size_t num_recv_wqe = ibgda_qp_depth;
|
||||
+ size_t recv_wqe_size = 16;
|
||||
|
||||
int status = 0;
|
||||
|
||||
@@ -1893,6 +1899,11 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
|
||||
status = ibgda_create_cq(&send_cq, device);
|
||||
NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
|
||||
|
||||
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
|
||||
+ status = ibgda_create_cq(&recv_cq, device);
|
||||
+ NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "ibgda_create_cq failed.\n");
|
||||
+ }
|
||||
+
|
||||
ep = (struct ibgda_ep *)calloc(1, sizeof(struct ibgda_ep));
|
||||
NVSHMEMI_NULL_ERROR_JMP(ep, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out,
|
||||
"Unable to allocate mem for ep.\n");
|
||||
@@ -1921,12 +1932,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
|
||||
DEVX_SET(qpc, qp_context, pm_state, MLX5_QPC_PM_STATE_MIGRATED);
|
||||
DEVX_SET(qpc, qp_context, pd, device->qp_shared_object.pdn);
|
||||
DEVX_SET(qpc, qp_context, uar_page, uar_mobject->uar->page_id); // BF register
|
||||
- DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue
|
||||
- DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
|
||||
DEVX_SET(qpc, qp_context, cqn_snd, send_cq->cqn);
|
||||
- DEVX_SET(qpc, qp_context, cqn_rcv, device->qp_shared_object.rcqn);
|
||||
+ DEVX_SET(qpc, qp_context, cqn_rcv, qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC ? recv_cq->cqn : device->qp_shared_object.rcqn);
|
||||
DEVX_SET(qpc, qp_context, log_sq_size, IBGDA_ILOG2_OR0(num_wqebb));
|
||||
- DEVX_SET(qpc, qp_context, log_rq_size, 0);
|
||||
DEVX_SET(qpc, qp_context, cs_req, 0); // Disable CS Request
|
||||
DEVX_SET(qpc, qp_context, cs_res, 0); // Disable CS Response
|
||||
DEVX_SET(qpc, qp_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE); // Enable dbr_umem_id
|
||||
@@ -1935,6 +1943,15 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
|
||||
DEVX_SET(qpc, qp_context, dbr_umem_id, dbr_umem->umem_id); // DBR buffer
|
||||
DEVX_SET(qpc, qp_context, user_index, qp_idx);
|
||||
DEVX_SET(qpc, qp_context, page_offset, 0);
|
||||
+ if (qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC){
|
||||
+ DEVX_SET(qpc, qp_context, rq_type, 0); // Regular recv queue
|
||||
+ DEVX_SET(qpc, qp_context, log_rq_size, IBGDA_ILOG2(num_recv_wqe)); // 4 wqe
|
||||
+ DEVX_SET(qpc, qp_context, log_rq_stride, IBGDA_ILOG2(recv_wqe_size) - 4); // max recv wqe size = 16B
|
||||
+ } else {
|
||||
+ DEVX_SET(qpc, qp_context, rq_type, IBGDA_SRQ_TYPE_VALUE); // Shared Receive Queue, DC must use this.
|
||||
+ DEVX_SET(qpc, qp_context, srqn_rmpn_xrqn, device->qp_shared_object.srqn);
|
||||
+ DEVX_SET(qpc, qp_context, log_rq_size, 0);
|
||||
+ }
|
||||
|
||||
ep->devx_qp = mlx5dv_devx_obj_create(context, cmd_in, sizeof(cmd_in), cmd_out, sizeof(cmd_out));
|
||||
NVSHMEMI_NULL_ERROR_JMP(ep->devx_qp, status, NVSHMEMX_ERROR_INTERNAL, out,
|
||||
@@ -1944,9 +1961,9 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
|
||||
ep->portid = portid;
|
||||
|
||||
ep->sq_cnt = num_wqebb;
|
||||
- ep->sq_buf_offset = 0;
|
||||
+ ep->sq_buf_offset = num_recv_wqe * recv_wqe_size;
|
||||
|
||||
- ep->rq_cnt = 0;
|
||||
+ ep->rq_cnt = num_recv_wqe;
|
||||
ep->rq_buf_offset = 0;
|
||||
|
||||
ep->wq_mobject = device->qp_shared_object.wq_mobject;
|
||||
@@ -1960,6 +1977,7 @@ static int ibgda_create_qp(struct ibgda_ep **ep_ptr, struct ibgda_device *device
|
||||
ep->uar_mobject = uar_mobject;
|
||||
|
||||
ep->send_cq = send_cq;
|
||||
+ ep->recv_cq = recv_cq;
|
||||
|
||||
ep->qp_type = qp_type;
|
||||
|
||||
@@ -1971,6 +1989,7 @@ out:
|
||||
if (status) {
|
||||
if (uar_mobject) ibgda_unmap_and_free_qp_uar(uar_mobject);
|
||||
if (send_cq) ibgda_destroy_cq(send_cq);
|
||||
+ if (recv_cq) ibgda_destroy_cq(recv_cq);
|
||||
if (ep) free(ep);
|
||||
}
|
||||
|
||||
@@ -2269,6 +2288,10 @@ static int ibgda_destroy_ep(struct ibgda_ep *ep) {
|
||||
ibgda_destroy_cq(ep->send_cq);
|
||||
}
|
||||
|
||||
+ if (ep->recv_cq) {
|
||||
+ ibgda_destroy_cq(ep->recv_cq);
|
||||
+ }
|
||||
+
|
||||
if (ep->ah) {
|
||||
ftable.destroy_ah(ep->ah);
|
||||
}
|
||||
@@ -2300,7 +2323,7 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
|
||||
dev_qp->qpn = ep->qpn;
|
||||
|
||||
assert(ep->wq_mobject->has_gpu_mapping);
|
||||
- dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset);
|
||||
+ dev_qp->tx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->sq_buf_offset);
|
||||
|
||||
if (ibgda_nic_handler == IBGDA_NIC_HANDLER_GPU) {
|
||||
assert(ep->dbr_mobject->has_gpu_mapping);
|
||||
@@ -2312,6 +2335,12 @@ static void ibgda_get_device_qp(nvshmemi_ibgda_device_qp_t *dev_qp, struct ibgda
|
||||
}
|
||||
|
||||
dev_qp->tx_wq.nwqes = ep->sq_cnt;
|
||||
+ if (ep->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC) {
|
||||
+ dev_qp->rx_wq.nwqes = ep->rq_cnt;
|
||||
+ dev_qp->rx_wq.wqe = (void *)((uintptr_t)ep->wq_mobject->aligned.gpu_ptr + ep->wq_offset + ep->rq_buf_offset);
|
||||
+ dev_qp->rx_wq.dbrec = (__be32 *)((uintptr_t)ep->dbr_mobject->aligned.gpu_ptr + ep->dbr_offset);
|
||||
+ dev_qp->rx_wq.bf = (void *)ep->uar_mobject->aligned.gpu_ptr;
|
||||
+ }
|
||||
|
||||
ibuf_dci_start = (uintptr_t)device->qp_shared_object.internal_buf.mem_object->aligned.gpu_ptr;
|
||||
ibuf_rc_start = ibuf_dci_start + (size_per_dci * device->dci.num_eps);
|
||||
@@ -2361,6 +2390,9 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
nvshmemi_ibgda_device_cq_t *cq_d = NULL;
|
||||
nvshmemi_ibgda_device_cq_t *cq_h = NULL;
|
||||
|
||||
+ nvshmemi_ibgda_device_cq_t *recv_cq_d = NULL;
|
||||
+ nvshmemi_ibgda_device_cq_t *recv_cq_h = NULL;
|
||||
+
|
||||
uint8_t *qp_group_switches_d = NULL;
|
||||
|
||||
const size_t mvars_offset = offsetof(nvshmemi_ibgda_device_qp_t, mvars);
|
||||
@@ -2368,6 +2400,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
const size_t cons_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.cons_idx);
|
||||
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
|
||||
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
|
||||
+ const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
|
||||
|
||||
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
|
||||
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
|
||||
@@ -2405,7 +2438,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
num_dct_handles += device->dct.num_eps * n_pes;
|
||||
num_dci_handles += device->dci.num_eps;
|
||||
num_rc_handles += device->rc.num_eps_per_pe * n_pes;
|
||||
- num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1));
|
||||
+ num_cq_handles += device->dci.num_eps + (device->rc.num_eps_per_pe * (n_pes - 1) * 2);
|
||||
num_shared_dci_handles += device->dci.num_shared_eps;
|
||||
}
|
||||
num_elements = num_dct_handles - NVSHMEMI_IBGDA_MAX_CONST_DCTS;
|
||||
@@ -2441,6 +2474,10 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
for (int i = 0; i < num_cq_handles; i++) {
|
||||
nvshmemi_init_ibgda_device_cq(cq_h[i]);
|
||||
}
|
||||
+
|
||||
+ recv_cq_h = (nvshmemi_ibgda_device_cq_t *)calloc(1, sizeof(*recv_cq_h));
|
||||
+ NVSHMEMI_NULL_ERROR_JMP(recv_cq_h, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "recv_cq calloc err.");
|
||||
+ nvshmemi_init_ibgda_device_cq(recv_cq_h[0]);
|
||||
/* allocate host memory for dct, rc, cq, dci end */
|
||||
|
||||
/* allocate device memory for dct, rc, cq, dci start */
|
||||
@@ -2544,6 +2581,14 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
}
|
||||
|
||||
++cq_idx;
|
||||
+
|
||||
+ rc_h[arr_idx].rx_wq.cq = &cq_d[cq_idx];
|
||||
+
|
||||
+ ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
|
||||
+ cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
|
||||
+ cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
|
||||
+ cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
|
||||
+ ++cq_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
--
|
||||
2.25.1
|
||||
|
||||
|
||||
From 0cc285269f154049f1c9775e07e306e03228eedc Mon Sep 17 00:00:00 2001
|
||||
From: Shangyan Zhou <sy.zhou@deepseek.com>
|
||||
Date: Sat, 8 Feb 2025 18:02:39 +0800
|
||||
Subject: [PATCH 4/5] Maintain recv queue's cons_idx.
|
||||
|
||||
---
|
||||
src/include/device_host_transport/nvshmem_common_ibgda.h | 5 +++--
|
||||
src/modules/transport/ibgda/ibgda.cpp | 6 ++++--
|
||||
2 files changed, 7 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
index 7d4e250..502645d 100644
|
||||
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
@@ -170,6 +170,7 @@ typedef struct {
|
||||
} tx_wq;
|
||||
struct {
|
||||
uint64_t resv_head; // last reserved wqe idx + 1
|
||||
+ uint64_t cons_idx; // polled wqe idx + 1 (consumer index + 1)
|
||||
} rx_wq;
|
||||
struct {
|
||||
uint64_t head;
|
||||
@@ -177,7 +178,7 @@ typedef struct {
|
||||
} ibuf;
|
||||
char padding[NVSHMEMI_IBGDA_QP_MANAGEMENT_PADDING];
|
||||
} __attribute__((__aligned__(8))) nvshmemi_ibgda_device_qp_management_v1;
|
||||
-static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 104,
|
||||
+static_assert(sizeof(nvshmemi_ibgda_device_qp_management_v1) == 112,
|
||||
"ibgda_device_qp_management_v1 must be 104 bytes.");
|
||||
|
||||
typedef nvshmemi_ibgda_device_qp_management_v1 nvshmemi_ibgda_device_qp_management_t;
|
||||
@@ -214,7 +215,7 @@ typedef struct nvshmemi_ibgda_device_qp {
|
||||
} rx_wq;
|
||||
nvshmemi_ibgda_device_qp_management_v1 mvars; // management variables
|
||||
} nvshmemi_ibgda_device_qp_v1;
|
||||
-static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 248, "ibgda_device_qp_v1 must be 248 bytes.");
|
||||
+static_assert(sizeof(nvshmemi_ibgda_device_qp_v1) == 256, "ibgda_device_qp_v1 must be 248 bytes.");
|
||||
|
||||
typedef nvshmemi_ibgda_device_qp_v1 nvshmemi_ibgda_device_qp_t;
|
||||
|
||||
diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp
|
||||
index b8d6bc7..a1cfe2e 100644
|
||||
--- a/src/modules/transport/ibgda/ibgda.cpp
|
||||
+++ b/src/modules/transport/ibgda/ibgda.cpp
|
||||
@@ -1063,7 +1063,7 @@ static inline void ibgda_nic_control_free(struct ibgda_mem_object *mobject) {
|
||||
ibgda_host_mem_free(mobject);
|
||||
}
|
||||
|
||||
-static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device) {
|
||||
+static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device, int cc = 1) {
|
||||
int status = 0;
|
||||
|
||||
struct ibgda_cq *gcq = NULL;
|
||||
@@ -1114,7 +1114,7 @@ static int ibgda_create_cq(struct ibgda_cq **pgcq, struct ibgda_device *device)
|
||||
cq_context = DEVX_ADDR_OF(create_cq_in, cmd_in, cq_context);
|
||||
DEVX_SET(cqc, cq_context, dbr_umem_valid, IBGDA_MLX5_UMEM_VALID_ENABLE);
|
||||
DEVX_SET(cqc, cq_context, cqe_sz, MLX5_CQE_SIZE_64B);
|
||||
- DEVX_SET(cqc, cq_context, cc, 0x1); // Use collapsed CQ
|
||||
+ DEVX_SET(cqc, cq_context, cc, cc); // Use collapsed CQ
|
||||
DEVX_SET(cqc, cq_context, oi, 0x1); // Allow overrun
|
||||
DEVX_SET(cqc, cq_context, dbr_umem_id, dbr_umem->umem_id);
|
||||
DEVX_SET(cqc, cq_context, log_cq_size, IBGDA_ILOG2_OR0(num_cqe));
|
||||
@@ -2401,6 +2401,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
const size_t wqe_h_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.resv_head);
|
||||
const size_t wqe_t_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, tx_wq.ready_head);
|
||||
const size_t rx_resv_head_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.resv_head);
|
||||
+ const size_t rx_cons_offset = offsetof(nvshmemi_ibgda_device_qp_management_t, rx_wq.cons_idx);
|
||||
|
||||
nvshmemi_ibgda_device_qp_map_type_t rc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
|
||||
nvshmemi_ibgda_device_qp_map_type_t dc_map_type = NVSHMEMI_IBGDA_DEVICE_QP_MAP_TYPE_INVALID;
|
||||
@@ -2586,6 +2587,7 @@ static int ibgda_setup_gpu_state(nvshmem_transport_t t) {
|
||||
|
||||
ibgda_get_device_cq(&cq_h[cq_idx], device->rc.eps[i]->recv_cq);
|
||||
cq_h[cq_idx].resv_head = (uint64_t *)(base_mvars_d_addr + rx_resv_head_offset);
|
||||
+ cq_h[cq_idx].cons_idx = (uint64_t *)(base_mvars_d_addr + rx_cons_offset);
|
||||
cq_h[cq_idx].qpn = rc_h[arr_idx].qpn;
|
||||
cq_h[cq_idx].qp_type = rc_h[arr_idx].qp_type;
|
||||
++cq_idx;
|
||||
--
|
||||
2.25.1
|
||||
|
||||
|
||||
From f91eb8510f8c9aa4f5769bd88434db5ab000e65a Mon Sep 17 00:00:00 2001
|
||||
From: Shangyan Zhou <sy.zhou@deepseek.com>
|
||||
Date: Tue, 11 Feb 2025 11:00:57 +0800
|
||||
Subject: [PATCH 5/5] Init rx_wq counters.
|
||||
|
||||
---
|
||||
src/include/device_host_transport/nvshmem_common_ibgda.h | 2 ++
|
||||
1 file changed, 2 insertions(+)
|
||||
|
||||
diff --git a/src/include/device_host_transport/nvshmem_common_ibgda.h b/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
index 502645d..f0bc328 100644
|
||||
--- a/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
+++ b/src/include/device_host_transport/nvshmem_common_ibgda.h
|
||||
@@ -46,6 +46,8 @@
|
||||
qp_man.tx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
qp_man.tx_wq.get_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
qp_man.tx_wq.get_tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
+ qp_man.rx_wq.resv_head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
+ qp_man.rx_wq.cons_idx = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
qp_man.ibuf.head = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
qp_man.ibuf.tail = NVSHMEMI_IBGDA_ULSCALAR_INVALID; \
|
||||
} while (0);
|
||||
--
|
||||
2.25.1
|
||||
|
||||
Reference in New Issue
Block a user