diff --git a/tests/utils.py b/tests/utils.py index 51ee18e..da2e12c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ import inspect +import numpy as np import os import sys -import numpy as np import torch import torch.distributed as dist from typing import Optional @@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int): 'rank': node_rank * num_local_ranks + local_rank, } if 'device_id' in sig.parameters: - params['device_id'] = torch.device(f"cuda:{local_rank}") + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda')