Update deep_ep.cpp (#242)

This commit is contained in:
fzyzcjy 2025-06-23 11:44:06 +08:00 committed by GitHub
parent 7b0c25f864
commit c95997f8c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -35,7 +35,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
#ifdef DISABLE_NVSHMEM
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation");
#endif
// Get device info
@ -151,7 +151,7 @@ pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
auto unique_id = internode::get_unique_id();
return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
#endif
}
@ -895,7 +895,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head, event};
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
return {};
#endif
}
@ -1016,7 +1016,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
// Return values
return {combined_x, combined_topk_weights, event};
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
return {};
#endif
}
@ -1040,7 +1040,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
clean_meta_1.first, clean_meta_1.second,
at::cuda::getCurrentCUDAStream());
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
#endif
}
@ -1149,7 +1149,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
return {};
#endif
}
@ -1242,7 +1242,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Return values
return {combined_x, event, recv_hook};
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
return {};
#endif
}
@ -1262,7 +1262,7 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
#else
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
return {};
#endif
}