diff --git a/tests/utils.py b/tests/utils.py index 7af4947..57a38b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.distributed as dist from typing import Optional +import inspect def init_dist(local_rank: int, num_local_ranks: int): @@ -14,7 +15,6 @@ def init_dist(local_rank: int, num_local_ranks: int): node_rank = int(os.getenv('RANK', 0)) assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - import inspect sig = inspect.signature(dist.init_process_group) params = { 'backend': 'nccl',