mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-14 10:18:47 +00:00
Fully remove barrier FIFO designs (#200)
* Fully remove FIFO slots * Fully remove FIFO buffers * Minor fix styles * Fix some typos * Bugs fixed * Cleanup `ibgda_poll_cq`
This commit is contained in:
parent
a16af40531
commit
8da2d7b38d
@ -18,10 +18,10 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
||||
num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes),
|
||||
low_latency_mode(low_latency_mode),
|
||||
comm_stream(at::cuda::getStreamFromPool(true)) {
|
||||
// Task fifo memory
|
||||
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
|
||||
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
|
||||
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;
|
||||
// Metadata memory
|
||||
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
|
||||
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
|
||||
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);
|
||||
|
||||
// Common checks
|
||||
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
|
||||
@ -41,18 +41,17 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
|
||||
|
||||
if (num_nvl_bytes > 0) {
|
||||
// Local IPC: alloc local memory and set local IPC handle
|
||||
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
|
||||
// Local IPC: alloc local memory and set local IPC handles
|
||||
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));
|
||||
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
|
||||
buffer_ptrs_gpu = reinterpret_cast<void**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes);
|
||||
buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes);
|
||||
|
||||
// Set task fifo
|
||||
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
|
||||
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
|
||||
task_fifo_ptrs_gpu = reinterpret_cast<int**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes);
|
||||
// Set barrier signals
|
||||
barrier_signal_ptrs[nvl_rank] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
|
||||
barrier_signal_ptrs_gpu = reinterpret_cast<int**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes);
|
||||
|
||||
// No need to synchronize, will do a full device sync during `sync`
|
||||
CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream));
|
||||
CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
|
||||
}
|
||||
|
||||
// Create 32 MiB workspace
|
||||
@ -91,8 +90,7 @@ Buffer::~Buffer() noexcept(false) {
|
||||
|
||||
if (num_nvl_bytes > 0) {
|
||||
// Barrier
|
||||
intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
|
||||
move_fifo_slots();
|
||||
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Close remote IPC
|
||||
@ -121,10 +119,6 @@ Buffer::~Buffer() noexcept(false) {
|
||||
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
|
||||
}
|
||||
|
||||
void Buffer::move_fifo_slots(int num_slots) {
|
||||
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
|
||||
}
|
||||
|
||||
bool Buffer::is_available() const {
|
||||
return available;
|
||||
}
|
||||
@ -162,7 +156,7 @@ pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
|
||||
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {
|
||||
torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);
|
||||
auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));
|
||||
auto base_ptr = reinterpret_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
|
||||
auto base_ptr = static_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
|
||||
auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;
|
||||
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
|
||||
}
|
||||
@ -183,15 +177,15 @@ void Buffer::sync(const std::vector<int> &device_ids,
|
||||
if (offset + i != rank) {
|
||||
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
|
||||
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
|
||||
task_fifo_ptrs[i] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
|
||||
barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
|
||||
} else {
|
||||
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy all buffer and task pointers to GPU
|
||||
// Copy all buffer and barrier signal pointers to GPU
|
||||
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
@ -395,9 +389,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
|
||||
// Copy rank prefix matrix and clean flags
|
||||
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(), num_memset_int,
|
||||
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks,
|
||||
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks,
|
||||
comm_stream);
|
||||
move_fifo_slots(2);
|
||||
} else {
|
||||
rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
|
||||
channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
|
||||
@ -416,9 +409,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
num_tokens, is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
|
||||
rank_prefix_matrix.data_ptr<int>(),
|
||||
num_memset_int, expert_alignment,
|
||||
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank,
|
||||
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
|
||||
comm_stream, num_channels);
|
||||
move_fifo_slots(3);
|
||||
|
||||
// Synchronize total received tokens and tokens per expert
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
@ -565,12 +557,9 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);
|
||||
intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr<int>(),
|
||||
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
|
||||
task_fifo_ptrs_gpu, head, rank, num_ranks,
|
||||
barrier_signal_ptrs_gpu, rank, num_ranks,
|
||||
comm_stream);
|
||||
|
||||
// NOTES: this function uses two FIFO slots (barrier before and after)
|
||||
move_fifo_slots(2);
|
||||
|
||||
// Combine data
|
||||
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
|
||||
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
|
||||
@ -746,10 +735,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
nullptr, nullptr, nullptr,
|
||||
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
|
||||
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
|
||||
task_fifo_ptrs_gpu, head, rank, comm_stream,
|
||||
barrier_signal_ptrs_gpu, rank, comm_stream,
|
||||
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
||||
num_nvl_bytes, true, low_latency_mode);
|
||||
move_fifo_slots(2);
|
||||
} else {
|
||||
rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
|
||||
recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
|
||||
@ -769,10 +757,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
|
||||
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
|
||||
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
|
||||
task_fifo_ptrs_gpu, head, rank, comm_stream,
|
||||
barrier_signal_ptrs_gpu, rank, comm_stream,
|
||||
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
||||
num_nvl_bytes, low_latency_mode);
|
||||
move_fifo_slots(3);
|
||||
|
||||
// Synchronize total received tokens and tokens per expert
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
@ -958,10 +945,9 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
|
||||
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
|
||||
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
|
||||
task_fifo_ptrs_gpu, head, rank, comm_stream,
|
||||
barrier_signal_ptrs_gpu, rank, comm_stream,
|
||||
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
|
||||
num_nvl_bytes, false, low_latency_mode);
|
||||
move_fifo_slots(2);
|
||||
|
||||
// Launch data combine
|
||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
|
@ -51,10 +51,9 @@ private:
|
||||
// After IPC/NVSHMEM synchronization, this flag will be true
|
||||
bool available = false;
|
||||
|
||||
// Task fifo
|
||||
int head = 0;
|
||||
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
|
||||
int** task_fifo_ptrs_gpu = nullptr;
|
||||
// Barrier signals
|
||||
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
|
||||
int** barrier_signal_ptrs_gpu = nullptr;
|
||||
|
||||
// Workspace
|
||||
void* workspace = nullptr;
|
||||
@ -75,9 +74,6 @@ private:
|
||||
volatile int* low_latency_usage_flag = nullptr;
|
||||
int* low_latency_usage_flag_mapped = nullptr;
|
||||
|
||||
private:
|
||||
void move_fifo_slots(int num_slots = 1);
|
||||
|
||||
public:
|
||||
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
|
||||
|
||||
|
@ -7,7 +7,7 @@ namespace deep_ep {
|
||||
// Intranode runtime
|
||||
namespace intranode {
|
||||
|
||||
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
|
||||
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
@ -35,11 +35,11 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
|
||||
cudaStream_t stream, int num_sms);
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
@ -51,7 +51,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
|
||||
int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
@ -84,7 +84,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank,
|
||||
int** barrier_signal_ptrs, int rank,
|
||||
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool low_latency_mode);
|
||||
|
||||
@ -106,7 +106,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int** barrier_signal_ptrs, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool is_cached_dispatch, bool low_latency_mode);
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#define NUM_MAX_NVL_PEERS 8
|
||||
#define NUM_MAX_RDMA_PEERS 20
|
||||
#define NUM_MAX_FIFO_SLOTS 32768
|
||||
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
|
||||
#define NUM_MAX_LOCAL_EXPERTS 1024
|
||||
#define NUM_BUFFER_ALIGNMENT_BYTES 128
|
||||
|
@ -164,7 +164,7 @@ void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
|
||||
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
|
||||
__be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) {
|
||||
ibgda_ctrl_seg_t ctrl_seg;
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_inl_data_seg inl_seg;
|
||||
@ -277,7 +277,7 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
|
||||
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
|
||||
void **out_wqes) {
|
||||
void** out_wqes) {
|
||||
ibgda_ctrl_seg_t ctrl_seg;
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_data_seg data_seg;
|
||||
@ -373,7 +373,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
nvshmemi_ibgda_device_qp_t *qp, const int &value,
|
||||
uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey,
|
||||
uint16_t wqe_idx, void **out_wqes) {
|
||||
uint16_t wqe_idx, void** out_wqes) {
|
||||
ibgda_ctrl_seg_t ctrl_seg = {0};
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_atomic_seg atomic_seg_1;
|
||||
@ -448,40 +448,27 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
|
||||
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
|
||||
// Note that this implementation does not guarantee thread safety,
|
||||
// so we must ensure that no other threads are concurrently using the same QP.
|
||||
__device__ static __forceinline__ int
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) {
|
||||
int status = 0;
|
||||
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
|
||||
|
||||
const auto cqe64 = static_cast<mlx5_cqe64*>(cq->cqe);
|
||||
const uint32_t ncqes = cq->ncqes;
|
||||
memory_fence_cta();
|
||||
|
||||
// NOTES: this while loop is part of do-while below.
|
||||
// `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`.
|
||||
// To be able to compare with the index, we need to use `wqe_counter + 1`.
|
||||
// Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for
|
||||
// sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than
|
||||
// idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1`
|
||||
// That's why we use `- 2` here to make this case overflow.
|
||||
uint16_t wqe_counter;
|
||||
uint16_t new_wqe_counter;
|
||||
|
||||
memory_fence_cta();
|
||||
|
||||
do {
|
||||
new_wqe_counter = ld_na_relaxed(&cqe64->wqe_counter);
|
||||
new_wqe_counter = HtoBE16(new_wqe_counter);
|
||||
wqe_counter = new_wqe_counter;
|
||||
}
|
||||
// NOTE: This while loop is part of do while above.
|
||||
// wqe_counter is the HW consumer index. However, we always maintain index
|
||||
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
|
||||
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
|
||||
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
|
||||
// idx, and thus we need to wait. We don't need to wait when idx ==
|
||||
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
|
||||
// wraparound.
|
||||
// Example:
|
||||
// if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes.
|
||||
while (((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes));
|
||||
|
||||
wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter));
|
||||
} while ((static_cast<uint16_t>(static_cast<uint16_t>(idx) - wqe_counter - static_cast<uint16_t>(2)) < ncqes));
|
||||
*cq->cons_idx = idx;
|
||||
// Prevent reordering of this function and subsequent instructions
|
||||
memory_fence_cta();
|
||||
|
||||
return status;
|
||||
// Prevent reordering of this function and later instructions
|
||||
memory_fence_cta();
|
||||
}
|
||||
|
||||
// Wait until wqe `idx - 1` is completed.
|
||||
|
@ -208,7 +208,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
|
||||
const nvshmem_team_t rdma_team) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@ -219,17 +219,15 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
|
||||
if (sm_id == 0) {
|
||||
// Communication with others
|
||||
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
|
||||
// Global barrier: the first warp does intra-node sync, the second warp does internode sync
|
||||
EP_DEVICE_ASSERT(num_warps > 1);
|
||||
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
|
||||
if (thread_id == 32)
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
|
||||
// Send numbers of tokens per rank/expert to RDMA ranks
|
||||
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
||||
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
|
||||
auto rdma_recv_num_tokens_mixed = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
|
||||
|
||||
// Clean up for later data dispatch
|
||||
@ -285,7 +283,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
auto nvl_recv_num_tokens_per_expert = AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
|
||||
|
||||
// Clean up for later data dispatch
|
||||
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes +
|
||||
nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int));
|
||||
#pragma unroll
|
||||
@ -328,11 +326,9 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
}
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
|
||||
// Reduce number of tokens per rank/expert
|
||||
// Reduce the number of tokens per rank/expert
|
||||
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
|
||||
if (thread_id == 0) {
|
||||
int sum = 0;
|
||||
@ -359,8 +355,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
__syncthreads();
|
||||
if (thread_id == 32)
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
} else {
|
||||
// Calculate meta data
|
||||
int dst_rdma_rank = sm_id - 1;
|
||||
@ -423,7 +418,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank,
|
||||
int** barrier_signal_ptrs, int rank,
|
||||
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool low_latency_mode) {
|
||||
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
||||
@ -439,7 +434,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
rdma_buffer_ptr, \
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank, \
|
||||
buffer_ptrs, barrier_signal_ptrs, rank, \
|
||||
cpu_rdma_team); } break
|
||||
|
||||
constexpr int kNumThreads = 512;
|
||||
@ -698,7 +693,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Release sequential lock
|
||||
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
||||
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
|
||||
// NOTES: in case of splitting the issued put at the end of the buffer
|
||||
// NOTES: in case of splitting, the issued put at the end of the buffer
|
||||
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
|
||||
|
||||
// Synchronize shared memory
|
||||
@ -861,7 +856,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
|
||||
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
|
||||
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
|
||||
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(static_cast<int8_t*>(shifted) + hidden_bytes));
|
||||
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
|
||||
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
|
||||
if (lane_id == src_rdma_rank) {
|
||||
@ -881,27 +876,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
|
||||
reinterpret_cast<int4*>(shifted),
|
||||
ld_nc_global, st_na_global);
|
||||
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
|
||||
shifted = static_cast<int4*>(shifted) + hidden_int4;
|
||||
|
||||
// Copy source meta
|
||||
if (lane_id == 0)
|
||||
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
|
||||
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
|
||||
shifted = static_cast<SourceMeta*>(shifted) + 1;
|
||||
|
||||
// Copy `x_scales`
|
||||
UNROLLED_WARP_COPY(1, lane_id, num_scales,
|
||||
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
|
||||
reinterpret_cast<float*>(shifted),
|
||||
ld_nc_global, st_na_global);
|
||||
shifted = reinterpret_cast<float*>(shifted) + num_scales;
|
||||
shifted = static_cast<float*>(shifted) + num_scales;
|
||||
|
||||
// Copy `topk_idx` and `topk_weights`
|
||||
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
|
||||
if (lane_id < num_topk) {
|
||||
// Read
|
||||
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
|
||||
shifted = reinterpret_cast<int*>(shifted) + num_topk;
|
||||
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
|
||||
auto idx_value = ld_nc_global(static_cast<int*>(shifted) + lane_id);
|
||||
shifted = static_cast<int*>(shifted) + num_topk;
|
||||
auto weight_value = ld_nc_global(static_cast<float*>(shifted) + lane_id);
|
||||
|
||||
// Transform and write
|
||||
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
|
||||
@ -1107,7 +1102,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
int* combined_rdma_head, int num_combined_tokens, int num_channels,
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
|
||||
bool is_cached_dispatch, const nvshmem_team_t rdma_team) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@ -1127,7 +1122,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
__syncthreads();
|
||||
|
||||
// Clean
|
||||
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
||||
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
|
||||
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
|
||||
@ -1138,12 +1133,10 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
} else if (sm_id == 1) {
|
||||
// Barrier for NVL
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
|
||||
// Clean
|
||||
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
|
||||
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
|
||||
@ -1151,8 +1144,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
__syncthreads();
|
||||
|
||||
// Barrier again
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
} else if (sm_id == 2) {
|
||||
if (is_cached_dispatch)
|
||||
return;
|
||||
@ -1213,7 +1205,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int** barrier_signal_ptrs, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool is_cached_dispatch, bool low_latency_mode) {
|
||||
const int num_threads = std::max(128, 32 * num_channels);
|
||||
@ -1237,7 +1229,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
combined_rdma_head, num_combined_tokens, num_channels,
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
|
||||
rdma_buffer_ptr,
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks,
|
||||
buffer_ptrs, barrier_signal_ptrs, rank, num_ranks,
|
||||
is_cached_dispatch, cpu_rdma_team);
|
||||
}
|
||||
|
||||
@ -1397,7 +1389,7 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
|
||||
break;
|
||||
|
||||
// Decide next RDMA buffer to send
|
||||
// Decide the next RDMA buffer to send
|
||||
bool is_lane_ready = false;
|
||||
auto start_time = clock64();
|
||||
while (true) {
|
||||
@ -1575,8 +1567,8 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
||||
expected_head, lane_id,
|
||||
hidden_int4, num_topk,
|
||||
reinterpret_cast<int4*>(shifted),
|
||||
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
||||
static_cast<int4*>(shifted),
|
||||
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
||||
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
||||
|
||||
// Update head
|
||||
|
@ -22,7 +22,7 @@ __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
|
||||
clean_1[i] = 0;
|
||||
|
||||
// Barrier after cleaning (make sure low-latency mode work fine)
|
||||
// Barrier after cleaning (make sure the low-latency mode works fine)
|
||||
nvshmemx_barrier_all_block();
|
||||
}
|
||||
|
||||
|
@ -14,16 +14,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
|
||||
|
||||
if (sm_id == 0) {
|
||||
// Barrier first
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
int *per_rank_buffer, *per_expert_buffer;
|
||||
if (thread_id < kNumRanks) {
|
||||
@ -46,9 +44,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
__syncthreads();
|
||||
|
||||
// Wait for all ranks to be finished
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
// Sum per-rank counts and return to CPU
|
||||
// Also pre-compute the prefix sum for data sending
|
||||
@ -86,7 +82,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
// Barrier
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
} else {
|
||||
int dst_rank = sm_id - 1;
|
||||
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
|
||||
@ -116,7 +112,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
|
||||
cudaStream_t stream, int num_channels) {
|
||||
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
|
||||
@ -124,7 +120,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
|
||||
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
|
||||
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
buffer_ptrs, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
constexpr int kNumThreads = 128;
|
||||
@ -139,11 +135,9 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
|
||||
// A simplified version for cached handles
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
// Copy and clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
@ -158,15 +152,15 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
}
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs,
|
||||
int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs,
|
||||
int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
|
||||
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
rank_prefix_matrix, num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 128, stream);
|
||||
@ -180,7 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void **buffer_ptrs, int rank,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
|
||||
@ -491,13 +485,11 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank) {
|
||||
int** barrier_signal_ptrs, int rank) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
if (sm_id == 0) {
|
||||
// Barrier before cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
@ -509,7 +501,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
} else {
|
||||
const auto channel_id = sm_id - 1;
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@ -528,7 +520,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
int token_idx = token_idx_tail - lane_id, expected_head = 0;
|
||||
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
|
||||
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
|
||||
head = __shfl_sync(0xffffffff, current_head, i);
|
||||
const int head = __shfl_sync(0xffffffff, current_head, i);
|
||||
if (head < 0) {
|
||||
if (lane_id == i)
|
||||
expected_head = -last_head - 1;
|
||||
@ -544,11 +536,11 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
int** barrier_signal_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_COMBINE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
|
||||
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
|
||||
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
const int num_threads = std::max(128, 32 * num_ranks);
|
||||
@ -566,7 +558,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void **buffer_ptrs, int rank,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
|
@ -12,13 +12,13 @@ namespace deep_ep {
|
||||
namespace intranode {
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
__global__ void barrier(int** barrier_signal_ptrs, int rank) {
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
}
|
||||
|
||||
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define BARRIER_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
|
||||
LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 32, stream);
|
||||
|
@ -396,44 +396,32 @@ __forceinline__ __device__ int get_lane_id() {
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void move_fifo_slots(int &head) {
|
||||
head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS;
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__device__ __forceinline__ bool not_finished(int *task, int expected) {
|
||||
auto result = false;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
if (lane_id < kNumRanks)
|
||||
result = ld_volatile_global(task + lane_id) != expected;
|
||||
return __any_sync(0xffffffff, result);
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) {
|
||||
barrier_block(int** barrier_signal_ptrs, int rank) {
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
|
||||
// Add self-ranks, sub other ranks
|
||||
if (thread_id < kNumRanks) {
|
||||
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
|
||||
memory_fence();
|
||||
atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
|
||||
}
|
||||
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
|
||||
|
||||
// Check timeout
|
||||
auto start_time = clock64();
|
||||
while (not_finished<kNumRanks>(task_fifo_ptrs[rank] + head, expected)) {
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) {
|
||||
printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank);
|
||||
while (true) {
|
||||
auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0;
|
||||
if (__all_sync(0xffffffff, value <= 0))
|
||||
break;
|
||||
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and get_lane_id() == 0) {
|
||||
printf("DeepEP timeout check failed: rank = %d, thread = %d)\n", rank, thread_id);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
|
||||
if (thread_id < kNumRanks) {
|
||||
atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG);
|
||||
memory_fence();
|
||||
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
|
||||
}
|
||||
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
Loading…
Reference in New Issue
Block a user