DeepEP/csrc/kernels/runtime.cu
moningchen 5ab80c28f3 In the Internode Normal Kernel, when using nvshmem ibrc for RDMA data transmission, a single QP is used for data transfer between two GPUs, which limits kernel performance in network card dual-port and RoCE network scenarios.
In our optimized Internode Normal Kernel, we implemented multiple QPs for data transmission between two GPUs, setting a different QP for each channel. Additionally, we modified the transmission method from IBRC to IBGAD.

Through these optimizations, the Internode Normal Kernel achieves optimal performance in both H800 and H20 environments, with RDMA transmission performance nearly reaching the physical network performance limit. Using the current default statistical method, in 4-node H800 and H20 environments, RDMA performance can reach 60GB/s+.
2025-04-21 15:50:39 +08:00

97 lines
3.0 KiB
Plaintext

#include <vector>
#include <cstring>
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
#include "ibgda_device.cuh"
namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
}
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
break
SETUP_LAUNCH_CONFIG(1, 32, stream);
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
#undef BARRIER_LAUNCH_CASE
}
} // namespace intranode
namespace internode {
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
nvshmem_team_config_t cpu_rdma_team_config;
std::vector<uint8_t> get_unique_id() {
nvshmemx_uniqueid_t unique_id;
nvshmemx_get_uniqueid(&unique_id);
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
return result;
}
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
nvshmemx_uniqueid_t root_unique_id;
nvshmemx_init_attr_t attr;
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
// Create sub-RDMA teams
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}
// Normal operations use IBRC, while low-latency operations use IBGDA
bool internode_use_ibgda = true;
if (low_latency_mode or internode_use_ibgda) {
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
}
nvshmem_barrier_all();
return nvshmem_my_pe();
}
void* alloc(size_t size, size_t alignment) {
return nvshmem_align(alignment, size);
}
void free(void* ptr) {
nvshmem_free(ptr);
}
void barrier() {
nvshmem_barrier_all();
}
void finalize() {
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
nvshmem_team_destroy(cpu_rdma_team);
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
}
nvshmem_finalize();
}
} // namespace internode
} // namespace deep_ep