#include #include #include #include #include #include #include #include #include "deep_ep.hpp" #include "kernels/api.cuh" #include "kernels/configs.cuh" namespace deep_ep { Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode), comm_stream(at::cuda::getStreamFromPool(true)) { // Task fifo memory int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; // Common checks EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); // Get ranks CUDA_CHECK(cudaGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handle CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); // Set task fifo EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); task_fifo_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); } // Create 32 MiB workspace CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); *moe_recv_counter = -1; // MoE expert-level counter CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i) moe_recv_expert_counter[i] = -1; // MoE RDMA-level counter if (num_rdma_ranks > 0) { CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); *moe_recv_rdma_counter = -1; } } Buffer::~Buffer() noexcept(false) { // Synchronize CUDA_CHECK(cudaDeviceSynchronize()); if (num_nvl_bytes > 0) { // Barrier intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); move_fifo_slots(); CUDA_CHECK(cudaDeviceSynchronize()); // Close remote IPC if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); } // Free local buffer and error flag CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); } // Free NVSHMEM if (num_rdma_bytes > 0) { CUDA_CHECK(cudaDeviceSynchronize()); internode::barrier(); internode::free(rdma_buffer_ptr); internode::finalize(); } // Free cuBLAS handle, workspace and MoE counter CUDA_CHECK(cudaFree(workspace)); CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); // Free chunked mode staffs CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); } void Buffer::move_fifo_slots(int num_slots) { head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; } bool Buffer::is_available() const { return available; } bool Buffer::is_internode_available() const { return is_available() and num_ranks > NUM_MAX_NVL_PEERS; } int Buffer::get_num_rdma_ranks() const { return num_rdma_ranks; } int Buffer::get_rdma_rank() const { return rdma_rank; } int Buffer::get_root_rdma_rank(bool global) const { return global ? nvl_rank : 0; } int Buffer::get_local_device_id() const { return device_id; } pybind11::bytearray Buffer::get_local_ipc_handle() const { return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; } torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); auto element_bytes = static_cast(elementSize(casted_dtype)); auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } void Buffer::sync(const std::vector &device_ids, const std::vector> &all_gathered_handles, const std::optional& root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); // Sync IPC handles if (num_nvl_bytes > 0) { EP_HOST_ASSERT(num_ranks == device_ids.size()); EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); if (offset + i != rank) { std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); } } // Copy all buffer and task pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaDeviceSynchronize()); } // Sync NVSHMEM handles and allocate memory if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); std::vector root_unique_id(root_unique_id_opt->size()); auto root_unique_id_str = root_unique_id_opt->cast(); std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size()); auto nvshmem_rank = low_latency_mode ? rank : rdma_rank; auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); internode::barrier(); // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); // Clean buffer (mainly for low-latency mode) CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); // Barrier internode::barrier(); CUDA_CHECK(cudaDeviceSynchronize()); } // Ready to use available = true; } std::tuple, torch::Tensor, torch::Tensor, std::optional> Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(topk_idx.dim() == 2); EP_HOST_ASSERT(topk_idx.is_contiguous()); EP_HOST_ASSERT(num_experts > 0); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); auto num_tokens_per_rdma_rank = std::optional(); auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA)); auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA)); if (is_internode_available()) num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); internode::get_dispatch_layout(topk_idx.data_ptr(), num_tokens_per_rank.data_ptr(), num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, num_tokens_per_expert.data_ptr(), is_token_in_rank.data_ptr(), num_tokens, num_topk, num_ranks, num_experts, comm_stream); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to: {num_tokens_per_rdma_rank}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; } std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); } else { EP_HOST_ASSERT(num_tokens_per_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_expert.has_value()); } // Type checks EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); } else { EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); if (cached_mode) { EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels); } else { EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; int64_t* topk_idx_ptr = nullptr; float* topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; int num_scales = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = x_scales->data_ptr(); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) int num_recv_tokens = -1; auto rank_prefix_matrix = torch::Tensor(); auto channel_prefix_matrix = torch::Tensor(); std::vector num_recv_tokens_per_expert_list; // Barrier or send sizes // To clean: channel start/end offset, head and tail int num_memset_int = num_channels * num_ranks * 4; if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; rank_prefix_matrix = cached_rank_prefix_matrix.value(); channel_prefix_matrix = cached_channel_prefix_matrix.value(); // Copy rank prefix matrix and clean flags intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks, comm_stream); move_fifo_slots(2); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); // Send sizes // Meta information: // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` // NOTES: no more token dropping in this version *moe_recv_counter = -1; for (int i = 0; i < num_local_experts; ++ i) moe_recv_expert_counter[i] = -1; EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), rank_prefix_matrix.data_ptr(), num_memset_int, expert_alignment, buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, comm_stream, num_channels); move_fifo_slots(3); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); while (true) { // Read total count num_recv_tokens = static_cast(*moe_recv_counter); // Read per-expert count bool ready = (num_recv_tokens >= 0); for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) throw std::runtime_error("DeepEP error: CPU recv timeout"); } num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); // Assign pointers int64_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; float* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = recv_x_scales->data_ptr(); } // Dispatch EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix num_channels * num_ranks * sizeof(int) + // Channel start offset num_channels * num_ranks * sizeof(int) + // Channel end offset num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer <= num_nvl_bytes); intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), send_head.data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), num_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event}; } std::tuple, std::optional> Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); int num_channels = config.num_sms / 2; auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_recv_tokens = static_cast(send_head.size(0)); EP_HOST_ASSERT(src_idx.size(0) == num_tokens); EP_HOST_ASSERT(send_head.size(1) == num_ranks); EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } int num_topk = 0; auto recv_topk_weights = std::optional(); float* topk_weights_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } // Launch barrier and reset queue head and tail EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr(), num_channels, num_recv_tokens, num_channels * num_ranks * 2, task_fifo_ptrs_gpu, head, rank, num_ranks, comm_stream); // NOTES: this function uses two FIFO slots (barrier before and after) move_fifo_slots(2); // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer <= num_nvl_bytes); intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), recv_x.data_ptr(), recv_topk_weights_ptr, x.data_ptr(), topk_weights_ptr, src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to: {topk_weights, recv_topk_weights}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); return {recv_x, recv_topk_weights, event}; } std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); } else { EP_HOST_ASSERT(num_tokens_per_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); EP_HOST_ASSERT(num_tokens_per_expert.has_value()); } // Type checks if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); } else { EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); } // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); if (cached_mode) { EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); } else { EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; int64_t* topk_idx_ptr = nullptr; float* topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { num_topk = static_cast(topk_idx->size(1)); EP_HOST_ASSERT(num_experts > 0); EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; int num_scales = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = x_scales->data_ptr(); } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) int num_recv_tokens = -1, num_rdma_recv_tokens = -1; auto rdma_channel_prefix_matrix = torch::Tensor(); auto recv_rdma_rank_prefix_sum = torch::Tensor(); auto gbl_channel_prefix_matrix = torch::Tensor(); auto recv_gbl_rank_prefix_sum = torch::Tensor(); std::vector num_recv_tokens_per_expert_list; // Barrier or send sizes if (cached_mode) { num_recv_tokens = cached_num_recv_tokens; num_rdma_recv_tokens = cached_num_rdma_recv_tokens; rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); // Just a barrier and clean flags internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, num_ranks, num_channels, 0, nullptr, nullptr, nullptr, nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, low_latency_mode); move_fifo_slots(2); } else { rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); // Send sizes *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; for (int i = 0; i < num_local_experts; ++ i) moe_recv_expert_counter[i] = -1; internode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, is_token_in_rank.data_ptr(), num_tokens, num_channels, hidden_int4, num_scales, num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); move_fifo_slots(3); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); while (true) { // Read total count num_recv_tokens = static_cast(*moe_recv_counter); num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); // Read per-expert count bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); for (int i = 0; i < num_local_experts and ready; ++ i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); for (int i = 0; i < num_local_experts; ++ i) printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); } } num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); auto recv_src_meta = std::optional(); auto recv_rdma_channel_prefix_matrix = std::optional(); auto recv_gbl_channel_prefix_matrix = std::optional(); auto send_rdma_head = std::optional(); auto send_nvl_head = std::optional(); if (not cached_mode) { recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA)); } // Assign pointers int64_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; float* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = recv_x_scales->data_ptr(); } // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, cached_mode ? nullptr : recv_src_meta->data_ptr(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), num_tokens, hidden_int4, num_scales, num_topk, num_experts, is_token_in_rank.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t: {x, is_token_in_rank, recv_x, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, recv_topk_idx, recv_topk_weights, recv_x_scales, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, recv_src_meta}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, send_nvl_head, event}; } std::tuple, std::optional> Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); at::cuda::setCurrentCUDAStream(comm_stream); } // Wait previous tasks to be finished if (previous_event.has_value()) { stream_wait(comm_stream, previous_event.value()); } else { stream_wait(comm_stream, compute_stream); } // Top-k checks int num_topk = 0; auto combined_topk_weights = std::optional(); float* topk_weights_ptr = nullptr; float* combined_topk_weights_ptr = nullptr; if (topk_weights.has_value()) { EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); num_topk = static_cast(topk_weights->size(1)); topk_weights_ptr = topk_weights->data_ptr(); combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); combined_topk_weights_ptr = combined_topk_weights->data_ptr(); } // Extra check for avoid-dead-lock design EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); // Launch barrier and reset queue head and tail internode::cached_notify(hidden_int4, 0, 0, num_topk, num_ranks, num_channels, num_combined_tokens, combined_rdma_head.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, task_fifo_ptrs_gpu, head, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode); move_fifo_slots(2); // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), combined_x.data_ptr(), combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), x.data_ptr(), topk_weights_ptr, combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); for (auto& t: {x, src_meta, is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, combined_x, combined_rdma_head, combined_nvl_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to: {topk_weights, combined_topk_weights}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { stream_wait(compute_stream, comm_stream); } // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); // Return values return {combined_x, combined_topk_weights, event}; } void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { EP_HOST_ASSERT(low_latency_mode); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto clean_meta_0 = layout.buffers[0].clean_meta(); auto clean_meta_1 = layout.buffers[1].clean_meta(); auto check_boundary = [=](void* ptr, size_t num_bytes) { auto offset = reinterpret_cast(ptr) - reinterpret_cast(rdma_buffer_ptr); EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes); }; check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second, clean_meta_1.first, clean_meta_1.second, at::cuda::getCurrentCUDAStream()); } std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, bool return_recv_hook) { EP_HOST_ASSERT(low_latency_mode); // Tensor checks // By default using `ptp128c` FP8 cast EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(num_experts % num_ranks == 0); auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); int num_local_experts = num_experts / num_ranks; // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); auto launch_stream = return_recv_hook ? compute_stream : comm_stream; EP_HOST_ASSERT(not (async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); // Allocate packed tensors auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); float* packed_recv_x_scales_ptr = nullptr; if (use_fp8) { EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), next_clean_meta.first, next_clean_meta.second, num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, workspace, launch_stream, phases); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); // Wait streams std::optional event; if (async) { // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, // so in Python API, we must wrap all tensors into the event handle. event = EventHandle(launch_stream); } else if (not return_recv_hook) { stream_wait(compute_stream, launch_stream); } // Receiver callback std::optional> recv_hook = std::nullopt; if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; } std::tuple, std::optional>> Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, int num_experts, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out) { EP_HOST_ASSERT(low_latency_mode); // Tensor checks EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); auto hidden = static_cast(x.size(2)); auto num_local_experts = num_experts / num_ranks, num_topk = static_cast(topk_weights.size(1)); auto num_combined_tokens = static_cast(topk_weights.size(0)); // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); auto launch_stream = return_recv_hook ? compute_stream : comm_stream; EP_HOST_ASSERT(not (async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); // Allocate output tensor torch::Tensor combined_x; if (out.has_value()) { EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); combined_x = out.value(); } else { combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), src_info.data_ptr(), layout_range.data_ptr(), next_clean_meta.first, next_clean_meta.second, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, workspace, launch_stream, phases, zero_copy); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); // Wait streams std::optional event; if (async) { // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, // so in Python API, we must wrap all tensors into the event handle. event = EventHandle(launch_stream); } else if (not return_recv_hook) { stream_wait(compute_stream, launch_stream); } // Receiver callback std::optional> recv_hook = std::nullopt; if (return_recv_hook) recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; // Return values return {combined_x, event, recv_hook}; } torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto buffer = layout.buffers[low_latency_buffer_idx]; auto dtype = torch::kBFloat16; auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); } } // namespace deep_ep PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepEP: an efficient expert-parallel communication library"; pybind11::class_(m, "Config") .def(pybind11::init(), py::arg("num_sms") = 20, py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); pybind11::class_(m, "EventHandle") .def(pybind11::init<>()) .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank) .def("get_local_device_id", &deep_ep::Buffer::get_local_device_id) .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) .def("sync", &deep_ep::Buffer::sync) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_combine", &deep_ep::Buffer::intranode_combine) .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch) .def("internode_combine", &deep_ep::Buffer::internode_combine) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); }