From 9d4f7ef8eeedd9970c2bb8efe998e07e7bb9df5b Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 18 Jun 2025 16:04:42 +0800 Subject: [PATCH] Surpass type checks --- tests/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 51ee18e..da2e12c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ import inspect +import numpy as np import os import sys -import numpy as np import torch import torch.distributed as dist from typing import Optional @@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int): 'rank': node_rank * num_local_ranks + local_rank, } if 'device_id' in sig.parameters: - params['device_id'] = torch.device(f"cuda:{local_rank}") + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda')