mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Several code lints
This commit is contained in:
parent
3e54b78fd7
commit
edbb1bc3ff
10
README.md
10
README.md
@ -17,19 +17,11 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c
|
||||
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
|
||||
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
|
||||
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
|
||||
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
|
||||
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
|
||||
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
|
||||
|
||||
Through in-depth optimization, the following enhancements have been implemented in the Internode Normal Kernel: 1) Replacing IBRC with IBGDA, and 2) Utilizing distinct QPs (Queue Pairs) per channel for parallel data transmission. These improvements not only enhance the robustness of the Internode Normal Kernel in scenarios involving dual-port NICs and RoCE networks but also further elevate communication performance.
|
||||
|
||||
| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth |
|
||||
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
|
||||
| Internode | 16 | 47 GB/s (RDMA) | 16 | 62 GB/s (RDMA) |
|
||||
| Internode | 32 | 59 GB/s (RDMA) | 32 | 60 GB/s (RDMA) |
|
||||
| Internode | 64 | 49 GB/s (RDMA) | 64 | 51 GB/s (RDMA) |
|
||||
|
||||
The performance optimization solution for Internode Normal Kernel was jointly completed by our team and Tencent Network Platform Department.
|
||||
**News (2025.04.22)**: the performance is optimized by 5-35% by Tencent Network Platform Department, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution!
|
||||
|
||||
### Low-latency kernels with pure RDMA
|
||||
|
||||
|
@ -413,7 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
nvshmemx_signal_op(reinterpret_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
nvshmemx_signal_op(static_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
} else {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
|
||||
|
@ -573,10 +573,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Issue RDMA for non-local ranks
|
||||
if (dst_rdma_rank != rdma_rank and lane_id == 0) {
|
||||
nvshmemi_ibgda_put_nbi_thread(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
|
||||
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
|
||||
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
|
||||
channel_id, false);
|
||||
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
|
||||
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
|
||||
channel_id, false);
|
||||
}
|
||||
}
|
||||
sync_rdma_sender_smem();
|
||||
@ -724,9 +724,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
if (lane_id == dst_rdma_rank)
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
}
|
||||
} else {
|
||||
// Lighter fence for local RDMA rank
|
||||
memory_fence();
|
||||
@ -1573,9 +1574,11 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
if (lane_id == 0)
|
||||
if (lane_id == 0) {
|
||||
// TODO: use the full warp to do this
|
||||
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
}
|
||||
} else {
|
||||
memory_fence();
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class Buffer:
|
||||
|
||||
def __init__(self, group: dist.ProcessGroup,
|
||||
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None:
|
||||
"""
|
||||
Initialize the communication buffer.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user