From bf4a4a21d282026b293ed61668aeb807540a3dba Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Wed, 18 Jun 2025 14:43:38 +0800 Subject: [PATCH] Set `device_id` to suppress pytorch warning. --- tests/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 1a9c176..7af4947 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,12 +14,17 @@ def init_dist(local_rank: int, num_local_ranks: int): node_rank = int(os.getenv('RANK', 0)) assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - dist.init_process_group( - backend='nccl', - init_method=f'tcp://{ip}:{port}', - world_size=num_nodes * num_local_ranks, - rank=node_rank * num_local_ranks + local_rank - ) + import inspect + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + params['device_id'] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank)