mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Support Ampere architecture (#204)
* Update README * Update `setup.py` * Fix headers * Add `DISABLE_NVSHMEM` for APIs * Fix launch * Fix TMA settings * Fix TMA usages * Fix dlink * Separate layout kernels * Update version * Add `is_sm90_compiled` * Fix tests * Add NVLink connection checks * Update README * Fix tests * Add some comments * Minor fix * Minor fix * Fix bugs
This commit is contained in:
parent
dd13c7145c
commit
b8d90fb753
20
README.md
20
README.md
@ -42,9 +42,11 @@ We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400
|
||||
|
||||
### Requirements
|
||||
|
||||
- Hopper GPUs (may support more architectures or devices later)
|
||||
- Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support
|
||||
- Python 3.8 and above
|
||||
- CUDA 12.3 and above
|
||||
- CUDA version
|
||||
- CUDA 11.0 and above for SM80 GPUs
|
||||
- CUDA 12.3 and above for SM90 GPUs
|
||||
- PyTorch 2.1 and above
|
||||
- NVLink for intranode communication
|
||||
- RDMA network for internode communication
|
||||
@ -75,6 +77,13 @@ python tests/test_low_latency.py
|
||||
NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
|
||||
```
|
||||
|
||||
#### Installation environment variables
|
||||
|
||||
- `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified
|
||||
- `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11
|
||||
- `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"`
|
||||
- `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefine behavior PTX usage](#undefined-behavior-ptx-usage) for more details
|
||||
|
||||
Then, import `deep_ep` in your Python project, and enjoy!
|
||||
|
||||
## Network configurations
|
||||
@ -286,11 +295,14 @@ For two-micro-batch overlapping, you can refer to the following figure. With our
|
||||
|
||||
- [x] AR support
|
||||
- [x] Refactor low-latency mode AR code
|
||||
- [ ] A100 support (intranode only)
|
||||
- [x] A100 support (intranode only)
|
||||
- [x] Support BF16 for the low-latency dispatch kernel
|
||||
- [x] Support NVLink protocol for intranode low-latency kernels
|
||||
- [ ] TMA copy instead of LD/ST
|
||||
- [ ] SM-free normal kernels and refactors
|
||||
- [x] Intranode kernels
|
||||
- [ ] Internode kernels
|
||||
- [ ] Low-latency kernels
|
||||
- [ ] SM-free kernels and refactors
|
||||
|
||||
## Notices
|
||||
|
||||
|
@ -56,7 +56,9 @@ struct Config {
|
||||
size_t num_bytes = 0;
|
||||
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
|
||||
#endif
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t);
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
|
||||
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float);
|
||||
@ -65,6 +67,7 @@ struct Config {
|
||||
}
|
||||
|
||||
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
// Legacy mode
|
||||
if (num_ranks <= NUM_MAX_NVL_PEERS)
|
||||
return 0;
|
||||
@ -88,6 +91,9 @@ struct Config {
|
||||
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
|
||||
num_bytes = ((num_bytes + 127) / 128) * 128;
|
||||
return num_bytes;
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cuda_runtime.h>
|
||||
#include <memory>
|
||||
@ -35,6 +34,9 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
|
||||
CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
|
||||
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
|
||||
#ifdef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
|
||||
// Get device info
|
||||
cudaDeviceProp device_prop = {};
|
||||
@ -104,12 +106,14 @@ Buffer::~Buffer() noexcept(false) {
|
||||
}
|
||||
|
||||
// Free NVSHMEM
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
if (num_rdma_bytes > 0) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
internode::barrier();
|
||||
internode::free(rdma_buffer_ptr);
|
||||
internode::finalize();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Free workspace and MoE counter
|
||||
CUDA_CHECK(cudaFree(workspace));
|
||||
@ -148,9 +152,13 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const {
|
||||
}
|
||||
|
||||
pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID");
|
||||
auto unique_id = internode::get_unique_id();
|
||||
return {reinterpret_cast<const char*>(unique_id.data()), unique_id.size()};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {
|
||||
@ -190,6 +198,7 @@ void Buffer::sync(const std::vector<int> &device_ids,
|
||||
}
|
||||
|
||||
// Sync NVSHMEM handles and allocate memory
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
if (num_rdma_bytes > 0) {
|
||||
// Initialize NVSHMEM
|
||||
EP_HOST_ASSERT(root_unique_id_opt.has_value());
|
||||
@ -211,6 +220,7 @@ void Buffer::sync(const std::vector<int> &device_ids,
|
||||
internode::barrier();
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
#endif
|
||||
|
||||
// Ready to use
|
||||
available = true;
|
||||
@ -246,13 +256,13 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
|
||||
if (is_internode_available())
|
||||
num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
internode::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
|
||||
num_tokens_per_rank.data_ptr<int>(),
|
||||
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
|
||||
num_tokens_per_expert.data_ptr<int>(),
|
||||
is_token_in_rank.data_ptr<bool>(),
|
||||
num_tokens, num_topk, num_ranks, num_experts,
|
||||
comm_stream);
|
||||
layout::get_dispatch_layout(topk_idx.data_ptr<int64_t>(),
|
||||
num_tokens_per_rank.data_ptr<int>(),
|
||||
num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr<int>() : nullptr,
|
||||
num_tokens_per_expert.data_ptr<int>(),
|
||||
is_token_in_rank.data_ptr<bool>(),
|
||||
num_tokens, num_topk, num_ranks, num_experts,
|
||||
comm_stream);
|
||||
|
||||
// Wait streams
|
||||
std::optional<EventHandle> event;
|
||||
@ -620,6 +630,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
|
||||
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
|
||||
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
// In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long.
|
||||
// If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL
|
||||
// unless we release GIL here.
|
||||
@ -882,6 +893,10 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
|
||||
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
|
||||
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
|
||||
recv_src_meta, send_rdma_head, send_nvl_head, event};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
|
||||
@ -890,6 +905,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
|
||||
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
|
||||
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
const int num_channels = config.num_sms / 2;
|
||||
EP_HOST_ASSERT(config.num_sms % 2 == 0);
|
||||
|
||||
@ -998,14 +1014,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
|
||||
|
||||
// Return values
|
||||
return {combined_x, combined_topk_weights, event};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
uint64_t Buffer::get_low_latency_usage_flag() const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_usage_flag != nullptr);
|
||||
return reinterpret_cast<uint64_t>(low_latency_usage_flag);
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
@ -1022,6 +1048,9 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
|
||||
internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second,
|
||||
clean_meta_1.first, clean_meta_1.second,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
@ -1029,6 +1058,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool use_fp8, bool async, bool return_recv_hook) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@ -1114,6 +1144,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
|
||||
|
||||
// Return values
|
||||
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
|
||||
@ -1122,6 +1156,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool zero_copy, bool async, bool return_recv_hook,
|
||||
const std::optional<torch::Tensor>& out) {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@ -1141,7 +1176,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
|
||||
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
|
||||
auto hidden = static_cast<int>(x.size(2));
|
||||
auto num_local_experts = num_experts / num_ranks, num_topk = static_cast<int>(topk_weights.size(1));
|
||||
auto num_topk = static_cast<int>(topk_weights.size(1));
|
||||
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
|
||||
|
||||
// Buffer control
|
||||
@ -1202,10 +1237,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
|
||||
// Return values
|
||||
return {combined_x, event, recv_hook};
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
torch::Tensor
|
||||
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
|
||||
|
||||
auto buffer = layout.buffers[low_latency_buffer_idx];
|
||||
@ -1217,6 +1257,18 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank
|
||||
{num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
|
||||
{num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1},
|
||||
torch::TensorOptions().dtype(dtype).device(torch::kCUDA));
|
||||
#else
|
||||
EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_sm90_compiled() {
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
@ -1258,4 +1310,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
|
||||
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
|
||||
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);
|
||||
|
||||
m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
|
||||
}
|
||||
|
@ -11,10 +11,11 @@ function(add_deep_ep_library target_name source_file)
|
||||
target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5)
|
||||
endfunction()
|
||||
|
||||
add_deep_ep_library(intranode_cuda intranode.cu)
|
||||
add_deep_ep_library(runtime_cuda runtime.cu)
|
||||
add_deep_ep_library(layout_cuda layout.cu)
|
||||
add_deep_ep_library(intranode_cuda intranode.cu)
|
||||
add_deep_ep_library(internode_cuda internode.cu)
|
||||
add_deep_ep_library(internode_ll_cuda internode_ll.cu)
|
||||
|
||||
# Later, we should link all libraries in `EP_CUDA_LIBRARIES`
|
||||
set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
|
||||
set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE)
|
||||
|
@ -28,6 +28,17 @@ void finalize();
|
||||
|
||||
} // namespace internode
|
||||
|
||||
// Layout kernels
|
||||
namespace layout {
|
||||
|
||||
void get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace layout
|
||||
|
||||
// Intranode kernels
|
||||
namespace intranode {
|
||||
|
||||
@ -69,12 +80,6 @@ namespace internode {
|
||||
|
||||
int get_source_meta_bytes();
|
||||
|
||||
void get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
|
||||
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
|
@ -37,11 +37,25 @@
|
||||
#undef __CUDA_NO_BFLOAT162_OPERATORS__
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
#include <cuda_fp8.h>
|
||||
#else
|
||||
// Ampere does not support FP8 features
|
||||
#define __NV_E4M3 0
|
||||
#define __NV_E5M2 1
|
||||
typedef int __nv_fp8_interpretation_t;
|
||||
typedef int __nv_fp8x4_e4m3;
|
||||
typedef uint8_t __nv_fp8_storage_t;
|
||||
#endif
|
||||
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
#include <nvshmem.h>
|
||||
#include <nvshmemx.h>
|
||||
#include <infiniband/mlx5dv.h>
|
||||
#include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh>
|
||||
#include <device_host_transport/nvshmem_common_ibgda.h>
|
||||
#endif
|
||||
|
@ -11,131 +11,6 @@ namespace internode {
|
||||
|
||||
extern nvshmem_team_t cpu_rdma_team;
|
||||
|
||||
template<int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
|
||||
// Count expert statistics
|
||||
__shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
|
||||
int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
|
||||
if (expert_begin_idx < expert_end_idx) {
|
||||
// Per-thread count
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumExpertsPerSM; ++ i)
|
||||
num_tokens_per_expert_per_thread[thread_id][i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
|
||||
auto shifted_topk_idx = topk_idx + i * num_topk;
|
||||
#pragma unroll
|
||||
for (int j = 0, expert_idx; j < num_topk; ++ j) {
|
||||
expert_idx = static_cast<int>(shifted_topk_idx[j]);
|
||||
if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
|
||||
++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Sum up
|
||||
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
|
||||
if (expert_begin_idx + thread_id < expert_end_idx) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumThreads; ++ i)
|
||||
sum += num_tokens_per_expert_per_thread[i][thread_id];
|
||||
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_tokens_per_rdma_rank != nullptr)
|
||||
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
|
||||
|
||||
// Count rank statistics
|
||||
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
|
||||
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
|
||||
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
|
||||
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
|
||||
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
|
||||
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
|
||||
if (rank_begin_idx < rank_end_idx) {
|
||||
const auto num_expert_per_rank = num_experts / num_ranks;
|
||||
auto expert_begin = rank_begin_idx * num_expert_per_rank;
|
||||
auto expert_end = rank_end_idx * num_expert_per_rank;
|
||||
|
||||
// Per-thread count
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanksPerSM; ++ i)
|
||||
num_tokens_per_rank_per_thread[thread_id][i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
|
||||
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
|
||||
auto shifted_topk_idx = topk_idx + i * num_topk;
|
||||
int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
|
||||
#pragma unroll
|
||||
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
|
||||
expert_idx = static_cast<int>(shifted_topk_idx[j]);
|
||||
if (expert_begin <= expert_idx and expert_idx < expert_end) {
|
||||
// Count single rank
|
||||
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
|
||||
is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
|
||||
}
|
||||
}
|
||||
|
||||
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
|
||||
#pragma unroll
|
||||
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
|
||||
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
|
||||
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
|
||||
num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Sum up
|
||||
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
|
||||
if (rank_begin_idx + thread_id < rank_end_idx) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumThreads; ++ i)
|
||||
sum += num_tokens_per_rank_per_thread[i][thread_id];
|
||||
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
|
||||
}
|
||||
|
||||
if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumThreads; ++ i)
|
||||
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
|
||||
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
|
||||
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
|
||||
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
|
||||
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
|
||||
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
|
||||
num_tokens, num_topk, num_ranks, num_experts);
|
||||
}
|
||||
|
||||
struct SourceMeta {
|
||||
int src_rdma_rank, is_token_in_nvl_rank_bits;
|
||||
|
||||
|
@ -227,6 +227,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
auto channel_x_scales_buffers = Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales);
|
||||
|
||||
// TMA stuffs
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto half_hidden_int4 = hidden_int4 / 2;
|
||||
auto half_hidden_bytes = half_hidden_int4 * static_cast<int>(sizeof(int4));
|
||||
@ -240,6 +241,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
|
||||
}
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
@ -399,6 +401,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
|
||||
auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
|
||||
auto shifted_recv_x_int4 = recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++ i) if (lane_id == 0) {
|
||||
tma_store_wait();
|
||||
@ -408,6 +411,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
#else
|
||||
UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4,
|
||||
ld_nc_global, st_na_global);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Copy `src_idx`
|
||||
@ -447,8 +454,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
}
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -473,12 +482,13 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 8192;
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
#define DISPATCH_LAUNCH_CASE(ranks) { \
|
||||
auto kernel = dispatch<ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<int4*>(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \
|
||||
send_head, reinterpret_cast<const int4*>(x), x_scales, topk_idx, topk_weights, \
|
||||
@ -587,8 +597,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
auto recv_int4 = reinterpret_cast<int4*>(recv_x);
|
||||
|
||||
// TMA stuffs
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp;
|
||||
#endif
|
||||
|
||||
if (is_sender) {
|
||||
// Workers for sending
|
||||
@ -778,9 +790,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
}
|
||||
|
||||
// Wait shared memory release
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
// Reduce data with pipeline
|
||||
constexpr int kNumStages = 8;
|
||||
@ -810,6 +824,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
for (int j = 0; j < kDtypePerInt4; ++ j)
|
||||
out_dtypes[j] = static_cast<dtype_t>(values[j]);
|
||||
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
// Wait TMA arrival
|
||||
if (lane_id == 0)
|
||||
tma_store_wait<kNumStages - 1>();
|
||||
@ -828,6 +843,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false);
|
||||
}
|
||||
__syncwarp();
|
||||
#else
|
||||
recv_int4[token_idx * hidden_int4 + i] = out_int4;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Reduce `topk_weights`
|
||||
@ -850,8 +868,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
warp_retired[recv_warp_id] = true;
|
||||
|
||||
// Make TMA store visible to the next kernel
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
if (lane_id == 0)
|
||||
tma_store_wait();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -866,12 +886,13 @@ void combine(cudaDataType_t type,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
constexpr int kNumThreads = 768;
|
||||
constexpr int kNumTMABytesPerWarp = 4096;
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32);
|
||||
#endif
|
||||
|
||||
#define COMBINE_LAUNCH_CASE(dtype, ranks) { \
|
||||
auto kernel = combine<dtype, ranks, kNumThreads, kNumTMABytesPerWarp>; \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size; \
|
||||
SET_SHARED_MEMORY_FOR_TMA(kernel); \
|
||||
LAUNCH_KERNEL(&cfg, kernel, \
|
||||
reinterpret_cast<dtype*>(recv_x), recv_topk_weights, \
|
||||
reinterpret_cast<const dtype*>(x), topk_weights, \
|
||||
|
@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
|
||||
#ifndef SETUP_LAUNCH_CONFIG
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \
|
||||
cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \
|
||||
cudaLaunchAttribute attr[1]; \
|
||||
@ -10,10 +12,39 @@
|
||||
attr[0].val.cooperative = 1; \
|
||||
cfg.attrs = attr; \
|
||||
cfg.numAttrs = 1
|
||||
#else
|
||||
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
|
||||
int __num_sms = (sms); \
|
||||
int __num_threads = (threads); \
|
||||
auto __stream = (stream)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef LAUNCH_KERNEL
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
#define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__))
|
||||
#else
|
||||
#define LAUNCH_KERNEL(config, kernel, ...) \
|
||||
do { \
|
||||
kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__); \
|
||||
cudaError_t e = cudaGetLastError(); \
|
||||
if (e != cudaSuccess) { \
|
||||
EPException cuda_exception("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \
|
||||
fprintf(stderr, "%s\n", cuda_exception.what()); \
|
||||
throw cuda_exception; \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef SET_SHARED_MEMORY_FOR_TMA
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
#define SET_SHARED_MEMORY_FOR_TMA(kernel) \
|
||||
EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \
|
||||
cfg.dynamicSmemBytes = smem_size;
|
||||
#else
|
||||
#define SET_SHARED_MEMORY_FOR_TMA(kernel) void()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define SWITCH_RANKS(case_macro) \
|
||||
|
136
csrc/kernels/layout.cu
Normal file
136
csrc/kernels/layout.cu
Normal file
@ -0,0 +1,136 @@
|
||||
#include "configs.cuh"
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
namespace layout {
|
||||
|
||||
template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x);
|
||||
|
||||
// Count expert statistics
|
||||
__shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM];
|
||||
int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
|
||||
if (expert_begin_idx < expert_end_idx) {
|
||||
// Per-thread count
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumExpertsPerSM; ++ i)
|
||||
num_tokens_per_expert_per_thread[thread_id][i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
|
||||
auto shifted_topk_idx = topk_idx + i * num_topk;
|
||||
#pragma unroll
|
||||
for (int j = 0, expert_idx; j < num_topk; ++ j) {
|
||||
expert_idx = static_cast<int>(shifted_topk_idx[j]);
|
||||
if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx)
|
||||
++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Sum up
|
||||
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM");
|
||||
if (expert_begin_idx + thread_id < expert_end_idx) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumThreads; ++ i)
|
||||
sum += num_tokens_per_expert_per_thread[i][thread_id];
|
||||
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_tokens_per_rdma_rank != nullptr)
|
||||
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS);
|
||||
|
||||
// Count rank statistics
|
||||
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
|
||||
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
|
||||
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM];
|
||||
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
|
||||
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
|
||||
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
|
||||
if (rank_begin_idx < rank_end_idx) {
|
||||
const auto num_expert_per_rank = num_experts / num_ranks;
|
||||
auto expert_begin = rank_begin_idx * num_expert_per_rank;
|
||||
auto expert_end = rank_end_idx * num_expert_per_rank;
|
||||
|
||||
// Per-thread count
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRanksPerSM; ++ i)
|
||||
num_tokens_per_rank_per_thread[thread_id][i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRDMARanksPerSM; ++ i)
|
||||
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
|
||||
auto shifted_topk_idx = topk_idx + i * num_topk;
|
||||
int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
|
||||
#pragma unroll
|
||||
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
|
||||
expert_idx = static_cast<int>(shifted_topk_idx[j]);
|
||||
if (expert_begin <= expert_idx and expert_idx < expert_end) {
|
||||
// Count single rank
|
||||
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
|
||||
is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++;
|
||||
}
|
||||
}
|
||||
|
||||
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
|
||||
#pragma unroll
|
||||
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) {
|
||||
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
|
||||
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j)
|
||||
num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Sum up
|
||||
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
|
||||
if (rank_begin_idx + thread_id < rank_end_idx) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumThreads; ++ i)
|
||||
sum += num_tokens_per_rank_per_thread[i][thread_id];
|
||||
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
|
||||
}
|
||||
|
||||
if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
|
||||
int sum = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumThreads; ++ i)
|
||||
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
|
||||
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_dispatch_layout(const int64_t* topk_idx,
|
||||
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
|
||||
int* num_tokens_per_expert, bool* is_token_in_rank,
|
||||
int num_tokens, int num_topk, int num_ranks, int num_experts,
|
||||
cudaStream_t stream) {
|
||||
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
|
||||
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
|
||||
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM");
|
||||
|
||||
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
|
||||
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
|
||||
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
|
||||
num_tokens, num_topk, num_ranks, num_experts);
|
||||
}
|
||||
|
||||
} // namespace layout
|
||||
|
||||
} // namespace deep_ep
|
@ -5,7 +5,10 @@
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
#include "ibgda_device.cuh"
|
||||
#endif
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
@ -30,6 +33,7 @@ void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t st
|
||||
|
||||
namespace internode {
|
||||
|
||||
#ifndef DISABLE_NVSHMEM
|
||||
nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID;
|
||||
nvshmem_team_config_t cpu_rdma_team_config;
|
||||
|
||||
@ -81,6 +85,7 @@ void finalize() {
|
||||
}
|
||||
nvshmem_finalize();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace internode
|
||||
|
||||
|
@ -266,6 +266,9 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
|
||||
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
|
||||
}
|
||||
|
||||
// TMA PTX instructions
|
||||
#ifndef DISABLE_SM90_FEATURES
|
||||
|
||||
__device__ __forceinline__ void fence_view_async_shared() {
|
||||
asm volatile("fence.proxy.async.shared::cta; \n" :: );
|
||||
}
|
||||
@ -327,6 +330,8 @@ __device__ __forceinline__ void tma_store_wait() {
|
||||
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename dtype_t>
|
||||
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
|
||||
return (a + b - 1) / b;
|
||||
|
@ -7,7 +7,7 @@ from typing import Callable, List, Tuple, Optional, Union
|
||||
import deep_ep_cpp
|
||||
# noinspection PyUnresolvedReferences
|
||||
from deep_ep_cpp import Config, EventHandle
|
||||
from .utils import EventOverlap
|
||||
from .utils import EventOverlap, check_nvlink_connections
|
||||
|
||||
|
||||
class Buffer:
|
||||
@ -50,6 +50,7 @@ class Buffer:
|
||||
please make sure all connections are via NVLink.
|
||||
allow_mnnvl: whether to allow MNNVL
|
||||
"""
|
||||
check_nvlink_connections(group)
|
||||
|
||||
# Initialize the CPP runtime
|
||||
self.rank = group.rank()
|
||||
@ -105,6 +106,10 @@ class Buffer:
|
||||
self.runtime.sync(device_ids, ipc_handles, root_unique_id)
|
||||
assert self.runtime.is_available()
|
||||
|
||||
@staticmethod
|
||||
def is_sm90_compiled():
|
||||
return deep_ep_cpp.is_sm90_compiled()
|
||||
|
||||
@staticmethod
|
||||
def set_num_sms(new_num_sms: int) -> None:
|
||||
"""
|
||||
|
@ -1,4 +1,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
@ -58,3 +61,28 @@ class EventOverlap:
|
||||
"""
|
||||
if self.event is not None:
|
||||
self.event.current_stream_wait()
|
||||
|
||||
|
||||
def check_nvlink_connections(group: dist.ProcessGroup):
|
||||
"""
|
||||
Check NVLink connection between every pair of GPUs.
|
||||
|
||||
Arguments:
|
||||
group: the communication group.
|
||||
"""
|
||||
# Check NVLink connection
|
||||
# NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2
|
||||
if 'PCIE' in torch.cuda.get_device_name():
|
||||
assert group.size() <= 2, 'No NVLink connection between all GPUs'
|
||||
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',')
|
||||
physical_device_idx = int(devices[torch.cuda.current_device()])
|
||||
physical_device_indices = [0, ] * group.size()
|
||||
dist.all_gather_object(physical_device_indices, physical_device_idx, group)
|
||||
|
||||
# Get connection matrix from `nvidia-smi`
|
||||
lines = subprocess.check_output(['nvidia-smi', 'topo', '-p2p', 'n']).decode('utf-8').split('\n')
|
||||
for line in lines:
|
||||
if line.lstrip().startswith(f'GPU{physical_device_idx}') and 'X' in line:
|
||||
status = line.strip().lstrip(f'GPU{physical_device_idx}').split()
|
||||
for dst_gpu_rank in physical_device_indices:
|
||||
assert status[dst_gpu_rank] in ('X', 'OK'), f'No NVLink connection between GPU {physical_device_idx} and GPU {dst_gpu_rank}'
|
||||
|
74
setup.py
74
setup.py
@ -6,34 +6,76 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
if __name__ == '__main__':
|
||||
nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
|
||||
assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM'
|
||||
print(f'NVSHMEM directory: {nvshmem_dir}')
|
||||
disable_nvshmem = nvshmem_dir is None
|
||||
if disable_nvshmem:
|
||||
print('Warning: `NVSHMEM_DIR` is not specified, all internode and low-latency features are disabled\n')
|
||||
else:
|
||||
assert os.path.exists(nvshmem_dir), f'Failed to find NVSHMEM: {nvshmem_dir}'
|
||||
|
||||
# TODO: currently, we only support Hopper architecture, we may add Ampere support later
|
||||
if os.getenv('TORCH_CUDA_ARCH_LIST', None) is None:
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
|
||||
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
|
||||
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
|
||||
nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10']
|
||||
include_dirs = ['csrc/', f'{nvshmem_dir}/include']
|
||||
sources = ['csrc/deep_ep.cpp',
|
||||
'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu',
|
||||
'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
|
||||
library_dirs = [f'{nvshmem_dir}/lib']
|
||||
nvcc_flags = ['-O3', '-Xcompiler', '-O3']
|
||||
sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu']
|
||||
include_dirs = ['csrc/']
|
||||
library_dirs = []
|
||||
nvcc_dlink = []
|
||||
extra_link_args = []
|
||||
|
||||
# NVSHMEM flags
|
||||
if disable_nvshmem:
|
||||
cxx_flags.append('-DDISABLE_NVSHMEM')
|
||||
nvcc_flags.append('-DDISABLE_NVSHMEM')
|
||||
else:
|
||||
sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'])
|
||||
include_dirs.extend([f'{nvshmem_dir}/include'])
|
||||
library_dirs.extend([f'{nvshmem_dir}/lib'])
|
||||
nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem'])
|
||||
extra_link_args.extend(['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib'])
|
||||
|
||||
if int(os.getenv('DISABLE_SM90_FEATURES', 0)):
|
||||
# Prefer A100
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0')
|
||||
|
||||
# Disable some SM90 features: FP8, launch methods, and TMA
|
||||
cxx_flags.append('-DDISABLE_SM90_FEATURES')
|
||||
nvcc_flags.append('-DDISABLE_SM90_FEATURES')
|
||||
|
||||
# Disable internode and low-latency kernels
|
||||
assert disable_nvshmem
|
||||
|
||||
# Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate`
|
||||
assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1
|
||||
os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
|
||||
else:
|
||||
# Prefer H800 series
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0')
|
||||
|
||||
# CUDA 12 flags
|
||||
nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10'])
|
||||
|
||||
# Disable aggressive PTX instructions
|
||||
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')):
|
||||
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
|
||||
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
|
||||
|
||||
# Disable DLTO (default by PyTorch)
|
||||
nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']
|
||||
extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']
|
||||
# Put them together
|
||||
extra_compile_args = {
|
||||
'cxx': cxx_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
'nvcc_dlink': nvcc_dlink
|
||||
}
|
||||
if len(nvcc_dlink) > 0:
|
||||
extra_compile_args['nvcc_dlink'] = nvcc_dlink
|
||||
|
||||
# Summary
|
||||
print(f'Build summary:')
|
||||
print(f' > Sources: {sources}')
|
||||
print(f' > Includes: {include_dirs}')
|
||||
print(f' > Libraries: {library_dirs}')
|
||||
print(f' > Compilation flags: {extra_compile_args}')
|
||||
print(f' > Link flags: {extra_link_args}')
|
||||
print(f' > Arch list: {os.environ["TORCH_CUDA_ARCH_LIST"]}')
|
||||
print(f' > NVSHMEM path: {nvshmem_dir}')
|
||||
print()
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -44,7 +86,7 @@ if __name__ == '__main__':
|
||||
|
||||
setuptools.setup(
|
||||
name='deep_ep',
|
||||
version='1.0.0' + revision,
|
||||
version='1.1.0' + revision,
|
||||
packages=setuptools.find_packages(
|
||||
include=['deep_ep']
|
||||
),
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -21,7 +20,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None
|
||||
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank
|
||||
@ -80,7 +79,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='')
|
||||
@ -168,7 +167,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)):
|
||||
best_time, best_results = 1e10, None
|
||||
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
|
||||
for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
|
||||
@ -189,8 +188,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
|
||||
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
|
||||
print('', flush=True)
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
# Gather the best config from rank 0 and the first test setting
|
||||
if best_dispatch_results is None:
|
||||
best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda')
|
||||
all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())]
|
||||
dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)
|
||||
|
Loading…
Reference in New Issue
Block a user