mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Initial commit
This commit is contained in:
20
csrc/kernels/CMakeLists.txt
Normal file
20
csrc/kernels/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
function(add_deep_ep_library target_name source_file)
|
||||
add_library(${target_name} STATIC ${source_file})
|
||||
set_target_properties(${target_name} PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD_REQUIRED ON
|
||||
CUDA_STANDARD_REQUIRED ON
|
||||
CXX_STANDARD 14
|
||||
CUDA_STANDARD 14
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
)
|
||||
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
|
||||
endfunction()
|
||||
|
||||
add_deep_ep_library(intranode_cuda intranode.cu)
|
||||
add_deep_ep_library(runtime_cuda runtime.cu)
|
||||
add_deep_ep_library(internode_cuda internode.cu)
|
||||
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
|
||||
|
||||
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
|
||||
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
|
||||
153
csrc/kernels/api.cuh
Normal file
153
csrc/kernels/api.cuh
Normal file
@@ -0,0 +1,153 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
// Intranode runtime
|
||||
namespace intranode {
|
||||
|
||||
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
// Internode runtime
|
||||
namespace internode {
|
||||
|
||||
std::vector<uint8_t> get_unique_id();
|
||||
|
||||
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
|
||||
|
||||
void *alloc(size_t size, size_t alignment);
|
||||
|
||||
void free(void *ptr);
|
||||
|
||||
void barrier();
|
||||
|
||||
void finalize();
|
||||
|
||||
} // namespace internode
|
||||
|
||||
// Intranode kernels
|
||||
namespace intranode {
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
cudaStream_t stream, int num_sms);
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
const void* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens);
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
// Internode kernels
|
||||
namespace internode {
|
||||
|
||||
int get_source_meta_bytes();
|
||||
|
||||
void get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
const bool* is_token_in_rank, int num_tokens, int num_channels,
|
||||
int hidden_int4, int num_scales, int num_topk, int expert_alignment,
|
||||
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
|
||||
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank,
|
||||
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool low_latency_mode);
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
|
||||
const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
int* send_rdma_head, int* send_nvl_head,
|
||||
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
|
||||
const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum,
|
||||
const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum,
|
||||
int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts,
|
||||
const bool* is_token_in_rank,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, bool is_cached_dispatch,
|
||||
cudaStream_t stream, int num_channels, bool low_latency_mode);
|
||||
|
||||
void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
|
||||
int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head,
|
||||
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
|
||||
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
|
||||
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
|
||||
bool is_cached_dispatch, bool low_latency_mode);
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* combined_x, float* combined_topk_weights,
|
||||
const bool* is_combined_token_in_rank,
|
||||
const void* x, const float* topk_weights,
|
||||
const int* combined_rdma_head, const int* combined_nvl_head,
|
||||
const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix,
|
||||
int num_tokens, int num_combined_tokens, int hidden, int num_topk,
|
||||
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
|
||||
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
|
||||
int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode);
|
||||
|
||||
} // namespace internode
|
||||
|
||||
// Internode low-latency kernels
|
||||
namespace internode_ll {
|
||||
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
cudaStream_t stream);
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases);
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
} // namespace deep_ep
|
||||
138
csrc/kernels/buffer.cuh
Normal file
138
csrc/kernels/buffer.cuh
Normal file
@@ -0,0 +1,138 @@
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
template <typename dtype_t>
|
||||
struct Buffer {
|
||||
private:
|
||||
uint8_t* ptr;
|
||||
|
||||
public:
|
||||
int total_bytes;
|
||||
|
||||
__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}
|
||||
|
||||
__device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) {
|
||||
total_bytes = num_elems * sizeof(dtype_t);
|
||||
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) {
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer() {
|
||||
return reinterpret_cast<dtype_t*>(ptr);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t& operator[](int idx) {
|
||||
return buffer()[idx];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t, int kNumRanks = 1>
|
||||
struct AsymBuffer {
|
||||
private:
|
||||
uint8_t* ptrs[kNumRanks];
|
||||
int num_bytes;
|
||||
|
||||
public:
|
||||
int total_bytes;
|
||||
|
||||
__device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
|
||||
int sm_id = 0, int num_sms = 1, int offset = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks == 1, "");
|
||||
num_bytes = num_elems * sizeof(dtype_t);
|
||||
|
||||
int per_channel_bytes = num_bytes * num_ranks;
|
||||
total_bytes = per_channel_bytes * num_sms;
|
||||
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks,
|
||||
int sm_id = 0, int num_sms = 1, int offset = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks > 1, "");
|
||||
num_bytes = num_elems * sizeof(dtype_t);
|
||||
|
||||
int per_channel_bytes = num_bytes * num_ranks;
|
||||
total_bytes = per_channel_bytes * num_sms;
|
||||
for (int i = 0; i < kNumRanks; ++ i) {
|
||||
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
|
||||
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void advance(int shift) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i)
|
||||
ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) {
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<int kNumAlsoRanks>
|
||||
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
|
||||
for (int i = 0; i < kNumAlsoRanks; ++ i)
|
||||
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case");
|
||||
return reinterpret_cast<dtype_t*>(ptrs[0] + num_bytes * idx);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) {
|
||||
EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
|
||||
return reinterpret_cast<dtype_t*>(ptrs[rank_idx] + num_bytes * idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dtype_t, bool kDecoupled = true>
|
||||
struct SymBuffer {
|
||||
private:
|
||||
// NOTES: for non-decoupled case, `recv_ptr` is not used
|
||||
uint8_t* send_ptr;
|
||||
uint8_t* recv_ptr;
|
||||
int num_bytes;
|
||||
|
||||
public:
|
||||
int total_bytes;
|
||||
|
||||
__device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks,
|
||||
int sm_id = 0, int num_sms = 1) {
|
||||
num_bytes = num_elems * sizeof(dtype_t);
|
||||
|
||||
int per_channel_bytes = num_bytes * num_ranks;
|
||||
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
|
||||
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
|
||||
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
|
||||
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case");
|
||||
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case");
|
||||
return reinterpret_cast<dtype_t*>(recv_ptr + num_bytes * idx);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ dtype_t* buffer(int idx = 0) {
|
||||
EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
|
||||
return reinterpret_cast<dtype_t*>(send_ptr + num_bytes * idx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_ep
|
||||
50
csrc/kernels/configs.cuh
Normal file
50
csrc/kernels/configs.cuh
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#define NUM_MAX_NVL_PEERS 8
|
||||
#define NUM_MAX_RDMA_PEERS 20
|
||||
#define NUM_MAX_FIFO_SLOTS 32768
|
||||
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
|
||||
#define NUM_MAX_LOCAL_EXPERTS 1024
|
||||
#define NUM_BUFFER_ALIGNMENT_BYTES 128
|
||||
|
||||
#define FINISHED_SUM_TAG 1024
|
||||
#define NUM_CPU_TIMEOUT_SECS 100
|
||||
#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s
|
||||
#define NUM_WAIT_NANOSECONDS 500
|
||||
|
||||
#define LOW_LATENCY_SEND_PHASE 1
|
||||
#define LOW_LATENCY_RECV_PHASE 2
|
||||
|
||||
// Make CLion CUDA indexing work
|
||||
#ifdef __CLION_IDE__
|
||||
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
|
||||
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
|
||||
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
|
||||
#define printf host_device_printf
|
||||
#endif
|
||||
|
||||
// Remove Torch restrictions
|
||||
#ifdef __CUDA_NO_HALF_CONVERSIONS__
|
||||
#undef __CUDA_NO_HALF_CONVERSIONS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_HALF_OPERATORS__
|
||||
#undef __CUDA_NO_HALF_OPERATORS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_HALF2_OPERATORS__
|
||||
#undef __CUDA_NO_HALF2_OPERATORS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__
|
||||
#undef __CUDA_NO_BFLOAT16_CONVERSIONS__
|
||||
#endif
|
||||
#ifdef __CUDA_NO_BFLOAT162_OPERATORS__
|
||||
#undef __CUDA_NO_BFLOAT162_OPERATORS__
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nvshmem.h>
|
||||
#include <nvshmemx.h>
|
||||
#include <infiniband/mlx5dv.h>
|
||||
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
|
||||
#include <device_host_transport/nvshmem_common_ibgda.h>
|
||||
51
csrc/kernels/exception.cuh
Normal file
51
csrc/kernels/exception.cuh
Normal file
@@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <exception>
|
||||
|
||||
#include "configs.cuh"
|
||||
|
||||
#ifndef EP_STATIC_ASSERT
|
||||
#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
|
||||
#endif
|
||||
|
||||
class EPException: public std::exception {
|
||||
private:
|
||||
std::string message = {};
|
||||
|
||||
public:
|
||||
explicit EPException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
}
|
||||
|
||||
const char *what() const noexcept override { return message.c_str(); }
|
||||
};
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef EP_HOST_ASSERT
|
||||
#define EP_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
throw EPException("Assertion", __FILE__, __LINE__, #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef EP_DEVICE_ASSERT
|
||||
#define EP_DEVICE_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
|
||||
asm("trap;"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
423
csrc/kernels/ibgda_device.cuh
Normal file
423
csrc/kernels/ibgda_device.cuh
Normal file
@@ -0,0 +1,423 @@
|
||||
// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem)
|
||||
// Copyright (c) NVIDIA Corporation.
|
||||
// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019).
|
||||
// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html
|
||||
//
|
||||
// Modified from original source:
|
||||
// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint64_t HtoBE64(uint64_t x) {
|
||||
uint64_t ret;
|
||||
asm("{\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
".reg .b32 lo;\n\t"
|
||||
".reg .b32 hi;\n\t"
|
||||
".reg .b32 new_lo;\n\t"
|
||||
".reg .b32 new_hi;\n\t"
|
||||
"mov.b64 {lo,hi}, %1;\n\t"
|
||||
"prmt.b32 new_hi, lo, ign, 0x0123;\n\t"
|
||||
"prmt.b32 new_lo, hi, ign, 0x0123;\n\t"
|
||||
"mov.b64 %0, {new_lo,new_hi};\n\t"
|
||||
"}" : "=l"(ret) : "l"(x));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint32_t HtoBE32(uint32_t x) {
|
||||
uint32_t ret;
|
||||
asm("{\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"prmt.b32 %0, %1, ign, 0x0123;\n\t"
|
||||
"}" : "=r"(ret) : "r"(x));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint16_t HtoBE16(uint16_t x) {
|
||||
// TODO: simplify PTX using 16-bit instructions
|
||||
auto a = static_cast<uint32_t>(x);
|
||||
uint32_t d;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .b32 mask;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"mov.b32 mask, 0x4401;\n\t"
|
||||
"mov.b32 ign, 0x0;\n\t"
|
||||
"prmt.b32 %0, %1, ign, mask;\n\t"
|
||||
"}"
|
||||
: "=r"(d)
|
||||
: "r"(a));
|
||||
return static_cast<uint16_t>(d);
|
||||
}
|
||||
|
||||
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
|
||||
|
||||
__device__ static __forceinline__
|
||||
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
|
||||
return &nvshmemi_ibgda_device_state_d;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
|
||||
auto state = ibgda_get_state();
|
||||
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
|
||||
return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe];
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_lock_acquire(int *lock) {
|
||||
while (atomicCAS(lock, 0, 1) == 1);
|
||||
|
||||
// Prevent reordering before the lock is acquired
|
||||
memory_fence_cta();
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_lock_release(int *lock) {
|
||||
memory_fence_cta();
|
||||
|
||||
// Prevent reordering before lock is released
|
||||
st_na_relaxed(lock, 0);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) {
|
||||
// `DBREC` contains the index of the next empty `WQEBB`
|
||||
__be32 dbrec_val;
|
||||
__be32 *dbrec_ptr = qp->tx_wq.dbrec;
|
||||
|
||||
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))`
|
||||
asm("{\n\t"
|
||||
".reg .b32 dbrec_head_16b;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
|
||||
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
|
||||
"}"
|
||||
: "=r"(dbrec_val)
|
||||
: "r"(dbrec_head));
|
||||
st_na_release(dbrec_ptr, dbrec_val);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) {
|
||||
auto bf_ptr = reinterpret_cast<uint64_t*>(qp->tx_wq.bf);
|
||||
ibgda_ctrl_seg_t ctrl_seg = {
|
||||
.opmod_idx_opcode = HtoBE32(prod_idx << 8),
|
||||
.qpn_ds = HtoBE32(qp->qpn << 8)
|
||||
};
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), "");
|
||||
st_na_release(bf_ptr, *(reinterpret_cast<uint64_t*>(&ctrl_seg)));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) {
|
||||
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
|
||||
uint64_t old_prod_idx;
|
||||
|
||||
// Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence
|
||||
ibgda_lock_acquire(&mvars->post_send_lock);
|
||||
|
||||
old_prod_idx = atomicMax(reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.prod_idx), new_prod_idx);
|
||||
if (new_prod_idx > old_prod_idx) {
|
||||
ibgda_update_dbr(qp, new_prod_idx);
|
||||
ibgda_ring_db(qp, new_prod_idx);
|
||||
}
|
||||
ibgda_lock_release(&mvars->post_send_lock);
|
||||
}
|
||||
|
||||
template <bool kAlwaysDoPostSend>
|
||||
__device__ static __forceinline__
|
||||
void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
|
||||
uint32_t num_wqes, int message_idx = 0) {
|
||||
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
|
||||
uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
|
||||
|
||||
// WQE writes must be finished first
|
||||
__threadfence();
|
||||
|
||||
// Wait for prior WQE slots to be filled first
|
||||
auto *ready_idx = reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.ready_head);
|
||||
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
|
||||
|
||||
// Always post, not in batch
|
||||
constexpr int kNumRequestInBatch = 4;
|
||||
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
|
||||
ibgda_post_send(qp, new_wqe_idx);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
|
||||
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
|
||||
ibgda_ctrl_seg_t ctrl_seg;
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_inl_data_seg inl_seg;
|
||||
|
||||
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
|
||||
auto *raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
|
||||
auto *inl_seg_ptr = reinterpret_cast<mlx5_wqe_inl_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
|
||||
auto *wqe_data_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(inl_seg_ptr) + sizeof(*inl_seg_ptr));
|
||||
|
||||
raddr_seg.raddr = HtoBE64(raddr);
|
||||
raddr_seg.rkey = rkey;
|
||||
raddr_seg.reserved = 0;
|
||||
|
||||
inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG);
|
||||
|
||||
// `imm == std::numeric_limits<uint32_t>::max()` means no imm writes
|
||||
ctrl_seg = {0};
|
||||
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
|
||||
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
|
||||
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits<uint32_t>::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE));
|
||||
if (imm != std::numeric_limits<uint32_t>::max())
|
||||
ctrl_seg.imm = HtoBE32(imm);
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4");
|
||||
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
|
||||
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
|
||||
st_na_relaxed(reinterpret_cast<uint32_t*>(inl_seg_ptr), *reinterpret_cast<const uint32_t*>(&inl_seg));
|
||||
st_na_relaxed(reinterpret_cast<uint32_t*>(wqe_data_ptr), *reinterpret_cast<const uint32_t*>(val));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__
|
||||
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
|
||||
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
|
||||
auto state = ibgda_get_state();
|
||||
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
|
||||
auto log2_cumem_granularity = state->log2_cumem_granularity;
|
||||
|
||||
// Local key
|
||||
uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity;
|
||||
auto device_key = state->constmem.lkeys[idx];
|
||||
auto lchunk_size = device_key.next_addr - laddr;
|
||||
*lkey = device_key.key;
|
||||
|
||||
// Remote key
|
||||
uint64_t roffset = raddr - heap_start;
|
||||
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
|
||||
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
|
||||
device_key = state->constmem.rkeys[idx];
|
||||
} else {
|
||||
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
|
||||
}
|
||||
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
|
||||
*out_rkey = device_key.key;
|
||||
|
||||
// Return the minimum of local and remote chunk sizes
|
||||
auto rchunk_size = device_key.next_addr - roffset;
|
||||
return min(lchunk_size, rchunk_size);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
|
||||
auto state = ibgda_get_state();
|
||||
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
|
||||
|
||||
uint64_t roffset = addr - heap_start;
|
||||
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
|
||||
nvshmemi_ibgda_device_key_t device_key;
|
||||
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
|
||||
device_key = state->constmem.rkeys[idx];
|
||||
else
|
||||
device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS];
|
||||
*out_raddr = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset;
|
||||
*out_rkey = device_key.key;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ uint64_t
|
||||
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
|
||||
auto mvars = &qp->mvars;
|
||||
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void*
|
||||
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
|
||||
uint16_t cnt = qp->tx_wq.nwqes;
|
||||
uint16_t idx = wqe_idx & (cnt - 1);
|
||||
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
|
||||
}
|
||||
|
||||
// Wait until wqe `idx - 1` is completed.
|
||||
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
|
||||
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) {
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
auto cq = qp->rx_wq.cq;
|
||||
|
||||
const uint32_t ncqes = cq->ncqes;
|
||||
auto *cqe64 = reinterpret_cast<struct mlx5_cqe64*>(cq->cqe);
|
||||
auto old_cons_idx = *cq->cons_idx;
|
||||
*cq->cons_idx = old_cons_idx + 1;
|
||||
|
||||
// Wait until `wqe_counter >= old_cons_idx`
|
||||
while ((static_cast<uint16_t>(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
|
||||
// Get rkey
|
||||
// NOTES: the `p` operation will not cross multiple remote chunks
|
||||
__be32 rkey;
|
||||
uint64_t raddr;
|
||||
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey);
|
||||
|
||||
// Write WQEs
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
void *wqe_ptrs;
|
||||
wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
|
||||
ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast<const uint32_t*>(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm);
|
||||
|
||||
// Submit requests
|
||||
ibgda_submit_requests<true>(qp, base_wqe_idx, 1);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
|
||||
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
|
||||
void **out_wqes) {
|
||||
ibgda_ctrl_seg_t ctrl_seg;
|
||||
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||
struct mlx5_wqe_data_seg data_seg;
|
||||
|
||||
auto *ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
|
||||
void *av_seg_ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
|
||||
struct mlx5_wqe_raddr_seg *raddr_seg_ptr;
|
||||
struct mlx5_wqe_data_seg *data_seg_ptr;
|
||||
|
||||
raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(av_seg_ptr));
|
||||
data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
|
||||
|
||||
raddr_seg.raddr = HtoBE64(raddr);
|
||||
raddr_seg.rkey = rkey;
|
||||
raddr_seg.reserved = 0;
|
||||
|
||||
data_seg.byte_count = HtoBE32(bytes);
|
||||
data_seg.lkey = lkey;
|
||||
data_seg.addr = HtoBE64(laddr);
|
||||
|
||||
ctrl_seg = {0};
|
||||
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3);
|
||||
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
|
||||
ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE);
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16");
|
||||
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16");
|
||||
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<const int4*>(&ctrl_seg));
|
||||
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<const int4*>(&raddr_seg));
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
ibgda_write_empty_recv_wqe(void *out_wqe) {
|
||||
auto *data_seg_ptr = reinterpret_cast<struct mlx5_wqe_data_seg*>(out_wqe);
|
||||
struct mlx5_wqe_data_seg data_seg;
|
||||
|
||||
// Make the first segment in the WQE invalid, then the entire list will be invalid
|
||||
data_seg.byte_count = 0;
|
||||
data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY);
|
||||
data_seg.addr = 0;
|
||||
|
||||
EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length");
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ uint64_t
|
||||
nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
|
||||
auto mvars = &qp->mvars;
|
||||
|
||||
// Allocate if not enough
|
||||
constexpr int kMinIBGDARecvs = 32;
|
||||
auto resv_head = mvars->rx_wq.resv_head;
|
||||
auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx;
|
||||
if (num_valid_slots < kMinIBGDARecvs) {
|
||||
resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes;
|
||||
mvars->rx_wq.resv_head = resv_head;
|
||||
|
||||
// Ensure WQE is written before `dbrec`
|
||||
__be32 dbrec_val;
|
||||
__be32 *dbrec_ptr = qp->rx_wq.dbrec;
|
||||
|
||||
// Compared to sending, for each QP, we only post recv in a single thread,
|
||||
// so we don't need to do synchronization here
|
||||
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
|
||||
asm("{\n\t"
|
||||
".reg .b32 dbrec_head_16b;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
|
||||
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
|
||||
"}" : "=r"(dbrec_val)
|
||||
: "r"(static_cast<uint32_t>(resv_head)));
|
||||
st_na_release(dbrec_ptr, dbrec_val);
|
||||
}
|
||||
|
||||
// Return old number of slots
|
||||
return num_valid_slots;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
|
||||
// NOTES: only one thread can run this function
|
||||
// TODO: consider this assertion for normal AR
|
||||
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||
// Get lkey and rkey, store them into lanes
|
||||
uint32_t num_wqes = 0;
|
||||
__be32 my_lkey = 0;
|
||||
uint64_t my_laddr = 0;
|
||||
__be32 my_rkey = 0;
|
||||
uint64_t my_raddr = 0;
|
||||
uint64_t my_chunk_size = 0;
|
||||
|
||||
// Decide how many messages (theoretically 3 for maximum)
|
||||
auto remaining_bytes = bytes;
|
||||
while (remaining_bytes > 0) {
|
||||
if (lane_id == num_wqes)
|
||||
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey));
|
||||
|
||||
// Move one more message
|
||||
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
|
||||
remaining_bytes -= chunk_size;
|
||||
req_lptr += chunk_size;
|
||||
req_rptr += chunk_size;
|
||||
++ num_wqes;
|
||||
}
|
||||
EP_DEVICE_ASSERT(num_wqes <= 32);
|
||||
|
||||
// Process WQE
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t base_wqe_idx = 0;
|
||||
if (lane_id == 0)
|
||||
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
|
||||
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
|
||||
if (lane_id < num_wqes) {
|
||||
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
|
||||
ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size,
|
||||
base_wqe_idx, &wqe_ptr);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Submit
|
||||
if (lane_id == 0)
|
||||
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
1720
csrc/kernels/internode.cu
Normal file
1720
csrc/kernels/internode.cu
Normal file
File diff suppressed because it is too large
Load Diff
533
csrc/kernels/internode_ll.cu
Normal file
533
csrc/kernels/internode_ll.cu
Normal file
@@ -0,0 +1,533 @@
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace internode_ll {
|
||||
|
||||
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
|
||||
__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1) {
|
||||
// Barrier before cleaning (in case of unfinished chunked EP)
|
||||
nvshmemx_barrier_all_block();
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
|
||||
clean_0[i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
|
||||
clean_1[i] = 0;
|
||||
|
||||
// Barrier after cleaning (make sure low-latency mode work fine)
|
||||
nvshmemx_barrier_all_block();
|
||||
}
|
||||
|
||||
void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
|
||||
int* clean_1, int num_clean_int_1,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kNumThreads = 256;
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
|
||||
LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>,
|
||||
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
|
||||
}
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* atomic_counter_per_local_expert,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_local_experts = num_experts / num_ranks;
|
||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// FP8 staffs
|
||||
constexpr int kNumPerChannels = 128;
|
||||
constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, kFP8AmaxInv = 1.0f / 448.0f;
|
||||
const int num_scales = kHidden / kNumPerChannels;
|
||||
const size_t hidden_int4 = kHidden / sizeof(int4);
|
||||
|
||||
// Message package: hidden data, FP8 scales, index at source
|
||||
// NOTES: currently we have 3 reserved int fields for future use
|
||||
const size_t num_bytes_per_msg = kHidden + num_scales * sizeof(float) + sizeof(int4);
|
||||
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
|
||||
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
|
||||
|
||||
// Sending phase
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||
goto LOW_LATENCY_DISPATCH_RECV;
|
||||
|
||||
// Expert counts
|
||||
__shared__ int shared_num_tokens_sent_per_expert[kNumWarpGroups];
|
||||
|
||||
// There are 2 kinds of warps in this part:
|
||||
// 1. The first-kind warps for FP8 cast and sending top-k tokens
|
||||
// 2. The last warp for reading `topk_idx` and count for per-expert information
|
||||
if (warp_id < num_warps - 1) {
|
||||
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16);
|
||||
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization");
|
||||
const auto num_threads = (num_warps - 1) * 32;
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;
|
||||
|
||||
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
|
||||
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
|
||||
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden);
|
||||
const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales);
|
||||
|
||||
// Overlap top-k index read and source token index write
|
||||
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
|
||||
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
|
||||
|
||||
// FP8 cast
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
|
||||
// Read and calculate local amax
|
||||
auto int4_value = __ldg(x_int4 + i);
|
||||
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
|
||||
float fp32_values[kNumElemsPerRead];
|
||||
float amax = kFP8Margin, scale, scale_inv;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; ++ j) {
|
||||
fp32_values[j] = static_cast<float>(bf16_values[j]);
|
||||
amax = fmaxf(amax, fabsf(fp32_values[j]));
|
||||
}
|
||||
|
||||
// Reduce amax and scale
|
||||
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
|
||||
amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, scale_inv = amax * kFP8AmaxInv;
|
||||
if (lane_id == 0 or lane_id == 16)
|
||||
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
|
||||
|
||||
// Cast into send buffer
|
||||
int2 int2_value;
|
||||
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerRead; j += 2) {
|
||||
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
|
||||
fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
rdma_x_int2[i] = int2_value;
|
||||
}
|
||||
asm volatile("bar.sync 1, %0;" :: "r"(num_threads));
|
||||
|
||||
// Issue IBGDA sends
|
||||
if (dst_expert_idx >= 0) {
|
||||
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
|
||||
slot_idx = __shfl_sync(0xffffffff, slot_idx, 0);
|
||||
const auto dst_rank = dst_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
|
||||
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
slot_idx * num_bytes_per_msg;
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
|
||||
} else {
|
||||
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
|
||||
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
|
||||
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
|
||||
}
|
||||
|
||||
// Increase counter after finishing
|
||||
__syncwarp();
|
||||
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
|
||||
}
|
||||
}
|
||||
} else if (warp_id == num_warps - 1) {
|
||||
EP_DEVICE_ASSERT(num_sms > 1);
|
||||
if (sm_id == 0) {
|
||||
// The first SM is also responsible for checking QPs
|
||||
EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts);
|
||||
|
||||
// The first SM is also responsible for cleaning the next buffer
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_next_clean_int; i += 32)
|
||||
next_clean[i] = 0;
|
||||
|
||||
// Notify before executing `int_p`
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_experts; i += 32)
|
||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
|
||||
}
|
||||
|
||||
// This SM should be responsible for some destination experts, read `topk_idx` for them
|
||||
int expert_count[kNumWarpGroups] = {0};
|
||||
const auto expert_begin_idx = sm_id * kNumWarpGroups;
|
||||
const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts);
|
||||
|
||||
// Per lane count
|
||||
#pragma unroll 8
|
||||
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
|
||||
auto idx = static_cast<int>(__ldg(topk_idx + i));
|
||||
if (idx >= expert_begin_idx and idx < expert_end_idx)
|
||||
expert_count[idx - expert_begin_idx] ++;
|
||||
}
|
||||
|
||||
// Warp reduce
|
||||
#pragma unroll
|
||||
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
|
||||
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
|
||||
if (lane_id == 0) {
|
||||
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
|
||||
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Issue count sends
|
||||
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups];
|
||||
|
||||
// Wait local sends issued and send expert counts
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
|
||||
dst_rank, dst_expert_local_idx, 0);
|
||||
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||
}
|
||||
|
||||
// Clean workspace for next use
|
||||
atomic_counter_per_expert[responsible_expert_idx] = 0;
|
||||
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Receiving phase
|
||||
LOW_LATENCY_DISPATCH_RECV:
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
// Receiving and packing
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
|
||||
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
|
||||
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
|
||||
const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
|
||||
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
|
||||
|
||||
// Shared between sub-warps in warp groups
|
||||
__shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups];
|
||||
|
||||
// Wait tokens to arrive
|
||||
// NOTES: using sub-warp 1 to overlap with sub-warp 0
|
||||
int num_recv_tokens, recv_token_begin_idx;
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
if (src_rank != rank) {
|
||||
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
|
||||
num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
|
||||
EP_DEVICE_ASSERT(num_recv_tokens != 0);
|
||||
} else {
|
||||
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||
}
|
||||
num_recv_tokens = -num_recv_tokens - 1;
|
||||
recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens);
|
||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
|
||||
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
|
||||
}
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32));
|
||||
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
|
||||
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
|
||||
|
||||
// Copy tokens
|
||||
EP_DEVICE_ASSERT(num_scales <= 64);
|
||||
for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) {
|
||||
// Copy data
|
||||
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
|
||||
const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
|
||||
const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global);
|
||||
|
||||
// Copy scales
|
||||
const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden);
|
||||
const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
|
||||
const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
|
||||
auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0;
|
||||
lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
|
||||
(lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f;
|
||||
|
||||
// Copy source info
|
||||
const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales);
|
||||
if (lane_id == 0)
|
||||
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
|
||||
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
|
||||
const void* x, const int64_t* topk_idx,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
|
||||
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
|
||||
|
||||
// Use the last part `rdma_recv_count` as `atomic_counter_per_local_expert`
|
||||
// NOTES: this part will be cleaned in `combine`
|
||||
auto atomic_counter_per_local_expert = rdma_recv_count + num_ranks * (num_experts / num_ranks);
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(hidden) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<kNumWarpGroups, kNumWarpsPerGroup, hidden>, \
|
||||
packed_recv_x, packed_recv_x_scales, \
|
||||
packed_recv_src_info, packed_recv_layout_range, \
|
||||
rdma_recv_x, rdma_recv_count, rdma_x, \
|
||||
x, topk_idx, \
|
||||
atomic_counter_per_expert, atomic_finish_counter_per_expert, atomic_counter_per_local_expert, \
|
||||
next_clean, num_next_clean_int, \
|
||||
num_tokens, num_max_dispatch_tokens_per_rank, \
|
||||
num_topk, num_experts, rank, num_ranks, phases); break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
|
||||
#undef DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
|
||||
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void
|
||||
combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int* atomic_clean_flag,
|
||||
int num_combined_tokens, int hidden, int num_topk,
|
||||
int num_max_dispatch_tokens_per_rank,
|
||||
int num_experts, int rank, int num_ranks,
|
||||
int phases) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto num_threads = static_cast<int>(blockDim.x);
|
||||
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
|
||||
const auto num_local_experts = num_experts / num_ranks;
|
||||
const auto warp_group_id = warp_id / kNumWarpsPerGroup;
|
||||
const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
|
||||
const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;
|
||||
|
||||
// Data type staffs
|
||||
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16);
|
||||
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
|
||||
|
||||
// Message package
|
||||
// BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
|
||||
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(nv_bfloat16);
|
||||
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
|
||||
|
||||
// Sending phase
|
||||
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
|
||||
goto LOW_LATENCY_COMBINE_RECV;
|
||||
|
||||
// Clean up next buffer
|
||||
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = lane_id; i < num_next_clean_int; i += 32)
|
||||
next_clean[i] = 0;
|
||||
|
||||
// Notify before executing `int_p`
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
atomic_add_release_global(atomic_clean_flag, num_experts);
|
||||
}
|
||||
|
||||
// FP8 cast and issue IBGDA sends
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
const auto dst_rank = responsible_expert_idx / num_local_experts;
|
||||
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
|
||||
const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
|
||||
const auto local_x = reinterpret_cast<const int4*>(x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
|
||||
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
|
||||
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
|
||||
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
|
||||
|
||||
// Unpack layout
|
||||
int offset, num_tokens_to_send;
|
||||
unpack2(layout, num_tokens_to_send, offset);
|
||||
|
||||
// Issue IBGDA send
|
||||
for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
|
||||
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
|
||||
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
|
||||
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);
|
||||
|
||||
// Copy directly to local rank, or copy to buffer and issue RDMA
|
||||
auto src_idx = __ldg(local_src_info + token_idx);
|
||||
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
|
||||
if (dst_rank == rank) {
|
||||
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
} else {
|
||||
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
|
||||
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Put finishing flag
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
|
||||
if (sub_warp_id == 1 and lane_id == 0) {
|
||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0);
|
||||
} else {
|
||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||
}
|
||||
atomic_add_release_global(atomic_clean_flag, -1);
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Receiving phase
|
||||
LOW_LATENCY_COMBINE_RECV:
|
||||
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
|
||||
return;
|
||||
|
||||
// Wait all ranks to arrive and notify PCIe usage
|
||||
if (responsible_expert_idx < num_experts) {
|
||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
||||
if (sub_warp_id == 0 and lane_id == 0) {
|
||||
// TODO: refactor QP indices
|
||||
auto src_rank = responsible_expert_idx / num_local_experts;
|
||||
auto src_expert_idx = responsible_expert_idx % num_local_experts;
|
||||
if (src_rank != rank) {
|
||||
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
|
||||
} else {
|
||||
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
cg::this_grid().sync();
|
||||
|
||||
// Reduce tokens with FP8 cast
|
||||
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
|
||||
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
|
||||
if (thread_id < hidden_bf16_int4) {
|
||||
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
|
||||
// Read top-k indices and weights
|
||||
int reg_topk_idx[kNumMaxTopk];
|
||||
float reg_topk_weights[kNumMaxTopk];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk; ++ i) {
|
||||
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
|
||||
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
|
||||
}
|
||||
|
||||
float combined_values[kNumElemsPerInt4] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
|
||||
// Read from sources
|
||||
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
|
||||
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
|
||||
|
||||
// Reduce
|
||||
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
|
||||
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
|
||||
}
|
||||
|
||||
// Write results
|
||||
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
|
||||
auto combined_bf16 = reinterpret_cast<nv_bfloat16*>(&combined_values);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kNumElemsPerInt4; ++ j)
|
||||
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
|
||||
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void combine(void* combined_x,
|
||||
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
|
||||
const void* x, const int64_t* topk_idx, const float* topk_weights,
|
||||
const int* src_info, const int64_t* layout_range,
|
||||
int* next_clean, int num_next_clean_int,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumWarpsPerGroup = 10;
|
||||
constexpr int kNumWarpGroups = 3;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
|
||||
// Check workspace
|
||||
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
|
||||
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(hidden) { \
|
||||
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
|
||||
LAUNCH_KERNEL(&cfg, combine_func, \
|
||||
combined_x, \
|
||||
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
|
||||
x, topk_idx, topk_weights, src_info, layout_range, \
|
||||
next_clean, num_next_clean_int, \
|
||||
atomic_clean_flag, \
|
||||
num_combined_tokens, hidden, num_topk, \
|
||||
num_max_dispatch_tokens_per_rank, \
|
||||
num_experts, rank, num_ranks, \
|
||||
phases); } break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
|
||||
SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
|
||||
#undef COMBINE_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace internode_ll
|
||||
|
||||
} // namespace deep_ep
|
||||
803
csrc/kernels/intranode.cu
Normal file
803
csrc/kernels/intranode.cu
Normal file
@@ -0,0 +1,803 @@
|
||||
#include "configs.cuh"
|
||||
#include "buffer.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace intranode {
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
|
||||
|
||||
if (sm_id == 0) {
|
||||
// Barrier first
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
int *per_rank_buffer, *per_expert_buffer;
|
||||
if (thread_id < kNumRanks) {
|
||||
per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[thread_id]);
|
||||
per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
|
||||
}
|
||||
|
||||
// After this loop:
|
||||
// - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j
|
||||
// - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j
|
||||
int num_experts_per_rank = num_experts / kNumRanks;
|
||||
if (thread_id < kNumRanks) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i)
|
||||
per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_experts_per_rank; ++ i)
|
||||
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for all ranks to be finished
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
// Sum per-rank counts and return to CPU
|
||||
// Also pre-compute the prefix sum for data sending
|
||||
auto local_per_rank_buffer = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
if (thread_id < kNumRanks) {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < kNumRanks; ++ i)
|
||||
local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];
|
||||
if (thread_id == rank)
|
||||
*moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];
|
||||
}
|
||||
|
||||
// Sum per-experts counts and return to CPU
|
||||
auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;
|
||||
if (thread_id < num_experts_per_rank) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i)
|
||||
sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];
|
||||
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
|
||||
moe_recv_expert_counter_mapped[thread_id] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Copy rank size prefix matrix to another tensor
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
||||
rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];
|
||||
|
||||
// Extra memset for later communication queue
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
local_per_expert_buffer[i] = 0;
|
||||
|
||||
// Barrier
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
} else {
|
||||
int dst_rank = sm_id - 1;
|
||||
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
||||
|
||||
// Iterate over tokens
|
||||
int count = 0;
|
||||
for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32)
|
||||
count += is_token_in_rank[i * kNumRanks + dst_rank];
|
||||
count = warp_reduce_sum(count);
|
||||
if (lane_id == 0)
|
||||
channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Pre-compute prefix sum for all channels
|
||||
if (thread_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_channels; ++ i)
|
||||
channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
cudaStream_t stream, int num_channels) {
|
||||
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
|
||||
num_tokens_per_rank, moe_recv_counter_mapped, \
|
||||
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
|
||||
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
|
||||
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
constexpr int kNumThreads = 128;
|
||||
EP_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
|
||||
SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
|
||||
#undef NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
|
||||
// A simplified version for cached handles
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
// Copy and clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
|
||||
ptr[i] = rank_prefix_matrix[i];
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
ptr[kNumRanks * kNumRanks + i] = 0;
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
}
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs,
|
||||
int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
|
||||
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 128, stream);
|
||||
SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);
|
||||
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void __launch_bounds__(kNumRanks * 32, 1)
|
||||
dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void **buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
EP_DEVICE_ASSERT(num_sms % 2 == 0);
|
||||
|
||||
// Each warp is responsible for a single rank
|
||||
const auto num_channels = num_sms / 2;
|
||||
const auto responsible_rank = (static_cast<int>(thread_id)) / 32;
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving
|
||||
const auto responsible_channel = sm_id / 2;
|
||||
|
||||
int num_experts_per_rank = num_experts / kNumRanks;
|
||||
EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);
|
||||
EP_DEVICE_ASSERT(num_topk <= 32);
|
||||
EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
|
||||
EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
// `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int));
|
||||
int target_rank = is_sender ? rank : responsible_rank;
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;
|
||||
|
||||
// Channel buffer metadata
|
||||
// Senders are responsible for tails, and receivers are responsible for heads
|
||||
// Stored on the receiver side
|
||||
// The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t`
|
||||
// `start_offset`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `end_offset`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_end_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
|
||||
// Channel data buffers, stored on the receiver side
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
// `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t)
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
// `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float)
|
||||
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
|
||||
auto channel_topk_idx_buffers = Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
constexpr int num_send_warps = kNumRanks;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_warp_id = send_thread_id / 32;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
EP_DEVICE_ASSERT(num_send_warps == kNumRanks and send_warp_id == responsible_rank);
|
||||
|
||||
// Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
|
||||
// NOTES: this is for distinguishing zero tokens
|
||||
if (send_lane_id == 0) {
|
||||
int value = responsible_channel > 0 ? channel_prefix_matrix[send_warp_id * num_channels + responsible_channel - 1] : 0;
|
||||
st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
|
||||
value = channel_prefix_matrix[send_warp_id * num_channels + responsible_channel];
|
||||
st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Get tasks
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
|
||||
|
||||
// Iterate over all tokens and send by chunks
|
||||
int cached_channel_tail_idx = 0;
|
||||
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
auto start_time = clock64();
|
||||
while (send_lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
|
||||
break;
|
||||
|
||||
// Rare cases to loop again
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
int chunk_token_idx = 0;
|
||||
while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
|
||||
if (send_lane_id == 0)
|
||||
send_head[token_idx * kNumRanks + send_warp_id] = is_token_in_rank[token_idx * kNumRanks + send_warp_id] ? cached_channel_tail_idx : -1;
|
||||
// Skip if not selected
|
||||
if (not is_token_in_rank[token_idx * kNumRanks + send_warp_id]) {
|
||||
token_idx ++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get an empty slot
|
||||
int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens;
|
||||
|
||||
// Copy data
|
||||
auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x + token_idx * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x,
|
||||
__ldg, st_na_global);
|
||||
|
||||
// Copy source index
|
||||
if (send_lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights` with transformed index
|
||||
if (send_lane_id < num_topk) {
|
||||
// Top-k index
|
||||
int recv_expert_begin = send_warp_id * num_experts_per_rank, recv_expert_end = (send_warp_id + 1) * num_experts_per_rank;
|
||||
auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id);
|
||||
idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1;
|
||||
channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value;
|
||||
|
||||
// Top-k weights
|
||||
auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id);
|
||||
weight_value = (idx_value >= 0) ? weight_value : 0.0f;
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value;
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll
|
||||
for (int i = send_lane_id; i < num_scales; i += 32)
|
||||
channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i);
|
||||
|
||||
// Move token index
|
||||
chunk_token_idx ++, token_idx ++;
|
||||
}
|
||||
|
||||
// Move tail index
|
||||
__syncwarp();
|
||||
if (send_lane_id == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
// Workers for receiving and copying into buffer
|
||||
constexpr int num_recv_warps = kNumRanks;
|
||||
const auto recv_thread_id = thread_id;
|
||||
const auto recv_warp_id = recv_thread_id / 32;
|
||||
const auto recv_lane_id = recv_thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and recv_warp_id == responsible_rank);
|
||||
EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps == kNumRanks);
|
||||
|
||||
// Calculate offset first
|
||||
auto rank_prefix_matrix = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
int rank_offset = recv_warp_id > 0 ? rank_prefix_matrix[(recv_warp_id - 1) * kNumRanks + rank] : 0;
|
||||
|
||||
// Receive channel offset
|
||||
int total_offset, num_tokens_to_recv;
|
||||
while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0);
|
||||
while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0);
|
||||
if (recv_lane_id == 0) {
|
||||
total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
|
||||
recv_channel_offset[recv_warp_id * num_channels + responsible_channel] = total_offset;
|
||||
num_tokens_to_recv -= total_offset;
|
||||
}
|
||||
total_offset = __shfl_sync(0xffffffff, total_offset, 0);
|
||||
total_offset += rank_offset;
|
||||
num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0);
|
||||
|
||||
auto start_time = clock64();
|
||||
int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
|
||||
while (num_tokens_to_recv > 0) {
|
||||
// Check channel status by lane 0
|
||||
while (recv_lane_id == 0) {
|
||||
cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());;
|
||||
|
||||
// Ready to copy
|
||||
if (cached_channel_head_idx != cached_channel_tail_idx)
|
||||
break;
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
|
||||
// Sync queue tail
|
||||
cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0);
|
||||
|
||||
// Copy data
|
||||
int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
|
||||
for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx) {
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
|
||||
ld_nc_global, st_na_global);
|
||||
}
|
||||
|
||||
// Copy `src_idx`
|
||||
#pragma unroll 4
|
||||
for (int chunk_idx = cached_channel_head_idx + recv_lane_id; chunk_idx < cached_channel_tail_idx; chunk_idx += 32)
|
||||
recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);
|
||||
|
||||
// Copy `topk_idx` and `topk_weights`
|
||||
#pragma unroll 4
|
||||
for (int idx = recv_lane_id; idx < num_recv_tokens * num_topk; idx += 32) {
|
||||
int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto recv_idx = static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
|
||||
auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;
|
||||
recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);
|
||||
recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);
|
||||
}
|
||||
|
||||
// Copy `x_scales`
|
||||
#pragma unroll 4
|
||||
for (int i = recv_lane_id; i < num_recv_tokens * num_scales; i += 32) {
|
||||
int chunk_idx = i / num_scales, scales_idx = i % num_scales;
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales + scales_idx] =
|
||||
ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx);
|
||||
}
|
||||
|
||||
// Move queue
|
||||
cached_channel_head_idx += num_recv_tokens;
|
||||
total_offset += num_recv_tokens;
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);
|
||||
|
||||
// Exit
|
||||
num_tokens_to_recv -= num_recv_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
|
||||
int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, dispatch<ranks>, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
is_token_in_rank, channel_prefix_matrix, \
|
||||
num_tokens, hidden_int4, num_topk, num_experts, num_scales, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving.
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
SETUP_LAUNCH_CONFIG(num_sms, num_ranks * 32, stream);
|
||||
SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
|
||||
#undef DISPATCH_LAUNCH_CASE
|
||||
}
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
if (sm_id == 0) {
|
||||
// Barrier before cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto ptr = reinterpret_cast<int*>(buffer_ptrs[rank]);
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_memset_int; i += num_threads)
|
||||
ptr[i] = 0;
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
} else {
|
||||
const auto channel_id = sm_id - 1;
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto rank_id = thread_id / 32;
|
||||
const auto lane_id = thread_id % 32;
|
||||
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
|
||||
|
||||
// NOTES: `1 << 25` is a heuristic large number
|
||||
int last_head = 1 << 25;
|
||||
#pragma unroll
|
||||
for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) {
|
||||
int token_idx = token_idx_tail - lane_id, expected_head = 0;
|
||||
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
|
||||
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
|
||||
head = __shfl_sync(0xffffffff, current_head, i);
|
||||
if (head < 0) {
|
||||
if (lane_id == i)
|
||||
expected_head = -last_head - 1;
|
||||
} else {
|
||||
last_head = head;
|
||||
}
|
||||
}
|
||||
if (current_head < 0 and token_idx >= token_start_idx)
|
||||
send_head[token_idx * kNumRanks + rank_id] = expected_head;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_COMBINE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
|
||||
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
const int num_threads = std::max(128, 32 * num_ranks);
|
||||
EP_HOST_ASSERT(num_ranks <= num_threads);
|
||||
EP_HOST_ASSERT(num_threads <= 1024);
|
||||
EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
|
||||
SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);
|
||||
SWITCH_RANKS(CACHED_NOTIFY_COMBINE);
|
||||
#undef CACHED_NOTIFY_COMBINE
|
||||
}
|
||||
|
||||
template<typename dtype_t, int kNumRanks, int kNumThreads>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void **buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto num_channels = num_sms / 2;
|
||||
const bool is_sender = sm_id % 2 == 0;
|
||||
const int responsible_channel = sm_id / 2;
|
||||
EP_DEVICE_ASSERT(num_topk <= 32);
|
||||
|
||||
constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
|
||||
int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4);
|
||||
auto x_int4 = reinterpret_cast<const int4*>(x);
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
// Several warps are responsible for a single rank
|
||||
constexpr int num_send_warps = kNumThreads / 32;
|
||||
constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
|
||||
const auto num_threads_per_rank = num_send_warps_per_rank * 32;
|
||||
const auto send_thread_id = thread_id;
|
||||
const auto send_lane_id = send_thread_id % 32;
|
||||
const auto send_rank_id = thread_id / num_threads_per_rank;
|
||||
const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32;
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[send_rank_id]));
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + rank;
|
||||
|
||||
// Channel meta data
|
||||
// `head_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
|
||||
auto channel_x_buffers = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens);
|
||||
auto channel_topk_weights_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
|
||||
// Get tasks
|
||||
// NOTES: `channel_offset` is already shifted
|
||||
int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;
|
||||
int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;
|
||||
int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];
|
||||
int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset;
|
||||
int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens;
|
||||
|
||||
// Iterate over all tokens and send by chunks
|
||||
int current_channel_tail_idx = 0;
|
||||
for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) {
|
||||
// Check destination queue emptiness, or wait a buffer to be released (rare cases)
|
||||
auto start_time = clock64();
|
||||
int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
|
||||
while (send_lane_id == 0) {
|
||||
// NOTES: we only consider the worst case, because counting the real numbers are time-consuming
|
||||
int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
|
||||
if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
|
||||
break;
|
||||
|
||||
// Rare cases to loop again
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Send by chunk
|
||||
#pragma unroll
|
||||
for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {
|
||||
// Get an empty slot
|
||||
int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;
|
||||
|
||||
// Copy data
|
||||
auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
|
||||
auto shifted_x = x_int4 + (token_idx + i) * hidden_int4;
|
||||
UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
|
||||
|
||||
// Send source index
|
||||
if (send_lane_id == 0)
|
||||
channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);
|
||||
|
||||
// Send `topk_weights`
|
||||
if (num_topk > 0 and send_lane_id < num_topk)
|
||||
channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id);
|
||||
}
|
||||
token_idx += num_round_tokens;
|
||||
current_channel_tail_idx += num_round_tokens;
|
||||
|
||||
// Move tail index
|
||||
asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank));
|
||||
if (send_lane_id == 0 and send_warp_id_in_rank == 0)
|
||||
st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
|
||||
}
|
||||
} else {
|
||||
// Workers for receiving
|
||||
// One warp for moving the queue head, others for reduction
|
||||
constexpr int num_recv_warps = kNumThreads / 32;
|
||||
const auto recv_warp_id = thread_id / 32;
|
||||
const auto recv_lane_id = thread_id % 32;
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32);
|
||||
EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0);
|
||||
|
||||
// Shared head, tail and retired flags for receiver warps
|
||||
__shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks];
|
||||
__shared__ volatile int channel_tail_idx[kNumRanks];
|
||||
__shared__ volatile bool warp_retired[num_recv_warps];
|
||||
if (thread_id < num_recv_warps)
|
||||
warp_retired[thread_id] = false;
|
||||
if (recv_lane_id < kNumRanks)
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0;
|
||||
if (thread_id < kNumRanks)
|
||||
channel_tail_idx[thread_id] = 0;
|
||||
asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads));
|
||||
|
||||
if (thread_id < 32) {
|
||||
int* channel_head_idx_ptr = reinterpret_cast<int*>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id;
|
||||
int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
|
||||
|
||||
// Queue head updater
|
||||
int last_head = 0;
|
||||
while (recv_lane_id < kNumRanks) {
|
||||
// Check retired
|
||||
bool retired = true;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_recv_warps; ++ i)
|
||||
retired = retired and warp_retired[i];
|
||||
if (retired)
|
||||
break;
|
||||
|
||||
// Update queue tail
|
||||
channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr);
|
||||
|
||||
// Update minimum head
|
||||
int min_head = std::numeric_limits<int>::max();
|
||||
#pragma unroll
|
||||
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
|
||||
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
|
||||
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
|
||||
}
|
||||
} else {
|
||||
// Receivers
|
||||
// Channel metadata
|
||||
// All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`
|
||||
Buffer<int4> channel_x_buffers[kNumRanks];
|
||||
Buffer<float> channel_topk_weights_buffers[kNumRanks];
|
||||
|
||||
// Calculate pointers by the specific layout
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i) {
|
||||
auto channel_rank_offset = responsible_channel * kNumRanks + i;
|
||||
auto num_channels_total = num_channels * kNumRanks;
|
||||
// `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
|
||||
auto ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int));
|
||||
|
||||
// `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
|
||||
channel_x_buffers[i] = Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
|
||||
|
||||
// `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
|
||||
ptr = reinterpret_cast<void*>(reinterpret_cast<int8_t*>(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int));
|
||||
|
||||
// `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float)
|
||||
channel_topk_weights_buffers[i] = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk);
|
||||
}
|
||||
|
||||
// The same tokens as the dispatch process
|
||||
int token_start_idx, token_end_idx;
|
||||
get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx);
|
||||
|
||||
// Iterate over all tokens and combine
|
||||
for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) {
|
||||
// Read expected head
|
||||
int expected_head = -1;
|
||||
if (recv_lane_id < kNumRanks) {
|
||||
expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id);
|
||||
warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
|
||||
}
|
||||
auto start_time = clock64();
|
||||
while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Broadcast current heads
|
||||
int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanks; ++ i) {
|
||||
auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i);
|
||||
if (expected_head_i >= 0) {
|
||||
slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;
|
||||
topk_ranks[num_topk_ranks ++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce data
|
||||
#pragma unroll
|
||||
for (int i = recv_lane_id; i < hidden_int4; i += 32) {
|
||||
// Read buffers
|
||||
int4 recv_value_int4[kNumRanks];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j)
|
||||
recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i);
|
||||
|
||||
// Reduce all-to-all results
|
||||
float values[kDtypePerInt4] = {0};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_topk_ranks; ++ j) {
|
||||
auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kDtypePerInt4; ++ k)
|
||||
values[k] += static_cast<float>(recv_value_dtypes[k]);
|
||||
}
|
||||
|
||||
// Cast back to `dtype_t` and write
|
||||
int4 out_int4;
|
||||
auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
||||
recv_int4[token_idx * hidden_int4 + i] = out_int4;
|
||||
}
|
||||
|
||||
// Reduce `topk_weights`
|
||||
if (recv_lane_id < num_topk) {
|
||||
float value = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_topk_ranks; ++ i)
|
||||
value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id);
|
||||
recv_topk_weights[token_idx * num_topk + recv_lane_id] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// Retired
|
||||
__syncwarp();
|
||||
if (recv_lane_id == 0)
|
||||
warp_retired[recv_warp_id] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void combine(cudaDataType_t type,
|
||||
void* recv_x, float* recv_topk_weights,
|
||||
const void* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void** buffer_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream, int num_sms,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) \
|
||||
LAUNCH_KERNEL(&cfg, (combine<dtype, ranks, kNumThreads>), \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
src_idx, rank_prefix_matrix, channel_prefix_matrix, \
|
||||
send_head, num_tokens, num_recv_tokens, hidden, num_topk, \
|
||||
buffer_ptrs, rank, \
|
||||
num_max_send_tokens, num_recv_buffer_tokens); \
|
||||
break
|
||||
#define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break
|
||||
|
||||
// Even-numbered blocks for sending, odd-numbered blocks for receiving
|
||||
EP_HOST_ASSERT(num_sms % 2 == 0);
|
||||
EP_HOST_ASSERT(kNumThreads >= num_ranks * 32);
|
||||
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
|
||||
SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);
|
||||
#undef COMBINE_DTYPE_LAUNCH_CASE
|
||||
#undef COMBINE_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
} // namespace deep_ep
|
||||
60
csrc/kernels/launch.cuh
Normal file
60
csrc/kernels/launch.cuh
Normal file
@@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
|
||||
#ifndef SETUP_LAUNCH_CONFIG
|
||||
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
|
||||
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
|
||||
cudaLaunchAttribute attr[1]; \
|
||||
attr[0].id = cudaLaunchAttributeCooperative; \
|
||||
attr[0].val.cooperative = 1; \
|
||||
cfg.attrs = attr; \
|
||||
cfg.numAttrs = 1
|
||||
#endif
|
||||
|
||||
#ifndef LAUNCH_KERNEL
|
||||
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
|
||||
#endif
|
||||
|
||||
#define SWITCH_RANKS(case_macro) \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(2); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_RDMA_RANKS(case_macro) \
|
||||
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
|
||||
case 2: case_macro(2); \
|
||||
case 3: case_macro(3); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
case 16: case_macro(16); \
|
||||
case 18: case_macro(18); \
|
||||
case 20: case_macro(20); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(dtype, 2); \
|
||||
case 4: case_macro(dtype, 4); \
|
||||
case 8: case_macro(dtype, 8); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_TYPES(case_macro) \
|
||||
switch (type) { \
|
||||
case CUDA_R_16BF: case_macro(nv_bfloat16); \
|
||||
case CUDA_R_32F: case_macro(float); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported type"); \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_HIDDEN(case_macro) \
|
||||
switch (hidden) { \
|
||||
case 2560: case_macro(2560); \
|
||||
case 5120: case_macro(5120); \
|
||||
case 7168: case_macro(7168); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
|
||||
} while (false)
|
||||
119
csrc/kernels/runtime.cu
Normal file
119
csrc/kernels/runtime.cu
Normal file
@@ -0,0 +1,119 @@
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace intranode {
|
||||
|
||||
template<int kNumRanks>
|
||||
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
}
|
||||
|
||||
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define BARRIER_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 32, stream);
|
||||
SWITCH_RANKS(BARRIER_LAUNCH_CASE);
|
||||
#undef BARRIER_LAUNCH_CASE
|
||||
}
|
||||
|
||||
} // namespace intranode
|
||||
|
||||
namespace internode {
|
||||
|
||||
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
|
||||
nvshmem_team_config_t cpu_rdma_team_config;
|
||||
|
||||
std::vector<uint8_t> get_unique_id() {
|
||||
nvshmemx_uniqueid_t unique_id;
|
||||
nvshmemx_get_uniqueid(&unique_id);
|
||||
std::vector<uint8_t> result(sizeof(nvshmemx_uniqueid_t));
|
||||
std::memcpy(result.data(), &unique_id, sizeof(nvshmemx_uniqueid_t));
|
||||
return result;
|
||||
}
|
||||
|
||||
__global__ void ibgda_initialize_recv_queue(int rank) {
|
||||
auto thread_idx = static_cast<int>(threadIdx.x);
|
||||
auto num_threads = static_cast<int>(blockDim.x);
|
||||
|
||||
auto dst_rank = static_cast<int>(blockIdx.x);
|
||||
if (dst_rank != rank) {
|
||||
for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) {
|
||||
auto qp = ibgda_get_rc(dst_rank, qp_id);
|
||||
|
||||
// Clean some necessary variables
|
||||
for (int i = 0; i < qp->rx_wq.nwqes; ++ i)
|
||||
ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i));
|
||||
qp->mvars.rx_wq.resv_head = 0;
|
||||
qp->mvars.rx_wq.cons_idx = 0;
|
||||
|
||||
// Allocate receive slots
|
||||
nvshmemi_ibgda_allocate_recvs(qp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
|
||||
nvshmemx_uniqueid_t root_unique_id;
|
||||
nvshmemx_init_attr_t attr;
|
||||
std::memcpy(&root_unique_id, root_unique_id_val.data(), sizeof(nvshmemx_uniqueid_t));
|
||||
nvshmemx_set_attr_uniqueid_args(rank, num_ranks, &root_unique_id, &attr);
|
||||
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
|
||||
|
||||
// Create sub-RDMA teams
|
||||
// NOTES: if `num_ranks <= NUM_MAX_NVL_PEERS` then only low-latency kernels are used
|
||||
if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) {
|
||||
EP_HOST_ASSERT(cpu_rdma_team == NVSHMEM_TEAM_INVALID);
|
||||
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
|
||||
EP_HOST_ASSERT(nvshmem_team_split_strided(NVSHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS,
|
||||
num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0);
|
||||
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
|
||||
}
|
||||
|
||||
// Normal operations use IBRC, while low-latency operations use IBGDA
|
||||
if (low_latency_mode) {
|
||||
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
|
||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||
|
||||
bool ibgda_is_initialized = false;
|
||||
cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice);
|
||||
|
||||
// Initialize recv queues for low-latency mode AR
|
||||
ibgda_initialize_recv_queue<<<num_ranks, 128>>>(rank);
|
||||
}
|
||||
nvshmem_barrier_all();
|
||||
return nvshmem_my_pe();
|
||||
}
|
||||
|
||||
void* alloc(size_t size, size_t alignment) {
|
||||
return nvshmem_align(alignment, size);
|
||||
}
|
||||
|
||||
void free(void* ptr) {
|
||||
nvshmem_free(ptr);
|
||||
}
|
||||
|
||||
void barrier() {
|
||||
nvshmem_barrier_all();
|
||||
}
|
||||
|
||||
void finalize() {
|
||||
if (cpu_rdma_team != NVSHMEM_TEAM_INVALID) {
|
||||
nvshmem_team_destroy(cpu_rdma_team);
|
||||
cpu_rdma_team = NVSHMEM_TEAM_INVALID;
|
||||
}
|
||||
nvshmem_finalize();
|
||||
}
|
||||
|
||||
} // namespace internode
|
||||
|
||||
} // namespace deep_ep
|
||||
381
csrc/kernels/utils.cuh
Normal file
381
csrc/kernels/utils.cuh
Normal file
@@ -0,0 +1,381 @@
|
||||
#pragma once
|
||||
|
||||
#include "exception.cuh"
|
||||
|
||||
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
|
||||
{ \
|
||||
constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \
|
||||
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
|
||||
auto __src = (SRC); \
|
||||
auto __dst = (DST); \
|
||||
for (int __i = (LANE_ID); __i < ((N) / kLoopStride) * kLoopStride; __i += kLoopStride) { \
|
||||
_Pragma("unroll") \
|
||||
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
|
||||
unrolled_values[__j] = LD_FUNC(__src + __i + __j * 32); \
|
||||
_Pragma("unroll") \
|
||||
for (int __j = 0; __j < (UNROLL_FACTOR); ++ __j) \
|
||||
ST_FUNC(__dst + __i + __j * 32, unrolled_values[__j]); \
|
||||
} \
|
||||
for (int __i = ((N) / kLoopStride) * kLoopStride + (LANE_ID); __i < (N); __i += 32) \
|
||||
ST_FUNC(__dst + __i, LD_FUNC(__src + __i)); \
|
||||
}
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
template <int kBytes>
|
||||
struct VecInt {};
|
||||
template<> struct VecInt<1> { using vec_t = int8_t; };
|
||||
template<> struct VecInt<2> { using vec_t = int16_t; };
|
||||
template<> struct VecInt<4> { using vec_t = int; };
|
||||
template<> struct VecInt<8> { using vec_t = int64_t; };
|
||||
template<> struct VecInt<16> { using vec_t = int4; };
|
||||
|
||||
__device__ __forceinline__ void trap() {
|
||||
asm("trap;");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void memory_fence() {
|
||||
asm volatile("fence.acq_rel.sys;":: : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void memory_fence_gpu() {
|
||||
asm volatile("fence.acq_rel.gpu;":: : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void memory_fence_cta() {
|
||||
asm volatile("fence.acq_rel.cta;":: : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_relaxed_sys_global(const int *ptr, int val) {
|
||||
asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_release_sys_global(const int *ptr, int val) {
|
||||
asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_release_cta(const int *ptr, int val) {
|
||||
asm volatile("st.release.cta.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_acquire_sys_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) {
|
||||
uint64_t ret;
|
||||
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_acquire_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) {
|
||||
int ret;
|
||||
asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) {
|
||||
int ret;
|
||||
asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_acquire_cta(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ld_na_relaxed(const uint8_t *ptr) {
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return static_cast<uint8_t>(ret);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint16_t ld_na_relaxed(const uint16_t *ptr) {
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t ld_na_relaxed(const uint32_t *ptr) {
|
||||
uint32_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t ld_na_relaxed(const uint64_t *ptr) {
|
||||
uint64_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int ld_volatile_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ld_volatile_global(const float *ptr) {
|
||||
float ret;
|
||||
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t ld_volatile_global(const int64_t *ptr) {
|
||||
int64_t ret;
|
||||
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) {
|
||||
int64_t ret;
|
||||
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
|
||||
#define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B"
|
||||
#else
|
||||
#define LD_NC_FUNC "ld.volatile.global"
|
||||
#endif
|
||||
|
||||
// `ld.global.nc.L1::no_allocate` will be translated into `LDG.E.NA.[width].CONSTANT` in SASS,
|
||||
// which does not have cache allocation, and `CONSTANT` memory does not have coherence control,
|
||||
// so we have to control them by queue semantics
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ dtype_t ld_nc_global(const dtype_t *ptr) {
|
||||
auto ret = ld_nc_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr));
|
||||
return *reinterpret_cast<dtype_t*>(&ret);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ uint8_t ld_nc_global(const uint8_t *ptr) {
|
||||
uint16_t ret;
|
||||
// NOTES: we must use `uint16_t` as inline ASM does not support 8-bit constraint letter (`h` below means unsigned 16-bit)
|
||||
asm volatile(LD_NC_FUNC ".u8 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return static_cast<uint8_t>(ret);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int ld_nc_global(const int *ptr) {
|
||||
int ret;
|
||||
asm volatile(LD_NC_FUNC ".s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int64_t ld_nc_global(const int64_t *ptr) {
|
||||
int64_t ret;
|
||||
asm volatile(LD_NC_FUNC ".s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ float ld_nc_global(const float *ptr) {
|
||||
float ret;
|
||||
asm volatile(LD_NC_FUNC ".f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int2 ld_nc_global(const int2 *ptr) {
|
||||
int2 ret;
|
||||
asm volatile(LD_NC_FUNC ".v2.s32 {%0, %1}, [%2];" : "=r"(ret.x), "=r"(ret.y) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) {
|
||||
int4 ret;
|
||||
asm volatile(LD_NC_FUNC ".v4.s32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const uint16_t *ptr, uint16_t val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b16 [%0], %1;" : : "l"(ptr), "h"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const uint32_t *ptr, uint32_t val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const int *ptr, int val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_relaxed(const int4 *ptr, int4 val) {
|
||||
asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};"
|
||||
: : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_release(const int *ptr, int val) {
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) {
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b32 [%0], %1;" : : "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) {
|
||||
asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val));
|
||||
}
|
||||
|
||||
// `st.global.L1::no_allocate` will be translated into `ST.E.NA.[width]` in SASS,
|
||||
// which does not have cache allocation (obviously in L1, I guess not in L2 too)
|
||||
#ifndef DISABLE_AGGRESSIVE_PTX_INSTRS
|
||||
#define ST_NA_FUNC "st.global.L1::no_allocate"
|
||||
#else
|
||||
#define ST_NA_FUNC "st.global"
|
||||
#endif
|
||||
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ void st_na_global(const dtype_t *ptr, const dtype_t& value) {
|
||||
st_na_global(reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(ptr),
|
||||
*reinterpret_cast<const typename VecInt<sizeof(dtype_t)>::vec_t*>(&value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const int *ptr, const int& value) {
|
||||
asm volatile(ST_NA_FUNC ".s32 [%0], %1;" ::"l"(ptr), "r"(value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const int64_t *ptr, const int64_t& value) {
|
||||
asm volatile(ST_NA_FUNC ".s64 [%0], %1;" ::"l"(ptr), "l"(value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const float *ptr, const float& value) {
|
||||
asm volatile(ST_NA_FUNC ".f32 [%0], %1;" ::"l"(ptr), "f"(value));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value) {
|
||||
asm volatile(ST_NA_FUNC ".v4.s32 [%0], {%1, %2, %3, %4};"
|
||||
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t align(dtype_t a, dtype_t b) {
|
||||
return cell_div<dtype_t>(a, b) * b;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void get_channel_task_range(int num_tokens, int num_sms, int sm_id,
|
||||
int& token_start_idx, int& token_end_idx) {
|
||||
int num_tokens_per_sm = cell_div(num_tokens, num_sms);
|
||||
token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens);
|
||||
token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens);
|
||||
}
|
||||
|
||||
template <typename dtype_a_t, typename dtype_b_t>
|
||||
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
|
||||
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
|
||||
dtype_b_t packed;
|
||||
auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
|
||||
unpacked_ptr[0] = x, unpacked_ptr[1] = y;
|
||||
return packed;
|
||||
}
|
||||
|
||||
template <typename dtype_a_t, typename dtype_b_t>
|
||||
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
|
||||
EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
|
||||
auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
|
||||
x = unpacked_ptr[0], y = unpacked_ptr[1];
|
||||
}
|
||||
|
||||
template <typename dtype_t>
|
||||
__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) {
|
||||
EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, "");
|
||||
auto send_int_values = reinterpret_cast<int*>(&ptr);
|
||||
int recv_int_values[sizeof(dtype_t) / sizeof(int)];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < sizeof(dtype_t) / sizeof(int); ++ i)
|
||||
recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], src_lane_idx);
|
||||
return *reinterpret_cast<dtype_t*>(recv_int_values);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int warp_reduce_sum(int value) {
|
||||
value += __shfl_xor_sync(0xffffffff, value, 16);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 8);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 4);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 2);
|
||||
value += __shfl_xor_sync(0xffffffff, value, 1);
|
||||
return value;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float half_warp_reduce_max(float value) {
|
||||
auto mask = __activemask();
|
||||
// The mask be in `{0xffffffff, 0xffff}`
|
||||
value = max(value, __shfl_xor_sync(mask, value, 8));
|
||||
value = max(value, __shfl_xor_sync(mask, value, 4));
|
||||
value = max(value, __shfl_xor_sync(mask, value, 2));
|
||||
value = max(value, __shfl_xor_sync(mask, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int get_lane_id() {
|
||||
int lane_id;
|
||||
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
|
||||
return lane_id;
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void move_fifo_slots(int &head) {
|
||||
head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS;
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__device__ __forceinline__ bool not_finished(int *task, int expected) {
|
||||
auto result = false;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
if (lane_id < kNumRanks)
|
||||
result = ld_volatile_global(task + lane_id) != expected;
|
||||
return __any_sync(0xffffffff, result);
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) {
|
||||
auto start_time = clock64();
|
||||
while (not_finished<kNumRanks>(task_fifo_ptrs[rank] + head, expected)) {
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) {
|
||||
printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumRanks>
|
||||
__forceinline__ __device__ void
|
||||
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
EP_DEVICE_ASSERT(kNumRanks <= 32);
|
||||
|
||||
if (thread_id < kNumRanks) {
|
||||
atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG);
|
||||
memory_fence();
|
||||
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
|
||||
}
|
||||
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
Reference in New Issue
Block a user