Improve AR performance

This commit is contained in:
Chenggang Zhao 2025-03-06 21:41:19 +08:00
parent 41385ba5b3
commit 1fc40d50f3
3 changed files with 8 additions and 4 deletions

View File

@ -34,8 +34,12 @@ struct Config {
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// Ceil up RDMA buffer size
this->num_max_rdma_chunked_recv_tokens = align<int>(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens);
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens);
// NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
}
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {

View File

@ -925,7 +925,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
break;
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks) {
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head;
@ -1655,7 +1655,7 @@ combine(int4* combined_x, float* combined_topk_weights,
#pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks) {
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;

View File

@ -255,7 +255,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
if (sub_warp_id == 1 and lane_id == 0) {
if (src_rank != rank) {
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
EP_DEVICE_ASSERT(num_recv_tokens != 0);
} else {
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);