From 9ec061204e6763c12a9dd9f4cc5ca3b6c868b552 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 11 Jun 2025 17:29:00 +0800 Subject: [PATCH] Use `pynvml` to detect NVLink connections (#205) * Use `pynvml` to detect NVLink connections * Add a TODO * Add shutdown * Fix comments --- deep_ep/utils.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/deep_ep/utils.py b/deep_ep/utils.py index 3fce634..b84e175 100644 --- a/deep_ep/utils.py +++ b/deep_ep/utils.py @@ -72,17 +72,30 @@ def check_nvlink_connections(group: dist.ProcessGroup): """ # Check NVLink connection # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2 + # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink if 'PCIE' in torch.cuda.get_device_name(): - assert group.size() <= 2, 'No NVLink connection between all GPUs' + assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections' + + # noinspection PyUnresolvedReferences + import pynvml + pynvml.nvmlInit() + + # noinspection PyTypeChecker devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',') physical_device_idx = int(devices[torch.cuda.current_device()]) physical_device_indices = [0, ] * group.size() dist.all_gather_object(physical_device_indices, physical_device_idx, group) - # Get connection matrix from `nvidia-smi` - lines = subprocess.check_output(['nvidia-smi', 'topo', '-p2p', 'n']).decode('utf-8').split('\n') - for line in lines: - if line.lstrip().startswith(f'GPU{physical_device_idx}') and 'X' in line: - status = line.strip().lstrip(f'GPU{physical_device_idx}').split() - for dst_gpu_rank in physical_device_indices: - assert status[dst_gpu_rank] in ('X', 'OK'), f'No NVLink connection between GPU {physical_device_idx} and GPU {dst_gpu_rank}' + # Check whether they are all connected via NVLink + # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438 + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i >= j: + continue + status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + assert status == pynvml.NVML_P2P_STATUS_OK,\ + f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink' + + # Close NVML + pynvml.nvmlShutdown()