mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Merge remote-tracking branch 'origin/main' into internode-tma
# Conflicts: # csrc/kernels/configs.cuh # csrc/kernels/internode.cu
This commit is contained in:
@@ -35,7 +35,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
||||
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
|
||||
#ifdef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation");
|
||||
#endif
|
||||
|
||||
// Get device info
|
||||
@@ -151,7 +151,7 @@ pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
|
||||
auto unique_id = internode::get_unique_id();
|
||||
return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -895,7 +895,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta, send_rdma_head, send_nvl_head, event};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
@@ -1016,7 +1016,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
// Return values
|
||||
return {combined_x, combined_topk_weights, event};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
@@ -1040,7 +1040,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
clean_meta_1.first, clean_meta_1.second,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1149,7 +1149,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
// Return values
|
||||
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
@@ -1242,7 +1242,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
// Return values
|
||||
return {combined_x, event, recv_hook};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
@@ -1262,7 +1262,7 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank
|
||||
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
|
||||
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#define FINISHED_SUM_TAG 1024
|
||||
#define NUM_WAIT_NANOSECONDS 500
|
||||
|
||||
#ifndef ENABLE_FAST_DEBUG
|
||||
#define NUM_CPU_TIMEOUT_SECS 100
|
||||
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
|
||||
|
||||
@@ -421,9 +421,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
|
||||
|
||||
// RDMA sender warp synchronization
|
||||
__shared__ volatile int rdma_send_next_token_idx;
|
||||
__shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
|
||||
__shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
|
||||
// NOTES: `rdma_send_channel_tail` means the latest released tail
|
||||
// NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status
|
||||
__shared__ int rdma_send_channel_lock[kNumRDMARanks];
|
||||
__shared__ int rdma_send_channel_tail[kNumRDMARanks];
|
||||
__shared__ uint32_t rdma_send_channel_window[kNumRDMARanks];
|
||||
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
|
||||
|
||||
// Forward warp synchronization
|
||||
@@ -437,12 +439,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
||||
|
||||
// Clean shared memory
|
||||
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
||||
(warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0;
|
||||
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
|
||||
(warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0;
|
||||
|
||||
// Send number of tokens in this channel by `-value - 1`
|
||||
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers");
|
||||
for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
|
||||
@@ -471,24 +467,33 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Iterate over tokens and copy into buffer
|
||||
int64_t token_idx;
|
||||
int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
|
||||
int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0;
|
||||
auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
|
||||
for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
|
||||
for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) {
|
||||
// Read RDMA rank existence
|
||||
uint64_t is_token_in_rank_uint64 = 0;
|
||||
if (lane_id < kNumRDMARanks)
|
||||
is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
|
||||
|
||||
// Acquire sequential lock
|
||||
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
|
||||
if (lane_id < kNumRDMARanks) {
|
||||
is_token_in_rank_uint64 = __ldg(reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS));
|
||||
global_rdma_tail_idx += (is_token_in_rank_uint64 != 0);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Acquire next tail
|
||||
int rdma_tail_idx = -1;
|
||||
if (is_token_in_rank_uint64 != 0) {
|
||||
rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++;
|
||||
while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens)
|
||||
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
// Skip the token which does not belong to this warp
|
||||
if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id)
|
||||
continue;
|
||||
auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1;
|
||||
|
||||
// Wait the remote buffer to be released
|
||||
auto start_time = clock64();
|
||||
while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
|
||||
cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -496,14 +501,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (lane_id < kNumRDMARanks and not kCachedMode)
|
||||
send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
|
||||
|
||||
// Update last token tail
|
||||
if (last_rdma_tail_idx >= 0)
|
||||
st_release_cta(const_cast<const int *>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
||||
last_rdma_tail_idx = rdma_tail_idx;
|
||||
|
||||
// Release sequential lock
|
||||
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
||||
|
||||
// Broadcast tails
|
||||
SourceMeta src_meta;
|
||||
int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
|
||||
@@ -560,24 +557,45 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
|
||||
st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Release the transaction in the window
|
||||
if (is_token_in_rank_uint64 != 0) {
|
||||
// Acquire lock first
|
||||
acquire_lock(rdma_send_channel_lock + lane_id);
|
||||
|
||||
// Release the transaction slot
|
||||
auto window = rdma_send_channel_window[lane_id];
|
||||
auto latest_tail = rdma_send_channel_tail[lane_id];
|
||||
auto offset = rdma_tail_idx - latest_tail;
|
||||
|
||||
// The same effect with `EP_DEVICE_ASSERT(offset < 32);`
|
||||
EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps");
|
||||
|
||||
// Erase bit and move the ones if possible
|
||||
window ^= 1u << offset;
|
||||
if (offset == 0) {
|
||||
auto num_empty_slots = __ffs(~window) - 1;
|
||||
st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots);
|
||||
window >>= num_empty_slots;
|
||||
}
|
||||
rdma_send_channel_window[lane_id] = window;
|
||||
|
||||
// Release lock
|
||||
release_lock(rdma_send_channel_lock + lane_id);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
// Acquire sequential lock
|
||||
while (lane_id == 0 and rdma_send_next_token_idx != token_idx);
|
||||
__syncwarp();
|
||||
|
||||
// Update last token tail
|
||||
if (last_rdma_tail_idx >= 0)
|
||||
st_release_cta(const_cast<const int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
|
||||
__syncwarp();
|
||||
|
||||
// Release sequential lock
|
||||
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
|
||||
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
|
||||
// NOTES: in case of splitting, the issued put at the end of the buffer
|
||||
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
|
||||
|
||||
// Clean shared memory
|
||||
EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks");
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0;
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0;
|
||||
(lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0;
|
||||
|
||||
// Synchronize shared memory
|
||||
sync_rdma_sender_smem();
|
||||
|
||||
@@ -595,10 +613,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
|
||||
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, tokens to send: %d\n",
|
||||
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send);
|
||||
trap();
|
||||
}
|
||||
|
||||
// TODO: try thread-level `put_nbi`?
|
||||
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
|
||||
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
|
||||
int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;
|
||||
@@ -606,9 +626,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (synced_num_tokens_to_send == 0)
|
||||
continue;
|
||||
|
||||
// Read progress
|
||||
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
|
||||
// Read the latest progress
|
||||
// NOTES: `rdma_send_channel_tail` does not need to be protected by lock
|
||||
auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank)), 0);
|
||||
auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank);
|
||||
auto num_tokens_processed = processed_tail - synced_last_issued_tail;
|
||||
if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens)
|
||||
continue;
|
||||
@@ -628,9 +649,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Lighter fence for local RDMA rank
|
||||
memory_fence();
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Update tails
|
||||
__syncwarp();
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
last_issued_tail += num_tokens_to_issue;
|
||||
num_tokens_to_send -= num_tokens_to_issue;
|
||||
|
||||
@@ -498,7 +498,7 @@ combine(void* combined_x,
|
||||
}
|
||||
cg::this_grid().sync();
|
||||
|
||||
// Reduce tokens with FP8 cast
|
||||
// Reduce tokens
|
||||
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
|
||||
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
|
||||
if (thread_id < hidden_bf16_int4) {
|
||||
|
||||
@@ -618,8 +618,8 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_rank_id = thread_id / num_threads_per_rank;
|
||||
const auto send_warp_id_in_rank = send_warp_id % num_send_warps_per_rank;
|
||||
const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks;
|
||||
const auto send_warp_id_in_rank = send_warp_id / kNumRanks;
|
||||
EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count");
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
@@ -777,7 +777,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
|
||||
|
||||
auto start_time = clock64();
|
||||
while (channel_tail_idx[lane_id] <= expected_head and expected_head >= 0) {
|
||||
while (__any_sync(0xffffffff, channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
|
||||
|
||||
@@ -80,6 +80,7 @@ cfg.dynamicSmemBytes = smem_size;
|
||||
|
||||
#define SWITCH_HIDDEN(case_macro) \
|
||||
switch (hidden) { \
|
||||
case 2048: case_macro(2048); \
|
||||
case 2560: case_macro(2560); \
|
||||
case 4096: case_macro(4096); \
|
||||
case 5120: case_macro(5120); \
|
||||
|
||||
@@ -466,4 +466,26 @@ barrier_block(int** barrier_signal_ptrs, int rank) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) {
|
||||
int ret;
|
||||
asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory");
|
||||
return ret;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) {
|
||||
int ret;
|
||||
asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory");
|
||||
return ret;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void acquire_lock(int* mutex) {
|
||||
// To make later memory operations valid, we must use `acquire` for memory semantics
|
||||
while (atomic_cas_cta_acquire(mutex, 0, 1) != 0);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void release_lock(int* mutex) {
|
||||
// To make previous memory operations visible to other threads, we must use `release` for memory semantics
|
||||
atomic_exch_cta_release(mutex, 0);
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
||||
@@ -184,9 +184,9 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
|
||||
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
|
||||
print('', flush=True)
|
||||
|
||||
# Gather the best config from rank 0 and the first test setting
|
||||
@@ -215,12 +215,12 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
|
||||
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
|
||||
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), avg_t: {t * 1e6:.2f} us', flush=True)
|
||||
if t < best_time and nvl_chunk_size > 0:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
||||
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
|
||||
print('', flush=True)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
@@ -14,12 +15,17 @@ def init_dist(local_rank: int, num_local_ranks: int):
|
||||
node_rank = int(os.getenv('RANK', 0))
|
||||
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
||||
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=f'tcp://{ip}:{port}',
|
||||
world_size=num_nodes * num_local_ranks,
|
||||
rank=node_rank * num_local_ranks + local_rank
|
||||
)
|
||||
sig = inspect.signature(dist.init_process_group)
|
||||
params = {
|
||||
'backend': 'nccl',
|
||||
'init_method': f'tcp://{ip}:{port}',
|
||||
'world_size': num_nodes * num_local_ranks,
|
||||
'rank': node_rank * num_local_ranks + local_rank,
|
||||
}
|
||||
if 'device_id' in sig.parameters:
|
||||
# noinspection PyTypeChecker
|
||||
params['device_id'] = torch.device(f'cuda:{local_rank}')
|
||||
dist.init_process_group(**params)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device('cuda')
|
||||
torch.cuda.set_device(local_rank)
|
||||
@@ -74,7 +80,7 @@ def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_gro
|
||||
return (scores * mask).view(num_tokens, num_experts)
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
|
||||
def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None):
|
||||
# Flush L2 cache with 256 MB data
|
||||
torch.cuda.synchronize()
|
||||
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
|
||||
Reference in New Issue
Block a user