Support Ampere architecture (#204)

* Update README

* Update `setup.py`

* Fix headers

* Add `DISABLE_NVSHMEM` for APIs

* Fix launch

* Fix TMA settings

* Fix TMA usages

* Fix dlink

* Separate layout kernels

* Update version

* Add `is_sm90_compiled`

* Fix tests

* Add NVLink connection checks

* Update README

* Fix tests

* Add some comments

* Minor fix

* Minor fix

* Fix bugs
This commit is contained in:
Chenggang Zhao
2025-06-11 15:48:18 +08:00
committed by GitHub
parent dd13c7145c
commit b8d90fb753
16 changed files with 413 additions and 174 deletions

View File

@@ -11,10 +11,11 @@ function(add_deep_ep_library target_name source_file)
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
endfunction()
add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(runtime_cuda runtime.cu)
add_deep_ep_library(layout_cuda layout.cu)
add_deep_ep_library(intranode_cuda intranode.cu)
add_deep_ep_library(internode_cuda internode.cu)
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)

View File

@@ -28,6 +28,17 @@ void finalize();
} // namespace internode
// Layout kernels
namespace layout {
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream);
} // namespace layout
// Intranode kernels
namespace intranode {
@@ -69,12 +80,6 @@ namespace internode {
int get_source_meta_bytes();
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream);
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,

View File

@@ -37,11 +37,25 @@
#undef __CUDA_NO_BFLOAT162_OPERATORS__
#endif
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#ifndef DISABLE_SM90_FEATURES
#include <cuda_fp8.h>
#else
// Ampere does not support FP8 features
#define __NV_E4M3 0
#define __NV_E5M2 1
typedef int __nv_fp8_interpretation_t;
typedef int __nv_fp8x4_e4m3;
typedef uint8_t __nv_fp8_storage_t;
#endif
#ifndef DISABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmemx.h>
#include <infiniband/mlx5dv.h>
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
#include <device_host_transport/nvshmem_common_ibgda.h>
#endif

View File

@@ -11,131 +11,6 @@ namespace internode {
extern nvshmem_team_t cpu_rdma_team;
template<int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
__global__ void __launch_bounds__(kNumThreads, 1)
get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
// Count expert statistics
__shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
if (expert_begin_idx < expert_end_idx) {
// Per-thread count
#pragma unroll
for (int i = 0; i < kNumExpertsPerSM; ++ i)
num_tokens_per_expert_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
auto shifted_topk_idx = topk_idx + i * num_topk;
#pragma unroll
for (int j = 0, expert_idx; j < num_topk; ++ j) {
expert_idx = static_cast<int>(shifted_topk_idx[j]);
if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
}
}
__syncthreads();
// Sum up
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
if (expert_begin_idx + thread_id < expert_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_expert_per_thread[i][thread_id];
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
}
return;
}
if (num_tokens_per_rdma_rank != nullptr)
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
// Count rank statistics
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
if (rank_begin_idx < rank_end_idx) {
const auto num_expert_per_rank = num_experts / num_ranks;
auto expert_begin = rank_begin_idx * num_expert_per_rank;
auto expert_end = rank_end_idx * num_expert_per_rank;
// Per-thread count
#pragma unroll
for (int i = 0; i < kNumRanksPerSM; ++ i)
num_tokens_per_rank_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
auto shifted_topk_idx = topk_idx + i * num_topk;
int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
#pragma unroll
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
expert_idx = static_cast<int>(shifted_topk_idx[j]);
if (expert_begin <= expert_idx and expert_idx < expert_end) {
// Count single rank
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
}
}
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
#pragma unroll
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
}
#pragma unroll
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
}
__syncthreads();
// Sum up
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
if (rank_begin_idx + thread_id < rank_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_rank_per_thread[i][thread_id];
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
}
if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
}
}
}
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream) {
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
num_tokens, num_topk, num_ranks, num_experts);
}
struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits;

View File

