mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-05 20:44:48 +00:00
Merge pull request #130 from deepseek-ai/trmt/internode_multi_qp
Support multi-QP for normal kernels
This commit is contained in:
commit
007fcfcf97
@ -18,8 +18,10 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c
|
||||
|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:|
|
||||
| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) |
|
||||
| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) |
|
||||
| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) |
|
||||
| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) |
|
||||
| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) |
|
||||
| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) |
|
||||
|
||||
**News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution!
|
||||
|
||||
### Low-latency kernels with pure RDMA
|
||||
|
||||
|
@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
template <bool kAlwaysDoPostSend = false>
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||
// Get lkey and rkey, store them into lanes
|
||||
@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
|
||||
// Submit
|
||||
if (lane_id == 0)
|
||||
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
@ -410,20 +411,25 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
nvshmemx_signal_op(static_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
} else {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
|
||||
__be32 rkey;
|
||||
uint64_t raddr;
|
||||
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
|
||||
__be32 rkey;
|
||||
uint64_t raddr;
|
||||
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
|
||||
|
||||
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
|
||||
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
|
||||
|
||||
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
|
||||
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
|
||||
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
|
||||
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
|
||||
|
||||
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
|
||||
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
@ -479,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
const bool is_forwarder = sm_id % 2 == 0;
|
||||
const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels);
|
||||
|
||||
const auto role_meta = [=]() -> std::pair<WarpRole, int> {
|
||||
if (is_forwarder) {
|
||||
if (warp_id < NUM_MAX_NVL_PEERS) {
|
||||
@ -555,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Send number of tokens in this channel by `-value - 1`
|
||||
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
|
||||
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
|
||||
auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
|
||||
if (lane_id < NUM_MAX_NVL_PEERS) {
|
||||
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
|
||||
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
|
||||
} else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
|
||||
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
|
||||
dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
|
||||
} else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
|
||||
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
|
||||
dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
|
||||
} else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
|
||||
rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
|
||||
dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Issue RDMA for non-local ranks
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
|
||||
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
|
||||
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
|
||||
channel_id, lane_id, 0);
|
||||
}
|
||||
nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
}
|
||||
nvshmem_fence();
|
||||
sync_rdma_sender_smem();
|
||||
|
||||
// Iterate over tokens and copy into buffer
|
||||
@ -710,11 +721,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
|
||||
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
|
||||
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
num_bytes_per_rdma_token * num_tokens_to_issue,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
nvshmem_fence();
|
||||
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
|
||||
} else {
|
||||
// Lighter fence for local RDMA rank
|
||||
memory_fence();
|
||||
@ -725,8 +736,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
last_issued_tail += num_tokens_to_issue;
|
||||
num_tokens_to_send -= num_tokens_to_issue;
|
||||
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -926,8 +937,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Update remote head
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id, lane_id == rdma_rank);
|
||||
last_head = min_head;
|
||||
}
|
||||
|
||||
@ -1558,20 +1569,21 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
|
||||
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
num_chunked_tokens * num_bytes_per_rdma_token,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
nvshmem_fence();
|
||||
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
|
||||
} else {
|
||||
memory_fence();
|
||||
}
|
||||
|
||||
// Write new RDMA tail
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (lane_id == 0) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1656,8 +1668,8 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
|
||||
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank);
|
||||
last_rdma_head = min_head;
|
||||
}
|
||||
} else {
|
||||
|
@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
EP_DEVICE_ASSERT(num_sms > 1);
|
||||
if (sm_id == 0) {
|
||||
// The first SM is also responsible for checking QPs
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts);
|
||||
|
||||
// The first SM is also responsible for cleaning the next buffer
|
||||
#pragma unroll
|
||||
|
@ -58,14 +58,12 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
|
||||
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
|
||||
}
|
||||
|
||||
// Normal operations use IBRC, while low-latency operations use IBGDA
|
||||
if (low_latency_mode) {
|
||||
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
|
||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
|
||||
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
|
||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||
|
||||
bool ibgda_is_initialized = false;
|
||||
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
|
||||
}
|
||||
bool ibgda_is_initialized = false;
|
||||
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
|
||||
nvshmem_barrier_all();
|
||||
return nvshmem_my_pe();
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class Buffer:
|
||||
|
||||
def __init__(self, group: dist.ProcessGroup,
|
||||
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None:
|
||||
low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None:
|
||||
"""
|
||||
Initialize the communication buffer.
|
||||
|
||||
@ -66,17 +66,16 @@ class Buffer:
|
||||
# Synchronize NVSHMEM unique IDs
|
||||
root_unique_id = None
|
||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
|
||||
if low_latency_mode:
|
||||
assert num_qps_per_rank > 0
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
||||
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
|
||||
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
|
||||
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
|
||||
# NOTES: NVSHMEM initialization requires at least 256 MiB
|
||||
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
|
||||
# Enable IBGDA
|
||||
assert num_qps_per_rank > 0
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
|
||||
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
|
||||
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
|
||||
os.environ['NVSHMEM_QP_DEPTH'] = '1024'
|
||||
# NOTES: NVSHMEM initialization requires at least 256 MiB
|
||||
os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}'
|
||||
|
||||
# Synchronize using the root ID
|
||||
nvshmem_unique_ids = [None, ] * self.group_size
|
||||
|
@ -219,16 +219,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility = False
|
||||
test_ll_compatibility = True
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
|
||||
num_sms = 24
|
||||
num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0)
|
||||
|
||||
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24, ):
|
||||
for i in (num_sms, ):
|
||||
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
|
||||
if local_rank == 0:
|
||||
print('', flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user