DeepEP/csrc/deep_ep.cpp
2025-03-18 15:41:50 +08:00

1240 lines
68 KiB
C++

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <atomic>
#include <chrono>
#include <cuda_runtime.h>
#include <memory>
#include <pybind11/functional.h>
#include <torch/python.h>
#include "deep_ep.hpp"
#include "kernels/api.cuh"
#include "kernels/configs.cuh"
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode):
rank(rank), num_ranks(num_ranks),
num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes),
low_latency_mode(low_latency_mode),
comm_stream(at::cuda::getStreamFromPool(true)) {
// Task fifo memory
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));
EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
if (num_rdma_bytes > 0)
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode);
// Get ranks
CUDA_CHECK(cudaGetDevice(&device_id));
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
// Get device info
cudaDeviceProp device_prop = {};
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handle
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
buffer_ptrs_gpu = reinterpret_cast<void**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes);
// Set task fifo
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
task_fifo_ptrs_gpu = reinterpret_cast<int**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes);
// No need to synchronize, will do a full device sync during `sync`
CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream));
}
// Create 32 MiB workspace
CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES));
CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream));
// MoE counter
CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast<int*>(moe_recv_counter), 0));
*moe_recv_counter = -1;
// MoE expert-level counter
CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast<int*>(moe_recv_expert_counter), 0));
for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i)
moe_recv_expert_counter[i] = -1;
// MoE RDMA-level counter
if (num_rdma_ranks > 0) {
CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped));
CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast<int*>(moe_recv_rdma_counter), 0));
*moe_recv_rdma_counter = -1;
}
}
Buffer::~Buffer() noexcept(false) {
// Synchronize
CUDA_CHECK(cudaDeviceSynchronize());
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
move_fifo_slots();
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
}
// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
}
// Free NVSHMEM
if (num_rdma_bytes > 0) {
CUDA_CHECK(cudaDeviceSynchronize());
internode::barrier();
internode::free(rdma_buffer_ptr);
internode::finalize();
}
// Free cuBLAS handle, workspace and MoE counter
CUDA_CHECK(cudaFree(workspace));
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_counter)));
// Free chunked mode staffs
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
}
void Buffer::move_fifo_slots(int num_slots) {
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
}
bool Buffer::is_available() const {
return available;
}
bool Buffer::is_internode_available() const {
return is_available() and num_ranks > NUM_MAX_NVL_PEERS;
}
int Buffer::get_num_rdma_ranks() const {
return num_rdma_ranks;
}
int Buffer::get_rdma_rank() const {
return rdma_rank;
}
int Buffer::get_root_rdma_rank(bool global) const {
return global ? nvl_rank : 0;
}
int Buffer::get_local_device_id() const {
return device_id;
}
pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE};
}
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID");
auto unique_id = internode::get_unique_id();
return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};
}
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {
torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);
auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));
auto base_ptr = reinterpret_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}
void Buffer::sync(const std::vector<int> &device_ids,
const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt) {
EP_HOST_ASSERT(not is_available());
// Sync IPC handles
if (num_nvl_bytes > 0) {
EP_HOST_ASSERT(num_ranks == device_ids.size());
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) {
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE);
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
task_fifo_ptrs[i] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
}
}
// Copy all buffer and task pointers to GPU
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaDeviceSynchronize());
}
// Sync NVSHMEM handles and allocate memory
if (num_rdma_bytes > 0) {
// Initialize NVSHMEM
EP_HOST_ASSERT(root_unique_id_opt.has_value());
std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());
auto root_unique_id_str = root_unique_id_opt->cast<std::string>();
std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());
auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;
auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;
EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode));
internode::barrier();
// Allocate
rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);
// Clean buffer (mainly for low-latency mode)
CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes));
// Barrier
internode::barrier();
CUDA_CHECK(cudaDeviceSynchronize());
}
// Ready to use
available = true;
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
EP_HOST_ASSERT(topk_idx.dim() == 2);
EP_HOST_ASSERT(topk_idx.is_contiguous());
EP_HOST_ASSERT(num_experts > 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
auto num_tokens = static_cast<int>(topk_idx.size(0)), num_topk = static_cast<int>(topk_idx.size(1));
auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
auto num_tokens_per_rdma_rank = std::optional<torch::Tensor>();
auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA));
auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA));
if (is_internode_available())
num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
internode::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
num_tokens_per_rank.data_ptr<int>(),
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
num_tokens_per_expert.data_ptr<int>(),
is_token_in_rank.data_ptr<bool>(),
num_tokens, num_topk, num_ranks, num_experts,
comm_stream);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {num_tokens_per_rdma_rank}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
bool cached_mode = cached_rank_prefix_matrix.has_value();
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous());
EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks);
if (cached_mode) {
EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks);
EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels);
} else {
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks;
// Top-k checks
int num_topk = 0;
int64_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
int num_recv_tokens = -1;
auto rank_prefix_matrix = torch::Tensor();
auto channel_prefix_matrix = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
// Barrier or send sizes
// To clean: channel start/end offset, head and tail
int num_memset_int = num_channels * num_ranks * 4;
if (cached_mode) {
num_recv_tokens = cached_num_recv_tokens;
rank_prefix_matrix = cached_rank_prefix_matrix.value();
channel_prefix_matrix = cached_channel_prefix_matrix.value();
// Copy rank prefix matrix and clean flags
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(), num_memset_int,
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks,
comm_stream);
move_fifo_slots(2);
} else {
rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
// Meta information:
// - Size prefix by ranks, shaped as `[num_ranks, num_ranks]`
// - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]`
// NOTES: no more token dropping in this version
*moe_recv_counter = -1;
for (int i = 0; i < num_local_experts; ++ i)
moe_recv_expert_counter[i] = -1;
EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes);
intranode::notify_dispatch(num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped, num_experts,
num_tokens, is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
rank_prefix_matrix.data_ptr<int>(),
num_memset_int, expert_alignment,
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank,
comm_stream, num_channels);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS)
throw std::runtime_error("DeepEP error: CPU recv timeout");
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA));
auto recv_topk_idx = std::optional<torch::Tensor>(), recv_topk_weights = std::optional<torch::Tensor>(), recv_x_scales = std::optional<torch::Tensor>();
auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
// Assign pointers
int64_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
}
// Dispatch
EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix
num_channels * num_ranks * sizeof(int) + // Channel start offset
num_channels * num_ranks * sizeof(int) + // Channel end offset
num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer
<= num_nvl_bytes);
intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr<int>(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
num_tokens, static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales,
buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32);
// One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving.
EP_HOST_ASSERT(config.num_sms % 2 == 0);
int num_channels = config.num_sms / 2;
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_recv_tokens = static_cast<int>(send_head.size(0));
EP_HOST_ASSERT(src_idx.size(0) == num_tokens);
EP_HOST_ASSERT(send_head.size(1) == num_ranks);
EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks);
EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
int num_topk = 0;
auto recv_topk_weights = std::optional<torch::Tensor>();
float* topk_weights_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
num_topk = static_cast<int>(topk_weights->size(1));
topk_weights_ptr = topk_weights->data_ptr<float>();
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
// Launch barrier and reset queue head and tail
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);
intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr<int>(),
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
task_fifo_ptrs_gpu, head, rank, num_ranks,
comm_stream);
// NOTES: this function uses two FIFO slots (barrier before and after)
move_fifo_slots(2);
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer
num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer
<= num_nvl_bytes);
intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
recv_x.data_ptr(), recv_topk_weights_ptr,
x.data_ptr(), topk_weights_ptr,
src_idx.data_ptr<int>(), rank_prefix_matrix.data_ptr<int>(), channel_prefix_matrix.data_ptr<int>(),
send_head.data_ptr<int>(), num_tokens, num_recv_tokens, hidden, num_topk,
buffer_ptrs_gpu, rank, num_ranks,
comm_stream, config.num_sms,
config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, recv_topk_weights}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
return {recv_x, recv_topk_weights, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % 2 == 0);
EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS);
bool cached_mode = cached_rdma_channel_prefix_matrix.has_value();
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value());
} else {
EP_HOST_ASSERT(num_tokens_per_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value());
EP_HOST_ASSERT(num_tokens_per_expert.has_value());
}
// Type checks
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32);
}
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0);
if (cached_mode) {
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous());
EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels);
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous());
EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks);
} else {
EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous());
EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks);
EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0);
EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS);
}
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)), hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_experts = cached_mode ? 0 : static_cast<int>(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks;
// Top-k checks
int num_topk = 0;
int64_t* topk_idx_ptr = nullptr;
float* topk_weights_ptr = nullptr;
EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value());
if (topk_idx.has_value()) {
num_topk = static_cast<int>(topk_idx->size(1));
EP_HOST_ASSERT(num_experts > 0);
EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous());
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0));
EP_HOST_ASSERT(num_topk == topk_weights->size(1));
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
topk_idx_ptr = topk_idx->data_ptr<int64_t>();
topk_weights_ptr = topk_weights->data_ptr<float>();
}
// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous());
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
}
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Create handles (only return for non-cached mode)
int num_recv_tokens = -1, num_rdma_recv_tokens = -1;
auto rdma_channel_prefix_matrix = torch::Tensor();
auto recv_rdma_rank_prefix_sum = torch::Tensor();
auto gbl_channel_prefix_matrix = torch::Tensor();
auto recv_gbl_rank_prefix_sum = torch::Tensor();
std::vector<int> num_recv_tokens_per_expert_list;
// Barrier or send sizes
if (cached_mode) {
num_recv_tokens = cached_num_recv_tokens;
num_rdma_recv_tokens = cached_num_rdma_recv_tokens;
rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value();
recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value();
gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value();
recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value();
// Just a barrier and clean flags
internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk,
num_ranks, num_channels, 0, nullptr,
nullptr, nullptr, nullptr,
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, true, low_latency_mode);
move_fifo_slots(2);
} else {
rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
// Send sizes
*moe_recv_counter = -1, *moe_recv_rdma_counter = -1;
for (int i = 0; i < num_local_experts; ++ i)
moe_recv_expert_counter[i] = -1;
internode::notify_dispatch(num_tokens_per_rank->data_ptr<int>(), moe_recv_counter_mapped, num_ranks,
num_tokens_per_rdma_rank->data_ptr<int>(), moe_recv_rdma_counter_mapped,
num_tokens_per_expert->data_ptr<int>(), moe_recv_expert_counter_mapped, num_experts,
is_token_in_rank.data_ptr<bool>(), num_tokens, num_channels,
hidden_int4, num_scales, num_topk, expert_alignment,
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, low_latency_mode);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
// Read per-expert count
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++ i)
ready &= moe_recv_expert_counter[i] >= 0;
if (ready)
break;
// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) {
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
for (int i = 0; i < num_local_experts; ++ i)
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
}
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
// Allocate new tensors
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
auto recv_topk_idx = std::optional<torch::Tensor>(), recv_topk_weights = std::optional<torch::Tensor>(), recv_x_scales = std::optional<torch::Tensor>();
auto recv_src_meta = std::optional<torch::Tensor>();
auto recv_rdma_channel_prefix_matrix = std::optional<torch::Tensor>();
auto recv_gbl_channel_prefix_matrix = std::optional<torch::Tensor>();
auto send_rdma_head = std::optional<torch::Tensor>();
auto send_nvl_head = std::optional<torch::Tensor>();
if (not cached_mode) {
recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA));
recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA));
}
// Assign pointers
int64_t* recv_topk_idx_ptr = nullptr;
float* recv_topk_weights_ptr = nullptr;
float* recv_x_scales_ptr = nullptr;
if (topk_idx.has_value()) {
recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options());
recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options());
recv_topk_idx_ptr = recv_topk_idx->data_ptr<int64_t>();
recv_topk_weights_ptr = recv_topk_weights->data_ptr<float>();
}
if (x_scales.has_value()) {
recv_x_scales = x_scales->dim() == 1 ?
torch::empty({num_recv_tokens}, x_scales->options()) :
torch::empty({num_recv_tokens, num_scales}, x_scales->options());
recv_x_scales_ptr = recv_x_scales->data_ptr<float>();
}
// Launch data dispatch
// NOTES: the buffer size checks are moved into the `.cu` file
internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr,
cached_mode ? nullptr : recv_src_meta->data_ptr(),
x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr,
cached_mode ? nullptr : send_rdma_head->data_ptr<int>(), cached_mode ? nullptr : send_nvl_head->data_ptr<int>(),
cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr<int>(),
cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
num_tokens, hidden_int4, num_scales, num_topk, num_experts,
is_token_in_rank.data_ptr<bool>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, cached_mode,
comm_stream, num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, is_token_in_rank, recv_x,
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {x_scales, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert,
cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum,
cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum,
recv_topk_idx, recv_topk_weights, recv_x_scales,
recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head,
recv_src_meta}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list,
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
recv_src_meta, send_rdma_head, send_nvl_head, event};
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
const int num_channels = config.num_sms / 2;
EP_HOST_ASSERT(config.num_sms % 2 == 0);
// Shape and contiguous checks
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte);
EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32);
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1)), hidden_int4 = static_cast<int>(x.size(1) * x.element_size() / sizeof(int4));
auto num_combined_tokens = static_cast<int>(is_combined_token_in_rank.size(0));
EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0);
EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes());
EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks);
EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks);
EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels);
EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks);
EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS);
// Allocate all tensors on comm stream if set
// NOTES: do not allocate tensors upfront!
auto compute_stream = at::cuda::getCurrentCUDAStream();
if (allocate_on_comm_stream) {
EP_HOST_ASSERT(previous_event.has_value() and async);
at::cuda::setCurrentCUDAStream(comm_stream);
}
// Wait previous tasks to be finished
if (previous_event.has_value()) {
stream_wait(comm_stream, previous_event.value());
} else {
stream_wait(comm_stream, compute_stream);
}
// Top-k checks
int num_topk = 0;
auto combined_topk_weights = std::optional<torch::Tensor>();
float* topk_weights_ptr = nullptr;
float* combined_topk_weights_ptr = nullptr;
if (topk_weights.has_value()) {
EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous());
EP_HOST_ASSERT(topk_weights->size(0) == num_tokens);
EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32);
num_topk = static_cast<int>(topk_weights->size(1));
topk_weights_ptr = topk_weights->data_ptr<float>();
combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options());
combined_topk_weights_ptr = combined_topk_weights->data_ptr<float>();
}
// Extra check for avoid-dead-lock design
EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks);
// Launch barrier and reset queue head and tail
internode::cached_notify(hidden_int4, 0, 0, num_topk,
num_ranks, num_channels,
num_combined_tokens, combined_rdma_head.data_ptr<int>(),
rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, false, low_latency_mode);
move_fifo_slots(2);
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()),
combined_x.data_ptr(), combined_topk_weights_ptr,
is_combined_token_in_rank.data_ptr<bool>(),
x.data_ptr(), topk_weights_ptr,
combined_rdma_head.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), gbl_channel_prefix_matrix.data_ptr<int>(),
num_tokens, num_combined_tokens, hidden, num_topk,
rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens,
rank, num_ranks, comm_stream, num_channels, low_latency_mode);
// Wait streams
std::optional<EventHandle> event;
if (async) {
event = EventHandle(comm_stream);
for (auto& t: {x, src_meta,
is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
combined_x, combined_rdma_head, combined_nvl_head}) {
t.record_stream(comm_stream);
if (allocate_on_comm_stream)
t.record_stream(compute_stream);
}
for (auto& to: {topk_weights, combined_topk_weights}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}
} else {
stream_wait(compute_stream, comm_stream);
}
// Switch back compute stream
if (allocate_on_comm_stream)
at::cuda::setCurrentCUDAStream(compute_stream);
// Return values
return {combined_x, combined_topk_weights, event};
}
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta();
auto check_boundary = [=](void* ptr, size_t num_bytes) {
auto offset = reinterpret_cast<int64_t>(ptr) - reinterpret_cast<int64_t>(rdma_buffer_ptr);
EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes);
};
check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int));
check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int));
internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second,
clean_meta_1.first, clean_meta_1.second,
at::cuda::getCurrentCUDAStream());
}
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
// By default using `ptp128c` FP8 cast
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(num_experts % num_ranks == 0);
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(topk_idx.size(1));
int num_local_experts = num_experts / num_ranks;
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not (async and return_recv_hook));
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
float* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr<float>();
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, use_fp8,
workspace, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
}
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);
EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
auto hidden = static_cast<int>(x.size(2));
auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream
auto compute_stream = at::cuda::getCurrentCUDAStream();
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
EP_HOST_ASSERT(not (async and return_recv_hook));
if (not return_recv_hook)
stream_wait(launch_stream, compute_stream);
// Allocate output tensor
torch::Tensor combined_x;
if (out.has_value()) {
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
combined_x = out.value();
} else {
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {combined_x, event, recv_hook};
}
torch::Tensor
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0);
return torch::from_blob(buffer.combine_rdma_send_buffer_data_start,
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
}
} // namespace deep_ep
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepEP: an efficient expert-parallel communication library";
pybind11::class_<deep_ep::Config>(m, "Config")
.def(pybind11::init<int, int, int, int, int>(),
py::arg("num_sms") = 20,
py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256,
py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256)
.def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint)
.def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint);
m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint);
pybind11::class_<deep_ep::EventHandle>(m, "EventHandle")
.def(pybind11::init<>())
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
.def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank)
.def("get_local_device_id", &deep_ep::Buffer::get_local_device_id)
.def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle)
.def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id)
.def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor)
.def("sync", &deep_ep::Buffer::sync)
.def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout)
.def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch)
.def("intranode_combine", &deep_ep::Buffer::intranode_combine)
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch)
.def("internode_combine", &deep_ep::Buffer::internode_combine)
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
}