@@ -227,6 +227,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
// TMA stuffs
#ifndef DISABLE_SM90_FEATURES
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto half_hidden_int4 = hidden_int4 / 2;
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
@@ -240,6 +241,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();
#endif
if (is_sender) {
// Workers for sending
@@ -399,6 +401,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
#ifndef DISABLE_SM90_FEATURES
#pragma unroll
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
tma_store_wait();
@@ -408,6 +411,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
}
__syncwarp();
#else
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
ld_nc_global, st_na_global);
#endif
}
// Copy `src_idx`
@@ -447,8 +454,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
}
// Make TMA store visible to the next kernel
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
#endif
}
@@ -473,12 +482,13 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 8192;
#ifndef DISABLE_SM90_FEATURES
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
#define DISPATCH_LAUNCH_CASE(ranks) { \
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
@@ -587,8 +597,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
// TMA stuffs
#ifndef DISABLE_SM90_FEATURES
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
#endif
if (is_sender) {
// Workers for sending
@@ -778,9 +790,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
}
// Wait shared memory release
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
__syncwarp();
#endif
// Reduce data with pipeline
constexpr int kNumStages = 8;
@@ -810,6 +824,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
for (int j = 0; j < kDtypePerInt4; ++ j)
out_dtypes[j] = static_cast<dtype_t>(values[j]);
#ifndef DISABLE_SM90_FEATURES
// Wait TMA arrival
if (lane_id == 0)
tma_store_wait<kNumStages - 1>();
@@ -828,6 +843,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
}
__syncwarp();
#else
recv_int4[token_idx * hidden_int4 + i] = out_int4;
#endif
}
// Reduce `topk_weights`
@@ -850,8 +868,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
warp_retired[recv_warp_id] = true;
// Make TMA store visible to the next kernel
#ifndef DISABLE_SM90_FEATURES
if (lane_id == 0)
tma_store_wait();
#endif
}
}
}
@@ -866,12 +886,13 @@ void combine(cudaDataType_t type,
int num_max_send_tokens, int num_recv_buffer_tokens) {
constexpr int kNumThreads = 768;
constexpr int kNumTMABytesPerWarp = 4096;
#ifndef DISABLE_SM90_FEATURES
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
#endif
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size; \
SET_SHARED_MEMORY_FOR_TMA(kernel); \
LAUNCH_KERNEL(&cfg, kernel, \
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
reinterpret_cast<const dtype*>(x), topk_weights, \

View File

@@ -1,8 +1,10 @@
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#ifndef SETUP_LAUNCH_CONFIG
#ifndef DISABLE_SM90_FEATURES
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
cudaLaunchAttribute attr[1]; \
@@ -10,10 +12,39 @@
attr[0].val.cooperative = 1; \
cfg.attrs = attr; \
cfg.numAttrs = 1
#else
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
int __num_sms = (sms); \
int __num_threads = (threads); \
auto __stream = (stream)
#endif
#endif
#ifndef LAUNCH_KERNEL
#ifndef DISABLE_SM90_FEATURES
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
#else
#define LAUNCH_KERNEL(config, kernel, ...) \
do { \
kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__); \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
EPException cuda_exception("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
fprintf(stderr, "%s\n", cuda_exception.what()); \
throw cuda_exception; \
} \
} while (0)
#endif
#endif
#ifndef SET_SHARED_MEMORY_FOR_TMA
#ifndef DISABLE_SM90_FEATURES
#define SET_SHARED_MEMORY_FOR_TMA(kernel) \
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
cfg.dynamicSmemBytes = smem_size;
#else
#define SET_SHARED_MEMORY_FOR_TMA(kernel) void()
#endif
#endif
#define SWITCH_RANKS(case_macro) \

136
csrc/kernels/layout.cu Normal file
View File

@@ -0,0 +1,136 @@
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
namespace deep_ep {
namespace layout {
template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
__global__ void __launch_bounds__(kNumThreads, 1)
get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
// Count expert statistics
__shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
if (expert_begin_idx < expert_end_idx) {
// Per-thread count
#pragma unroll
for (int i = 0; i < kNumExpertsPerSM; ++ i)
num_tokens_per_expert_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
auto shifted_topk_idx = topk_idx + i * num_topk;
#pragma unroll
for (int j = 0, expert_idx; j < num_topk; ++ j) {
expert_idx = static_cast<int>(shifted_topk_idx[j]);
if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
}
}
__syncthreads();
// Sum up
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
if (expert_begin_idx + thread_id < expert_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_expert_per_thread[i][thread_id];
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
}
return;
}
if (num_tokens_per_rdma_rank != nullptr)
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
// Count rank statistics
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
if (rank_begin_idx < rank_end_idx) {
const auto num_expert_per_rank = num_experts / num_ranks;
auto expert_begin = rank_begin_idx * num_expert_per_rank;
auto expert_end = rank_end_idx * num_expert_per_rank;
// Per-thread count
#pragma unroll
for (int i = 0; i < kNumRanksPerSM; ++ i)
num_tokens_per_rank_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
#pragma unroll
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
auto shifted_topk_idx = topk_idx + i * num_topk;
int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
#pragma unroll
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
expert_idx = static_cast<int>(shifted_topk_idx[j]);
if (expert_begin <= expert_idx and expert_idx < expert_end) {
// Count single rank
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
}
}
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
#pragma unroll
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
}
#pragma unroll
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
}
__syncthreads();
// Sum up
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
if (rank_begin_idx + thread_id < rank_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_rank_per_thread[i][thread_id];
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
}
if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
int sum = 0;
#pragma unroll
for (int i = 0; i < kNumThreads; ++ i)
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
}
}
}
void get_dispatch_layout(const int64_t* topk_idx,
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
int* num_tokens_per_expert, bool* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream) {
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
num_tokens, num_topk, num_ranks, num_experts);
}
} // namespace layout
} // namespace deep_ep

View File

@@ -5,7 +5,10 @@
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#ifndef DISABLE_NVSHMEM
#include "ibgda_device.cuh"
#endif
namespace deep_ep {
@@ -30,6 +33,7 @@ void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t st
namespace internode {
#ifndef DISABLE_NVSHMEM
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;
@@ -81,6 +85,7 @@ void finalize() {
}
nvshmem_finalize();
}
#endif
} // namespace internode

View File

@@ -266,6 +266,9 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES
__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}
@@ -327,6 +330,8 @@ __device__ __forceinline__ void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
}
#endif
template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;