From b8d90fb7531d6ef1d53e36d6c5819030dc633633 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 11 Jun 2025 15:48:18 +0800 Subject: [PATCH] Support Ampere architecture (#204) * Update README * Update `setup.py` * Fix headers * Add `DISABLE_NVSHMEM` for APIs * Fix launch * Fix TMA settings * Fix TMA usages * Fix dlink * Separate layout kernels * Update version * Add `is_sm90_compiled` * Fix tests * Add NVLink connection checks * Update README * Fix tests * Add some comments * Minor fix * Minor fix * Fix bugs --- README.md | 20 ++++-- csrc/config.hpp | 6 ++ csrc/deep_ep.cpp | 72 ++++++++++++++++--- csrc/kernels/CMakeLists.txt | 5 +- csrc/kernels/api.cuh | 17 +++-- csrc/kernels/configs.cuh | 16 ++++- csrc/kernels/internode.cu | 125 --------------------------------- csrc/kernels/intranode.cu | 29 ++++++-- csrc/kernels/launch.cuh | 31 ++++++++ csrc/kernels/layout.cu | 136 ++++++++++++++++++++++++++++++++++++ csrc/kernels/runtime.cu | 5 ++ csrc/kernels/utils.cuh | 5 ++ deep_ep/buffer.py | 7 +- deep_ep/utils.py | 28 ++++++++ setup.py | 74 +++++++++++++++----- tests/test_intranode.py | 11 ++- 16 files changed, 413 insertions(+), 174 deletions(-) create mode 100644 csrc/kernels/layout.cu diff --git a/README.md b/README.md index fafe9d9..26123e7 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,11 @@ We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 ### Requirements -- Hopper GPUs (may support more architectures or devices later) +- Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support - Python 3.8 and above -- CUDA 12.3 and above +- CUDA version + - CUDA 11.0 and above for SM80 GPUs + - CUDA 12.3 and above for SM90 GPUs - PyTorch 2.1 and above - NVLink for intranode communication - RDMA network for internode communication @@ -75,6 +77,13 @@ python tests/test_low_latency.py NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install ``` +#### Installation environment variables + +- `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified +- `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11 +- `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"` +- `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefine behavior PTX usage](#undefined-behavior-ptx-usage) for more details + Then, import `deep_ep` in your Python project, and enjoy! ## Network configurations @@ -286,11 +295,14 @@ For two-micro-batch overlapping, you can refer to the following figure. With our - [x] AR support - [x] Refactor low-latency mode AR code -- [ ] A100 support (intranode only) +- [x] A100 support (intranode only) - [x] Support BF16 for the low-latency dispatch kernel - [x] Support NVLink protocol for intranode low-latency kernels - [ ] TMA copy instead of LD/ST -- [ ] SM-free normal kernels and refactors + - [x] Intranode kernels + - [ ] Internode kernels + - [ ] Low-latency kernels +- [ ] SM-free kernels and refactors ## Notices diff --git a/csrc/config.hpp b/csrc/config.hpp index ec74564..b6ffd60 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -56,7 +56,9 @@ struct Config { size_t num_bytes = 0; num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; +#ifndef DISABLE_NVSHMEM num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); +#endif num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); @@ -65,6 +67,7 @@ struct Config { } size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { +#ifndef DISABLE_NVSHMEM // Legacy mode if (num_ranks <= NUM_MAX_NVL_PEERS) return 0; @@ -88,6 +91,9 @@ struct Config { num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; num_bytes = ((num_bytes + 127) / 128) * 128; return num_bytes; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); +#endif } }; diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 66c1964..ce47fcd 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -35,6 +34,9 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ CUDA_CHECK(cudaGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); +#ifdef DISABLE_NVSHMEM + EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disable during compilation"); +#endif // Get device info cudaDeviceProp device_prop = {}; @@ -104,12 +106,14 @@ Buffer::~Buffer() noexcept(false) { } // Free NVSHMEM +#ifndef DISABLE_NVSHMEM if (num_rdma_bytes > 0) { CUDA_CHECK(cudaDeviceSynchronize()); internode::barrier(); internode::free(rdma_buffer_ptr); internode::finalize(); } +#endif // Free workspace and MoE counter CUDA_CHECK(cudaFree(workspace)); @@ -148,9 +152,13 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const { } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); +#endif } torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { @@ -190,6 +198,7 @@ void Buffer::sync(const std::vector &device_ids, } // Sync NVSHMEM handles and allocate memory +#ifndef DISABLE_NVSHMEM if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); @@ -211,6 +220,7 @@ void Buffer::sync(const std::vector &device_ids, internode::barrier(); CUDA_CHECK(cudaDeviceSynchronize()); } +#endif // Ready to use available = true; @@ -246,13 +256,13 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, if (is_internode_available()) num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - internode::get_dispatch_layout(topk_idx.data_ptr(), - num_tokens_per_rank.data_ptr(), - num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, - num_tokens_per_expert.data_ptr(), - is_token_in_rank.data_ptr(), - num_tokens, num_topk, num_ranks, num_experts, - comm_stream); + layout::get_dispatch_layout(topk_idx.data_ptr(), + num_tokens_per_rank.data_ptr(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, + num_tokens_per_expert.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, num_topk, num_ranks, num_experts, + comm_stream); // Wait streams std::optional event; @@ -620,6 +630,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL // unless we release GIL here. @@ -882,6 +893,10 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional, std::optional> @@ -890,6 +905,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); @@ -998,14 +1014,24 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(low_latency_usage_flag); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); + return 0; +#endif } void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); @@ -1022,6 +1048,9 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second, clean_meta_1.first, clean_meta_1.second, at::cuda::getCurrentCUDAStream()); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); +#endif } std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> @@ -1029,6 +1058,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i const std::optional& cumulative_local_expert_recv_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, bool return_recv_hook) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1114,6 +1144,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); + return {}; +#endif } std::tuple, std::optional>> @@ -1122,6 +1156,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id int num_max_dispatch_tokens_per_rank, int num_experts, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1141,7 +1176,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); auto hidden = static_cast(x.size(2)); - auto num_local_experts = num_experts / num_ranks, num_topk = static_cast(topk_weights.size(1)); + auto num_topk = static_cast(topk_weights.size(1)); auto num_combined_tokens = static_cast(topk_weights.size(0)); // Buffer control @@ -1202,10 +1237,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id // Return values return {combined_x, event, recv_hook}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); + return {}; +#endif } torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { +#ifndef DISABLE_NVSHMEM LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); auto buffer = layout.buffers[low_latency_buffer_idx]; @@ -1217,6 +1257,18 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); + return {}; +#endif +} + +bool is_sm90_compiled() { +#ifndef DISABLE_SM90_FEATURES + return true; +#else + return false; +#endif } } // namespace deep_ep @@ -1258,4 +1310,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); + + m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); } diff --git a/csrc/kernels/CMakeLists.txt b/csrc/kernels/CMakeLists.txt index 31e05ce..af5979a 100644 --- a/csrc/kernels/CMakeLists.txt +++ b/csrc/kernels/CMakeLists.txt @@ -11,10 +11,11 @@ function(add_deep_ep_library target_name source_file) target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5) endfunction() -add_deep_ep_library(intranode_cuda intranode.cu) add_deep_ep_library(runtime_cuda runtime.cu) +add_deep_ep_library(layout_cuda layout.cu) +add_deep_ep_library(intranode_cuda intranode.cu) add_deep_ep_library(internode_cuda internode.cu) add_deep_ep_library(internode_ll_cuda internode_ll.cu) # Later, we should link all libraries in `EP_CUDA_LIBRARIES` -set(EP_CUDA_LIBRARIES intranode_cuda runtime_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) +set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index d10044e..8f12a52 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -28,6 +28,17 @@ void finalize(); } // namespace internode +// Layout kernels +namespace layout { + +void get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, + int num_tokens, int num_topk, int num_ranks, int num_experts, + cudaStream_t stream); + +} // namespace layout + // Intranode kernels namespace intranode { @@ -69,12 +80,6 @@ namespace internode { int get_source_meta_bytes(); -void get_dispatch_layout(const int64_t* topk_idx, - int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, - int* num_tokens_per_expert, bool* is_token_in_rank, - int num_tokens, int num_topk, int num_ranks, int num_experts, - cudaStream_t stream); - void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index cc1f914..8893b79 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -37,11 +37,25 @@ #undef __CUDA_NO_BFLOAT162_OPERATORS__ #endif +#include #include -#include #include + +#ifndef DISABLE_SM90_FEATURES +#include +#else +// Ampere does not support FP8 features +#define __NV_E4M3 0 +#define __NV_E5M2 1 +typedef int __nv_fp8_interpretation_t; +typedef int __nv_fp8x4_e4m3; +typedef uint8_t __nv_fp8_storage_t; +#endif + +#ifndef DISABLE_NVSHMEM #include #include #include #include #include +#endif diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 35b4ae2..f5bb0a5 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -11,131 +11,6 @@ namespace internode { extern nvshmem_team_t cpu_rdma_team; -template -__global__ void __launch_bounds__(kNumThreads, 1) -get_dispatch_layout(const int64_t* topk_idx, - int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, - int* num_tokens_per_expert, bool* is_token_in_rank, - int num_tokens, int num_topk, int num_ranks, int num_experts) { - auto sm_id = static_cast(blockIdx.x); - auto thread_id = static_cast(threadIdx.x); - - // Count expert statistics - __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; - int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); - if (expert_begin_idx < expert_end_idx) { - // Per-thread count - #pragma unroll - for (int i = 0; i < kNumExpertsPerSM; ++ i) - num_tokens_per_expert_per_thread[thread_id][i] = 0; - #pragma unroll - for (int i = thread_id; i < num_tokens; i += kNumThreads) { - auto shifted_topk_idx = topk_idx + i * num_topk; - #pragma unroll - for (int j = 0, expert_idx; j < num_topk; ++ j) { - expert_idx = static_cast(shifted_topk_idx[j]); - if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) - ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; - } - } - __syncthreads(); - - // Sum up - EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); - if (expert_begin_idx + thread_id < expert_end_idx) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumThreads; ++ i) - sum += num_tokens_per_expert_per_thread[i][thread_id]; - num_tokens_per_expert[expert_begin_idx + thread_id] = sum; - } - return; - } - - if (num_tokens_per_rdma_rank != nullptr) - EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); - - // Count rank statistics - constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; - __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; - __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; - auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; - int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); - int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; - if (rank_begin_idx < rank_end_idx) { - const auto num_expert_per_rank = num_experts / num_ranks; - auto expert_begin = rank_begin_idx * num_expert_per_rank; - auto expert_end = rank_end_idx * num_expert_per_rank; - - // Per-thread count - #pragma unroll - for (int i = 0; i < kNumRanksPerSM; ++ i) - num_tokens_per_rank_per_thread[thread_id][i] = 0; - #pragma unroll - for (int i = 0; i < kNumRDMARanksPerSM; ++ i) - num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; - #pragma unroll - for (int i = thread_id; i < num_tokens; i += kNumThreads) { - auto shifted_topk_idx = topk_idx + i * num_topk; - int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; - #pragma unroll - for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { - expert_idx = static_cast(shifted_topk_idx[j]); - if (expert_begin <= expert_idx and expert_idx < expert_end) { - // Count single rank - rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; - is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; - } - } - - auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; - #pragma unroll - for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) { - shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); - num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); - } - - #pragma unroll - for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) - num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); - } - __syncthreads(); - - // Sum up - EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); - if (rank_begin_idx + thread_id < rank_end_idx) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumThreads; ++ i) - sum += num_tokens_per_rank_per_thread[i][thread_id]; - num_tokens_per_rank[rank_begin_idx + thread_id] = sum; - } - - if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { - int sum = 0; - #pragma unroll - for (int i = 0; i < kNumThreads; ++ i) - sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; - num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; - } - } -} - -void get_dispatch_layout(const int64_t* topk_idx, - int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, - int* num_tokens_per_expert, bool* is_token_in_rank, - int num_tokens, int num_topk, int num_ranks, int num_experts, - cudaStream_t stream) { - constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; - int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; - EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); - - SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); - LAUNCH_KERNEL(&cfg, (get_dispatch_layout), - topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, - num_tokens, num_topk, num_ranks, num_experts); -} - struct SourceMeta { int src_rdma_rank, is_token_in_nvl_rank_bits; diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 6f7c701..8915e62 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -227,6 +227,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to auto channel_x_scales_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); // TMA stuffs +#ifndef DISABLE_SM90_FEATURES extern __shared__ __align__(1024) uint8_t smem_buffer[]; auto half_hidden_int4 = hidden_int4 / 2; auto half_hidden_bytes = half_hidden_int4 * static_cast(sizeof(int4)); @@ -240,6 +241,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 and half_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp); } __syncwarp(); +#endif if (is_sender) { // Workers for sending @@ -399,6 +401,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; +#ifndef DISABLE_SM90_FEATURES #pragma unroll for (int i = 0; i < 2; ++ i) if (lane_id == 0) { tma_store_wait(); @@ -408,6 +411,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to tma_store_1d(tma_buffer, shifted_recv_x_int4 + i * half_hidden_int4, half_hidden_bytes, false); } __syncwarp(); +#else + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, + ld_nc_global, st_na_global); +#endif } // Copy `src_idx` @@ -447,8 +454,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to } // Make TMA store visible to the next kernel +#ifndef DISABLE_SM90_FEATURES if (lane_id == 0) tma_store_wait(); +#endif } @@ -473,12 +482,13 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 768; constexpr int kNumTMABytesPerWarp = 8192; +#ifndef DISABLE_SM90_FEATURES constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); +#endif #define DISPATCH_LAUNCH_CASE(ranks) { \ auto kernel = dispatch; \ - EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ - cfg.dynamicSmemBytes = smem_size; \ + SET_SHARED_MEMORY_FOR_TMA(kernel); \ LAUNCH_KERNEL(&cfg, kernel, \ reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ @@ -587,8 +597,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights, auto recv_int4 = reinterpret_cast(recv_x); // TMA stuffs +#ifndef DISABLE_SM90_FEATURES extern __shared__ __align__(1024) uint8_t smem_buffer[]; auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp; +#endif if (is_sender) { // Workers for sending @@ -778,9 +790,11 @@ combine(dtype_t* recv_x, float* recv_topk_weights, } // Wait shared memory release +#ifndef DISABLE_SM90_FEATURES if (lane_id == 0) tma_store_wait(); __syncwarp(); +#endif // Reduce data with pipeline constexpr int kNumStages = 8; @@ -810,6 +824,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, for (int j = 0; j < kDtypePerInt4; ++ j) out_dtypes[j] = static_cast(values[j]); +#ifndef DISABLE_SM90_FEATURES // Wait TMA arrival if (lane_id == 0) tma_store_wait(); @@ -828,6 +843,9 @@ combine(dtype_t* recv_x, float* recv_topk_weights, recv_int4 + token_idx * hidden_int4 + i, tma_bytes, false); } __syncwarp(); +#else + recv_int4[token_idx * hidden_int4 + i] = out_int4; +#endif } // Reduce `topk_weights` @@ -850,8 +868,10 @@ combine(dtype_t* recv_x, float* recv_topk_weights, warp_retired[recv_warp_id] = true; // Make TMA store visible to the next kernel +#ifndef DISABLE_SM90_FEATURES if (lane_id == 0) tma_store_wait(); +#endif } } } @@ -866,12 +886,13 @@ void combine(cudaDataType_t type, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 768; constexpr int kNumTMABytesPerWarp = 4096; +#ifndef DISABLE_SM90_FEATURES constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); +#endif #define COMBINE_LAUNCH_CASE(dtype, ranks) { \ auto kernel = combine; \ - EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ - cfg.dynamicSmemBytes = smem_size; \ + SET_SHARED_MEMORY_FOR_TMA(kernel); \ LAUNCH_KERNEL(&cfg, kernel, \ reinterpret_cast(recv_x), recv_topk_weights, \ reinterpret_cast(x), topk_weights, \ diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index 3d7574c..5b398bf 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -1,8 +1,10 @@ #pragma once #include "configs.cuh" +#include "exception.cuh" #ifndef SETUP_LAUNCH_CONFIG +#ifndef DISABLE_SM90_FEATURES #define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream) \ cudaLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream, nullptr, 0}; \ cudaLaunchAttribute attr[1]; \ @@ -10,10 +12,39 @@ attr[0].val.cooperative = 1; \ cfg.attrs = attr; \ cfg.numAttrs = 1 +#else +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + int __num_sms = (sms); \ + int __num_threads = (threads); \ + auto __stream = (stream) +#endif #endif #ifndef LAUNCH_KERNEL +#ifndef DISABLE_SM90_FEATURES #define LAUNCH_KERNEL(config, kernel, ...) CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) +#else +#define LAUNCH_KERNEL(config, kernel, ...) \ +do { \ + kernel<<<__num_sms, __num_threads, 0, __stream>>>(__VA_ARGS__); \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) { \ + EPException cuda_exception("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ + fprintf(stderr, "%s\n", cuda_exception.what()); \ + throw cuda_exception; \ + } \ +} while (0) +#endif +#endif + +#ifndef SET_SHARED_MEMORY_FOR_TMA +#ifndef DISABLE_SM90_FEATURES +#define SET_SHARED_MEMORY_FOR_TMA(kernel) \ +EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ +cfg.dynamicSmemBytes = smem_size; +#else +#define SET_SHARED_MEMORY_FOR_TMA(kernel) void() +#endif #endif #define SWITCH_RANKS(case_macro) \ diff --git a/csrc/kernels/layout.cu b/csrc/kernels/layout.cu new file mode 100644 index 0000000..829d5bc --- /dev/null +++ b/csrc/kernels/layout.cu @@ -0,0 +1,136 @@ +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" + +namespace deep_ep { + +namespace layout { + +template +__global__ void __launch_bounds__(kNumThreads, 1) +get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, + int num_tokens, int num_topk, int num_ranks, int num_experts) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + + // Count expert statistics + __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; + int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); + if (expert_begin_idx < expert_end_idx) { + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumExpertsPerSM; ++ i) + num_tokens_per_expert_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + #pragma unroll + for (int j = 0, expert_idx; j < num_topk; ++ j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) + ++ num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; + } + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); + if (expert_begin_idx + thread_id < expert_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++ i) + sum += num_tokens_per_expert_per_thread[i][thread_id]; + num_tokens_per_expert[expert_begin_idx + thread_id] = sum; + } + return; + } + + if (num_tokens_per_rdma_rank != nullptr) + EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); + + // Count rank statistics + constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; + __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; + __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; + auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; + int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); + int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; + if (rank_begin_idx < rank_end_idx) { + const auto num_expert_per_rank = num_experts / num_ranks; + auto expert_begin = rank_begin_idx * num_expert_per_rank; + auto expert_end = rank_end_idx * num_expert_per_rank; + + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumRanksPerSM; ++ i) + num_tokens_per_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanksPerSM; ++ i) + num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; + #pragma unroll + for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin <= expert_idx and expert_idx < expert_end) { + // Count single rank + rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; + is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; + } + } + + auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; + #pragma unroll + for (int j = 0; j + rank_begin_idx < rank_end_idx; ++ j) { + shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); + num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); + } + + #pragma unroll + for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) + num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); + if (rank_begin_idx + thread_id < rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++ i) + sum += num_tokens_per_rank_per_thread[i][thread_id]; + num_tokens_per_rank[rank_begin_idx + thread_id] = sum; + } + + if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++ i) + sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; + num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; + } + } +} + +void get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, + int num_tokens, int num_topk, int num_ranks, int num_experts, + cudaStream_t stream) { + constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; + int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; + EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); + + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, (get_dispatch_layout), + topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, + num_tokens, num_topk, num_ranks, num_experts); +} + +} // namespace layout + +} // namespace deep_ep diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index 79abdcd..26fecd5 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -5,7 +5,10 @@ #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" + +#ifndef DISABLE_NVSHMEM #include "ibgda_device.cuh" +#endif namespace deep_ep { @@ -30,6 +33,7 @@ void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t st namespace internode { +#ifndef DISABLE_NVSHMEM nvshmem_team_t cpu_rdma_team = NVSHMEM_TEAM_INVALID; nvshmem_team_config_t cpu_rdma_team_config; @@ -81,6 +85,7 @@ void finalize() { } nvshmem_finalize(); } +#endif } // namespace internode diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 7eec3cf..e19d156 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -266,6 +266,9 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); } +// TMA PTX instructions +#ifndef DISABLE_SM90_FEATURES + __device__ __forceinline__ void fence_view_async_shared() { asm volatile("fence.proxy.async.shared::cta; \n" :: ); } @@ -327,6 +330,8 @@ __device__ __forceinline__ void tma_store_wait() { asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory"); } +#endif + template __host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index ff4634c..f72139c 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -7,7 +7,7 @@ from typing import Callable, List, Tuple, Optional, Union import deep_ep_cpp # noinspection PyUnresolvedReferences from deep_ep_cpp import Config, EventHandle -from .utils import EventOverlap +from .utils import EventOverlap, check_nvlink_connections class Buffer: @@ -50,6 +50,7 @@ class Buffer: please make sure all connections are via NVLink. allow_mnnvl: whether to allow MNNVL """ + check_nvlink_connections(group) # Initialize the CPP runtime self.rank = group.rank() @@ -105,6 +106,10 @@ class Buffer: self.runtime.sync(device_ids, ipc_handles, root_unique_id) assert self.runtime.is_available() + @staticmethod + def is_sm90_compiled(): + return deep_ep_cpp.is_sm90_compiled() + @staticmethod def set_num_sms(new_num_sms: int) -> None: """ diff --git a/deep_ep/utils.py b/deep_ep/utils.py index 009aa2a..3fce634 100644 --- a/deep_ep/utils.py +++ b/deep_ep/utils.py @@ -1,4 +1,7 @@ +import os +import subprocess import torch +import torch.distributed as dist from typing import Any, Optional, Tuple # noinspection PyUnresolvedReferences @@ -58,3 +61,28 @@ class EventOverlap: """ if self.event is not None: self.event.current_stream_wait() + + +def check_nvlink_connections(group: dist.ProcessGroup): + """ + Check NVLink connection between every pair of GPUs. + + Arguments: + group: the communication group. + """ + # Check NVLink connection + # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2 + if 'PCIE' in torch.cuda.get_device_name(): + assert group.size() <= 2, 'No NVLink connection between all GPUs' + devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',') + physical_device_idx = int(devices[torch.cuda.current_device()]) + physical_device_indices = [0, ] * group.size() + dist.all_gather_object(physical_device_indices, physical_device_idx, group) + + # Get connection matrix from `nvidia-smi` + lines = subprocess.check_output(['nvidia-smi', 'topo', '-p2p', 'n']).decode('utf-8').split('\n') + for line in lines: + if line.lstrip().startswith(f'GPU{physical_device_idx}') and 'X' in line: + status = line.strip().lstrip(f'GPU{physical_device_idx}').split() + for dst_gpu_rank in physical_device_indices: + assert status[dst_gpu_rank] in ('X', 'OK'), f'No NVLink connection between GPU {physical_device_idx} and GPU {dst_gpu_rank}' diff --git a/setup.py b/setup.py index af8e4e6..b16310a 100644 --- a/setup.py +++ b/setup.py @@ -6,34 +6,76 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension if __name__ == '__main__': nvshmem_dir = os.getenv('NVSHMEM_DIR', None) - assert nvshmem_dir is not None and os.path.exists(nvshmem_dir), 'Failed to find NVSHMEM' - print(f'NVSHMEM directory: {nvshmem_dir}') + disable_nvshmem = nvshmem_dir is None + if disable_nvshmem: + print('Warning: `NVSHMEM_DIR` is not specified, all internode and low-latency features are disabled\n') + else: + assert os.path.exists(nvshmem_dir), f'Failed to find NVSHMEM: {nvshmem_dir}' - # TODO: currently, we only support Hopper architecture, we may add Ampere support later - if os.getenv('TORCH_CUDA_ARCH_LIST', None) is None: - os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes'] - nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10'] - include_dirs = ['csrc/', f'{nvshmem_dir}/include'] - sources = ['csrc/deep_ep.cpp', - 'csrc/kernels/runtime.cu', 'csrc/kernels/intranode.cu', - 'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'] - library_dirs = [f'{nvshmem_dir}/lib'] + nvcc_flags = ['-O3', '-Xcompiler', '-O3'] + sources = ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu'] + include_dirs = ['csrc/'] + library_dirs = [] + nvcc_dlink = [] + extra_link_args = [] + + # NVSHMEM flags + if disable_nvshmem: + cxx_flags.append('-DDISABLE_NVSHMEM') + nvcc_flags.append('-DDISABLE_NVSHMEM') + else: + sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']) + include_dirs.extend([f'{nvshmem_dir}/include']) + library_dirs.extend([f'{nvshmem_dir}/lib']) + nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']) + extra_link_args.extend(['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']) + + if int(os.getenv('DISABLE_SM90_FEATURES', 0)): + # Prefer A100 + os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '8.0') + + # Disable some SM90 features: FP8, launch methods, and TMA + cxx_flags.append('-DDISABLE_SM90_FEATURES') + nvcc_flags.append('-DDISABLE_SM90_FEATURES') + + # Disable internode and low-latency kernels + assert disable_nvshmem + + # Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate` + assert int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', 1)) == 1 + os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1' + else: + # Prefer H800 series + os.environ['TORCH_CUDA_ARCH_LIST'] = os.getenv('TORCH_CUDA_ARCH_LIST', '9.0') + + # CUDA 12 flags + nvcc_flags.extend(['-rdc=true', '--ptxas-options=--register-usage-level=10']) # Disable aggressive PTX instructions if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')): cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') - # Disable DLTO (default by PyTorch) - nvcc_dlink = ['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem'] - extra_link_args = ['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib'] + # Put them together extra_compile_args = { 'cxx': cxx_flags, 'nvcc': nvcc_flags, - 'nvcc_dlink': nvcc_dlink } + if len(nvcc_dlink) > 0: + extra_compile_args['nvcc_dlink'] = nvcc_dlink + + # Summary + print(f'Build summary:') + print(f' > Sources: {sources}') + print(f' > Includes: {include_dirs}') + print(f' > Libraries: {library_dirs}') + print(f' > Compilation flags: {extra_compile_args}') + print(f' > Link flags: {extra_link_args}') + print(f' > Arch list: {os.environ["TORCH_CUDA_ARCH_LIST"]}') + print(f' > NVSHMEM path: {nvshmem_dir}') + print() # noinspection PyBroadException try: @@ -44,7 +86,7 @@ if __name__ == '__main__': setuptools.setup( name='deep_ep', - version='1.0.0' + revision, + version='1.1.0' + revision, packages=setuptools.find_packages( include=['deep_ep'] ), diff --git a/tests/test_intranode.py b/tests/test_intranode.py index fb8a573..1fe713c 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -1,4 +1,3 @@ -import os import time import torch import torch.distributed as dist @@ -21,7 +20,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: # Random data x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - x_e4m3 = per_token_cast_to_fp8(x) + x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank @@ -80,7 +79,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: for previous_mode in (False, True): for async_mode in (False, True): - for current_x in (x_pure_rand, x, x_e4m3): + for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)): for with_topk in (False, True): if local_rank == 0: print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') @@ -168,7 +167,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: # Tune dispatch performance best_dispatch_results = None fp8_factor = (1 + 4 / 128) / 2 - for current_x in (x_e4m3, x): + for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)): best_time, best_results = 1e10, None nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ): @@ -189,8 +188,8 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) print('', flush=True) - if isinstance(current_x, tuple): - # Gather FP8 the best config from rank 0 + # Gather the best config from rank 0 and the first test setting + if best_dispatch_results is None: best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda') all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group)