mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fully remove barrier FIFO designs (#200)
* Fully remove FIFO slots * Fully remove FIFO buffers * Minor fix styles * Fix some typos * Bugs fixed * Cleanup `ibgda_poll_cq`
This commit is contained in:
@@ -208,7 +208,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
|
||||
const nvshmem_team_t rdma_team) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
@@ -219,17 +219,15 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
|
||||
if (sm_id == 0) {
|
||||
// Communication with others
|
||||
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
|
||||
// Global barrier: the first warp does intra-node sync, the second warp does internode sync
|
||||
EP_DEVICE_ASSERT(num_warps > 1);
|
||||
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
|
||||
if (thread_id == 32)
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
|
||||
// Send numbers of tokens per rank/expert to RDMA ranks
|
||||
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
||||
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
|
||||
auto rdma_recv_num_tokens_mixed = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
|
||||
|
||||
// Clean up for later data dispatch
|
||||
@@ -285,7 +283,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
auto nvl_recv_num_tokens_per_expert = AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
|
||||
|
||||
// Clean up for later data dispatch
|
||||
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes +
|
||||
nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int));
|
||||
#pragma unroll
|
||||
@@ -328,11 +326,9 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
}
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
|
||||
// Reduce number of tokens per rank/expert
|
||||
// Reduce the number of tokens per rank/expert
|
||||
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
|
||||
if (thread_id == 0) {
|
||||
int sum = 0;
|
||||
@@ -359,8 +355,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
__syncthreads();
|
||||
if (thread_id == 32)
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
} else {
|
||||
// Calculate meta data
|
||||
int dst_rdma_rank = sm_id - 1;
|
||||
@@ -423,7 +418,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank,
|
||||
int** barrier_signal_ptrs, int rank,
|
||||
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool low_latency_mode) {
|
||||
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
|
||||
@@ -439,7 +434,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
|
||||
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
|
||||
rdma_buffer_ptr, \
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank, \
|
||||
buffer_ptrs, barrier_signal_ptrs, rank, \
|
||||
cpu_rdma_team); } break
|
||||
|
||||
constexpr int kNumThreads = 512;
|
||||
@@ -698,7 +693,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Release sequential lock
|
||||
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
||||
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
|
||||
// NOTES: in case of splitting the issued put at the end of the buffer
|
||||
// NOTES: in case of splitting, the issued put at the end of the buffer
|
||||
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
|
||||
|
||||
// Synchronize shared memory
|
||||
@@ -861,7 +856,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
|
||||
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
|
||||
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
|
||||
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(static_cast<int8_t*>(shifted) + hidden_bytes));
|
||||
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
|
||||
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
|
||||
if (lane_id == src_rdma_rank) {
|
||||
@@ -881,27 +876,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
|
||||
reinterpret_cast<int4*>(shifted),
|
||||
ld_nc_global, st_na_global);
|
||||
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
|
||||
shifted = static_cast<int4*>(shifted) + hidden_int4;
|
||||
|
||||
// Copy source meta
|
||||
if (lane_id == 0)
|
||||
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
|
||||
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
|
||||
shifted = static_cast<SourceMeta*>(shifted) + 1;
|
||||
|
||||
// Copy `x_scales`
|
||||
UNROLLED_WARP_COPY(1, lane_id, num_scales,
|
||||
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
|
||||
reinterpret_cast<float*>(shifted),
|
||||
ld_nc_global, st_na_global);
|
||||
shifted = reinterpret_cast<float*>(shifted) + num_scales;
|
||||
shifted = static_cast<float*>(shifted) + num_scales;
|
||||
|
||||
// Copy `topk_idx` and `topk_weights`
|
||||
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
|
||||
if (lane_id < num_topk) {
|
||||
// Read
|
||||
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
|
||||
shifted = reinterpret_cast<int*>(shifted) + num_topk;
|
||||
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
|
||||
auto idx_value = ld_nc_global(static_cast<int*>(shifted) + lane_id);
|
||||
shifted = static_cast<int*>(shifted) + num_topk;
|
||||
auto weight_value = ld_nc_global(static_cast<float*>(shifted) + lane_id);
|
||||
|
||||
// Transform and write
|
||||
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
|
||||
@@ -1107,7 +1102,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
int* combined_rdma_head, int num_combined_tokens, int num_channels,
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
|
||||
bool is_cached_dispatch, const nvshmem_team_t rdma_team) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@@ -1127,7 +1122,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
__syncthreads();
|
||||
|
||||
// Clean
|
||||
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
|
||||
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
|
||||
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
|
||||
@@ -1138,12 +1133,10 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
} else if (sm_id == 1) {
|
||||
// Barrier for NVL
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
|
||||
// Clean
|
||||
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
|
||||
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
|
||||
@@ -1151,8 +1144,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
__syncthreads();
|
||||
|
||||
// Barrier again
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
|
||||
} else if (sm_id == 2) {
|
||||
if (is_cached_dispatch)
|
||||
return;
|
||||
@@ -1213,7 +1205,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int** barrier_signal_ptrs, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool is_cached_dispatch, bool low_latency_mode) {
|
||||
const int num_threads = std::max(128, 32 * num_channels);
|
||||
@@ -1237,7 +1229,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
|
||||
combined_rdma_head, num_combined_tokens, num_channels,
|
||||
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
|
||||
rdma_buffer_ptr,
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks,
|
||||
buffer_ptrs, barrier_signal_ptrs, rank, num_ranks,
|
||||
is_cached_dispatch, cpu_rdma_team);
|
||||
}
|
||||
|
||||
@@ -1397,7 +1389,7 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
|
||||
break;
|
||||
|
||||
// Decide next RDMA buffer to send
|
||||
// Decide the next RDMA buffer to send
|
||||
bool is_lane_ready = false;
|
||||
auto start_time = clock64();
|
||||
while (true) {
|
||||
@@ -1575,8 +1567,8 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
|
||||
expected_head, lane_id,
|
||||
hidden_int4, num_topk,
|
||||
reinterpret_cast<int4*>(shifted),
|
||||
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
||||
static_cast<int4*>(shifted),
|
||||
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
|
||||
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
|
||||
|
||||
// Update head
|
||||
|
||||
Reference in New Issue
Block a user