mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support Ampere architecture (#204)
* Update README * Update `setup.py` * Fix headers * Add `DISABLE_NVSHMEM` for APIs * Fix launch * Fix TMA settings * Fix TMA usages * Fix dlink * Separate layout kernels * Update version * Add `is_sm90_compiled` * Fix tests * Add NVLink connection checks * Update README * Fix tests * Add some comments * Minor fix * Minor fix * Fix bugs
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cuda_runtime.h>
|
||||
#include <memory>
|
||||
@@ -35,6 +34,9 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
||||
CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
|
||||
#ifdef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
|
||||
// Get device info
|
||||
cudaDeviceProp device_prop = {};
|
||||
@@ -104,12 +106,14 @@ Buffer::~Buffer() noexcept(false) {
|
||||
}
|
||||
|
||||
// Free NVSHMEM
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
if (num_rdma_bytes > 0) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
internode::barrier();
|
||||
internode::free(rdma_buffer_ptr);
|
||||
internode::finalize();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Free workspace and MoE counter
|
||||
CUDA_CHECK(cudaFree(workspace));
|
||||
@@ -148,9 +152,13 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const {
|
||||
}
|
||||
|
||||
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID");
|
||||
auto unique_id = internode::get_unique_id();
|
||||
return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {
|
||||
@@ -190,6 +198,7 @@ void Buffer::sync(const std::vector<int> &device_ids,
|
||||
}
|
||||
|
||||
// Sync NVSHMEM handles and allocate memory
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
if (num_rdma_bytes > 0) {
|
||||
// Initialize NVSHMEM
|
||||
EP_HOST_ASSERT(root_unique_id_opt.has_value());
|
||||
@@ -211,6 +220,7 @@ void Buffer::sync(const std::vector<int> &device_ids,
|
||||
internode::barrier();
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
#endif
|
||||
|
||||
// Ready to use
|
||||
available = true;
|
||||
@@ -246,13 +256,13 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
|
||||
if (is_internode_available())
|
||||
num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
internode::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
|
||||
num_tokens_per_rank.data_ptr<int>(),
|
||||
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
|
||||
num_tokens_per_expert.data_ptr<int>(),
|
||||
is_token_in_rank.data_ptr<bool>(),
|
||||
num_tokens, num_topk, num_ranks, num_experts,
|
||||
comm_stream);
|
||||
layout::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
|
||||
num_tokens_per_rank.data_ptr<int>(),
|
||||
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
|
||||
num_tokens_per_expert.data_ptr<int>(),
|
||||
is_token_in_rank.data_ptr<bool>(),
|
||||
num_tokens, num_topk, num_ranks, num_experts,
|
||||
comm_stream);
|
||||
|
||||
// Wait streams
|
||||
std::optional<EventHandle> event;
|
||||
@@ -620,6 +630,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
|
||||
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
// In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long.
|
||||
// If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL
|
||||
// unless we release GIL here.
|
||||
@@ -882,6 +893,10 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
|
||||
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta, send_rdma_head, send_nvl_head, event};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
@@ -890,6 +905,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
const int num_channels = config.num_sms / 2;
|
||||
EP_HOST_ASSERT(config.num_sms % 2 == 0);
|
||||
|
||||
@@ -998,14 +1014,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
|
||||
// Return values
|
||||
return {combined_x, combined_topk_weights, event};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
uint64_t Buffer::get_low_latency_usage_flag() const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_usage_flag != nullptr);
|
||||
return reinterpret_cast<uint64_t>(low_latency_usage_flag);
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
@@ -1022,6 +1048,9 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second,
|
||||
clean_meta_1.first, clean_meta_1.second,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
@@ -1029,6 +1058,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@@ -1114,6 +1144,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
// Return values
|
||||
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
@@ -1122,6 +1156,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool zero_copy, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@@ -1141,7 +1176,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
|
||||
auto hidden = static_cast<int>(x.size(2));
|
||||
auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
|
||||
auto num_topk = static_cast<int>(topk_weights.size(1));
|
||||
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
|
||||
|
||||
// Buffer control
|
||||
@@ -1202,10 +1237,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
|
||||
// Return values
|
||||
return {combined_x, event, recv_hook};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
@@ -1217,6 +1257,18 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank
|
||||
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
|
||||
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_sm90_compiled() {
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
@@ -1258,4 +1310,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
|
||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
|
||||
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
|
||||
|
||||
m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user