diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index d008098..8e536ca 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -58,15 +58,12 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - // Normal operations use IBRC, while low-latency operations use IBGDA - bool internode_use_ibgda = true; - if (low_latency_mode or internode_use_ibgda) { - nvshmemi_device_host_state_t* dev_state_ptr = nullptr; - CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); + // TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later + nvshmemi_device_host_state_t* dev_state_ptr = nullptr; + CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); - bool ibgda_is_initialized = false; - CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); - } + bool ibgda_is_initialized = false; + CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); nvshmem_barrier_all(); return nvshmem_my_pe(); } diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 0665eb7..dd14838 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -65,19 +65,17 @@ class Buffer: # Synchronize NVSHMEM unique IDs root_unique_id = None - internode_use_ibgda = True if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: - # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode or internode_use_ibgda: - assert num_qps_per_rank > 0 - os.environ['NVSHMEM_DISABLE_P2P'] = '1' - os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' - os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' - os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' - # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check - os.environ['NVSHMEM_QP_DEPTH'] = '1024' - # NOTES: NVSHMEM initialization requires at least 256 MiB - os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' + # Enable IBGDA + assert num_qps_per_rank > 0 + os.environ['NVSHMEM_DISABLE_P2P'] = '1' + os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' + os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' + os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' + # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check + os.environ['NVSHMEM_QP_DEPTH'] = '1024' + # NOTES: NVSHMEM initialization requires at least 256 MiB + os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' # Synchronize using the root ID nvshmem_unique_ids = [None, ] * self.group_size