diff --git a/tests/utils.py b/tests/utils.py index 1a9c176..57a38b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.distributed as dist from typing import Optional +import inspect def init_dist(local_rank: int, num_local_ranks: int): @@ -14,12 +15,16 @@ def init_dist(local_rank: int, num_local_ranks: int): node_rank = int(os.getenv('RANK', 0)) assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - dist.init_process_group( - backend='nccl', - init_method=f'tcp://{ip}:{port}', - world_size=num_nodes * num_local_ranks, - rank=node_rank * num_local_ranks + local_rank - ) + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + params['device_id'] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank)