Use pynvml to detect NVLink connections (#205)

* Use `pynvml` to detect NVLink connections

* Add a TODO

* Add shutdown

* Fix comments
This commit is contained in:
Chenggang Zhao 2025-06-11 17:29:00 +08:00 committed by GitHub
parent b8d90fb753
commit 9ec061204e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -72,17 +72,30 @@ def check_nvlink_connections(group: dist.ProcessGroup):
"""
# Check NVLink connection
# NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2
# TODO: check all cases, all local-node GPUs in the group should be connected via NVLink
if 'PCIE' in torch.cuda.get_device_name():
assert group.size() <= 2, 'No NVLink connection between all GPUs'
assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections'
# noinspection PyUnresolvedReferences
import pynvml
pynvml.nvmlInit()
# noinspection PyTypeChecker
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',')
physical_device_idx = int(devices[torch.cuda.current_device()])
physical_device_indices = [0, ] * group.size()
dist.all_gather_object(physical_device_indices, physical_device_idx, group)
# Get connection matrix from `nvidia-smi`
lines = subprocess.check_output(['nvidia-smi', 'topo', '-p2p', 'n']).decode('utf-8').split('\n')
for line in lines:
if line.lstrip().startswith(f'GPU{physical_device_idx}') and 'X' in line:
status = line.strip().lstrip(f'GPU{physical_device_idx}').split()
for dst_gpu_rank in physical_device_indices:
assert status[dst_gpu_rank] in ('X', 'OK'), f'No NVLink connection between GPU {physical_device_idx} and GPU {dst_gpu_rank}'
# Check whether they are all connected via NVLink
# Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i >= j:
continue
status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
assert status == pynvml.NVML_P2P_STATUS_OK,\
f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink'
# Close NVML
pynvml.nvmlShutdown()