diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 819db50..7ea5fcb 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -18,10 +18,10 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode), comm_stream(at::cuda::getStreamFromPool(true)) { - // Task fifo memory - int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; - int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; - int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); @@ -41,18 +41,17 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); if (num_nvl_bytes > 0) { - // Local IPC: alloc local memory and set local IPC handle - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); + // Local IPC: alloc local memory and set local IPC handles + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); + buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); - // Set task fifo - EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); - task_fifo_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes); + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace @@ -91,8 +90,7 @@ Buffer::~Buffer() noexcept(false) { if (num_nvl_bytes > 0) { // Barrier - intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); - move_fifo_slots(); + intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); CUDA_CHECK(cudaDeviceSynchronize()); // Close remote IPC @@ -121,10 +119,6 @@ Buffer::~Buffer() noexcept(false) { CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); } -void Buffer::move_fifo_slots(int num_slots) { - head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; -} - bool Buffer::is_available() const { return available; } @@ -162,7 +156,7 @@ pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); auto element_bytes = static_cast(elementSize(casted_dtype)); - auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; + auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } @@ -183,15 +177,15 @@ void Buffer::sync(const std::vector &device_ids, if (offset + i != rank) { std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); } } - // Copy all buffer and task pointers to GPU + // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaDeviceSynchronize()); } @@ -395,9 +389,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional(), num_memset_int, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks, + buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); - move_fifo_slots(2); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -416,9 +409,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional(), channel_prefix_matrix.data_ptr(), rank_prefix_matrix.data_ptr(), num_memset_int, expert_alignment, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, + buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, comm_stream, num_channels); - move_fifo_slots(3); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); @@ -565,12 +557,9 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional(), num_channels, num_recv_tokens, num_channels * num_ranks * 2, - task_fifo_ptrs_gpu, head, rank, num_ranks, + barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); - // NOTES: this function uses two FIFO slots (barrier before and after) - move_fifo_slots(2); - // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail @@ -746,10 +735,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional(), recv_gbl_rank_prefix_sum.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, + barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); - move_fifo_slots(3); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); @@ -958,10 +945,9 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, + barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode); - move_fifo_slots(2); // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index f193bcc..85e723c 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -51,10 +51,9 @@ private: // After IPC/NVSHMEM synchronization, this flag will be true bool available = false; - // Task fifo - int head = 0; - int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - int** task_fifo_ptrs_gpu = nullptr; + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; // Workspace void* workspace = nullptr; @@ -75,9 +74,6 @@ private: volatile int* low_latency_usage_flag = nullptr; int* low_latency_usage_flag_mapped = nullptr; -private: - void move_fifo_slots(int num_slots = 1); - public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 1ffa061..df7aece 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -7,7 +7,7 @@ namespace deep_ep { // Intranode runtime namespace intranode { -void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); } // namespace intranode @@ -35,11 +35,11 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int num_sms); void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, @@ -51,7 +51,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re int num_max_send_tokens, int num_recv_buffer_tokens); void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, @@ -84,7 +84,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode); @@ -106,7 +106,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode); diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index f6937da..cc1f914 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -2,7 +2,6 @@ #define NUM_MAX_NVL_PEERS 8 #define NUM_MAX_RDMA_PEERS 20 -#define NUM_MAX_FIFO_SLOTS 32768 #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_MAX_LOCAL_EXPERTS 1024 #define NUM_BUFFER_ALIGNMENT_BYTES 128 diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 3d02af4..c9a10a5 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -164,7 +164,7 @@ void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx __device__ static __forceinline__ void ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr, - __be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) { + __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) { ibgda_ctrl_seg_t ctrl_seg; struct mlx5_wqe_raddr_seg raddr_seg; struct mlx5_wqe_inl_data_seg inl_seg; @@ -277,7 +277,7 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t __device__ static __forceinline__ void ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx, - void **out_wqes) { + void** out_wqes) { ibgda_ctrl_seg_t ctrl_seg; struct mlx5_wqe_raddr_seg raddr_seg; struct mlx5_wqe_data_seg data_seg; @@ -373,7 +373,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, __device__ static __forceinline__ void ibgda_write_amo_add_wqe( nvshmemi_ibgda_device_qp_t *qp, const int &value, uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, - uint16_t wqe_idx, void **out_wqes) { + uint16_t wqe_idx, void** out_wqes) { ibgda_ctrl_seg_t ctrl_seg = {0}; struct mlx5_wqe_raddr_seg raddr_seg; struct mlx5_wqe_atomic_seg atomic_seg_1; @@ -448,40 +448,27 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co // This is a simplified version of NVSHMEM's `ibgda_poll_cq`. // Note that this implementation does not guarantee thread safety, // so we must ensure that no other threads are concurrently using the same QP. -__device__ static __forceinline__ int +__device__ static __forceinline__ void ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) { - int status = 0; - struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; - + const auto cqe64 = static_cast(cq->cqe); const uint32_t ncqes = cq->ncqes; + memory_fence_cta(); + // NOTES: this while loop is part of do-while below. + // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`. + // To be able to compare with the index, we need to use `wqe_counter + 1`. + // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for + // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than + // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1` + // That's why we use `- 2` here to make this case overflow. uint16_t wqe_counter; - uint16_t new_wqe_counter; - - memory_fence_cta(); - do { - new_wqe_counter = ld_na_relaxed(&cqe64->wqe_counter); - new_wqe_counter = HtoBE16(new_wqe_counter); - wqe_counter = new_wqe_counter; - } - // NOTE: This while loop is part of do while above. - // wqe_counter is the HW consumer index. However, we always maintain index - // + 1 in SW. To be able to compare with idx, we need to use wqe_counter + - // 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for - // sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than - // idx, and thus we need to wait. We don't need to wait when idx == - // wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case - // wraparound. - // Example: - // if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes. - while (((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes)); - + wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)); + } while ((static_cast(static_cast(idx) - wqe_counter - static_cast(2)) < ncqes)); *cq->cons_idx = idx; - // Prevent reordering of this function and subsequent instructions - memory_fence_cta(); - return status; + // Prevent reordering of this function and later instructions + memory_fence_cta(); } // Wait until wqe `idx - 1` is completed. diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 6593f96..35b4ae2 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -208,7 +208,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, const nvshmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); @@ -219,17 +219,15 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in if (sm_id == 0) { // Communication with others - // Global barrier: the first warp do intra-node sync, the second warp do internode sync + // Global barrier: the first warp does intra-node sync, the second warp does internode sync EP_DEVICE_ASSERT(num_warps > 1); EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, nvl_rank); // Send numbers of tokens per rank/expert to RDMA ranks - auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks); // Clean up for later data dispatch @@ -285,7 +283,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); // Clean up for later data dispatch - auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); #pragma unroll @@ -328,11 +326,9 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in } memory_fence(); __syncthreads(); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, nvl_rank); - // Reduce number of tokens per rank/expert + // Reduce the number of tokens per rank/expert EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); if (thread_id == 0) { int sum = 0; @@ -359,8 +355,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in __syncthreads(); if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); + barrier_block(barrier_signal_ptrs, nvl_rank); } else { // Calculate meta data int dst_rdma_rank = sm_id - 1; @@ -423,7 +418,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ @@ -439,7 +434,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ rdma_buffer_ptr, \ - buffer_ptrs, task_fifo_ptrs, head, rank, \ + buffer_ptrs, barrier_signal_ptrs, rank, \ cpu_rdma_team); } break constexpr int kNumThreads = 512; @@ -698,7 +693,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Release sequential lock lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; } else if (warp_role == WarpRole::kRDMASenderCoordinator) { - // NOTES: in case of splitting the issued put at the end of the buffer + // NOTES: in case of splitting, the issued put at the end of the buffer EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); // Synchronize shared memory @@ -861,7 +856,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; - auto src_meta = ld_nc_global(reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes)); + auto src_meta = ld_nc_global(reinterpret_cast(static_cast(shifted) + hidden_bytes)); lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); if (lane_id == src_rdma_rank) { @@ -881,27 +876,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, reinterpret_cast(shifted), ld_nc_global, st_na_global); - shifted = reinterpret_cast(shifted) + hidden_int4; + shifted = static_cast(shifted) + hidden_int4; // Copy source meta if (lane_id == 0) st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); - shifted = reinterpret_cast(shifted) + 1; + shifted = static_cast(shifted) + 1; // Copy `x_scales` UNROLLED_WARP_COPY(1, lane_id, num_scales, nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, reinterpret_cast(shifted), ld_nc_global, st_na_global); - shifted = reinterpret_cast(shifted) + num_scales; + shifted = static_cast(shifted) + num_scales; // Copy `topk_idx` and `topk_weights` // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted if (lane_id < num_topk) { // Read - auto idx_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); - shifted = reinterpret_cast(shifted) + num_topk; - auto weight_value = ld_nc_global(reinterpret_cast(shifted) + lane_id); + auto idx_value = ld_nc_global(static_cast(shifted) + lane_id); + shifted = static_cast(shifted) + num_topk; + auto weight_value = ld_nc_global(static_cast(shifted) + lane_id); // Transform and write idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; @@ -1107,7 +1102,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in int* combined_rdma_head, int num_combined_tokens, int num_channels, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, bool is_cached_dispatch, const nvshmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); @@ -1127,7 +1122,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in __syncthreads(); // Clean - auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); #pragma unroll for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; @@ -1138,12 +1133,10 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in nvshmem_sync_with_same_gpu_idx(rdma_team); } else if (sm_id == 1) { // Barrier for NVL - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, nvl_rank); // Clean - auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); #pragma unroll for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; @@ -1151,8 +1144,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in __syncthreads(); // Barrier again - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); + barrier_block(barrier_signal_ptrs, nvl_rank); } else if (sm_id == 2) { if (is_cached_dispatch) return; @@ -1213,7 +1205,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode) { const int num_threads = std::max(128, 32 * num_channels); @@ -1237,7 +1229,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, - buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks, + buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch, cpu_rdma_team); } @@ -1397,7 +1389,7 @@ combine(int4* combined_x, float* combined_topk_weights, if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) break; - // Decide next RDMA buffer to send + // Decide the next RDMA buffer to send bool is_lane_ready = false; auto start_time = clock64(); while (true) { @@ -1575,8 +1567,8 @@ combine(int4* combined_x, float* combined_topk_weights, combine_token(expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, - reinterpret_cast(shifted), - reinterpret_cast(reinterpret_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + static_cast(shifted), + reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); // Update head diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 1452c8e..7d850e1 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -22,7 +22,7 @@ __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) clean_1[i] = 0; - // Barrier after cleaning (make sure low-latency mode work fine) + // Barrier after cleaning (make sure the low-latency mode works fine) nvshmemx_barrier_all_block(); } diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 2954ac2..cd545a0 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -14,16 +14,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) { + void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32; if (sm_id == 0) { // Barrier first - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { @@ -46,9 +44,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, __syncthreads(); // Wait for all ranks to be finished - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); // Sum per-rank counts and return to CPU // Also pre-compute the prefix sum for data sending @@ -86,7 +82,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, // Barrier memory_fence(); __syncthreads(); - barrier_device(task_fifo_ptrs, head, rank); + barrier_block(barrier_signal_ptrs, rank); } else { int dst_rank = sm_id - 1; for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { @@ -116,7 +112,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, cudaStream_t stream, int num_channels) { #define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, notify_dispatch, \ @@ -124,7 +120,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \ rank_prefix_matrix_copy, num_memset_int, expert_alignment, \ - buffer_ptrs, task_fifo_ptrs, head, rank); \ + buffer_ptrs, barrier_signal_ptrs, rank); \ break constexpr int kNumThreads = 128; @@ -139,11 +135,9 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe template __global__ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) { + void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { // A simplified version for cached handles - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); // Copy and clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); @@ -158,15 +152,15 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, __syncthreads(); // Barrier after cleaning - barrier_device(task_fifo_ptrs, head, rank); + barrier_block(barrier_signal_ptrs, rank); } void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int** task_fifo_ptrs, - int head, int rank, int num_ranks, cudaStream_t stream) { + void** buffer_ptrs, int** barrier_signal_ptrs, + int rank, int num_ranks, cudaStream_t stream) { #define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, cached_notify_dispatch, \ - rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \ + rank_prefix_matrix, num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \ break SETUP_LAUNCH_CONFIG(1, 128, stream); @@ -180,7 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void **buffer_ptrs, int rank, + void** buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); @@ -491,13 +485,11 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re template __global__ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank) { + int** barrier_signal_ptrs, int rank) { const auto sm_id = static_cast(blockIdx.x); if (sm_id == 0) { // Barrier before cleaning - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); // Clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); @@ -509,7 +501,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int __syncthreads(); // Barrier after cleaning - barrier_device(task_fifo_ptrs, head, rank); + barrier_block(barrier_signal_ptrs, rank); } else { const auto channel_id = sm_id - 1; const auto thread_id = static_cast(threadIdx.x); @@ -528,7 +520,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int int token_idx = token_idx_tail - lane_id, expected_head = 0; auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1; for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) { - head = __shfl_sync(0xffffffff, current_head, i); + const int head = __shfl_sync(0xffffffff, current_head, i); if (head < 0) { if (lane_id == i) expected_head = -last_head - 1; @@ -544,11 +536,11 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) { #define CACHED_NOTIFY_COMBINE(ranks) \ LAUNCH_KERNEL(&cfg, cached_notify_combine, \ - buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \ + buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, barrier_signal_ptrs, rank); \ break const int num_threads = std::max(128, 32 * num_ranks); @@ -566,7 +558,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, const dtype_t* x, const float* topk_weights, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void **buffer_ptrs, int rank, + void** buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index 7792953..79abdcd 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -12,13 +12,13 @@ namespace deep_ep { namespace intranode { template -__global__ void barrier(int** task_fifo_ptrs, int head, int rank) { - barrier_device(task_fifo_ptrs, head, rank); +__global__ void barrier(int** barrier_signal_ptrs, int rank) { + barrier_block(barrier_signal_ptrs, rank); } -void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) { #define BARRIER_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ + LAUNCH_KERNEL(&cfg, barrier, barrier_signal_ptrs, rank); \ break SETUP_LAUNCH_CONFIG(1, 32, stream); diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 5c4f2ba..7eec3cf 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -396,44 +396,32 @@ __forceinline__ __device__ int get_lane_id() { return lane_id; } -template -__forceinline__ __device__ void move_fifo_slots(int &head) { - head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS; -} - -template -__device__ __forceinline__ bool not_finished(int *task, int expected) { - auto result = false; - auto lane_id = threadIdx.x % 32; - if (lane_id < kNumRanks) - result = ld_volatile_global(task + lane_id) != expected; - return __any_sync(0xffffffff, result); -} - template __forceinline__ __device__ void -timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) { +barrier_block(int** barrier_signal_ptrs, int rank) { + auto thread_id = static_cast(threadIdx.x); + + // Add self-ranks, sub other ranks + if (thread_id < kNumRanks) { + atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); + memory_fence(); + atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); + } + EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); + + // Check timeout auto start_time = clock64(); - while (not_finished(task_fifo_ptrs[rank] + head, expected)) { - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) { - printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank); + while (true) { + auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0; + if (__all_sync(0xffffffff, value <= 0)) + break; + + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and get_lane_id() == 0) { + printf("DeepEP timeout check failed: rank = %d, thread = %d)\n", rank, thread_id); trap(); } } -} - -template -__forceinline__ __device__ void -barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) { - auto thread_id = static_cast(threadIdx.x); - EP_DEVICE_ASSERT(kNumRanks <= 32); - - if (thread_id < kNumRanks) { - atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG); - memory_fence(); - atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG); - } - timeout_check(task_fifo_ptrs, head, rank, 0, tag); + __syncthreads(); } } // namespace deep_ep