Surpass type checks

This commit is contained in:
Chenggang Zhao 2025-06-18 16:04:42 +08:00
parent b56f7c2c8c
commit 9d4f7ef8ee

View File

@ -1,7 +1,7 @@
import inspect import inspect
import numpy as np
import os import os
import sys import sys
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import Optional from typing import Optional
@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int):
'rank': node_rank * num_local_ranks + local_rank, 'rank': node_rank * num_local_ranks + local_rank,
} }
if 'device_id' in sig.parameters: if 'device_id' in sig.parameters:
params['device_id'] = torch.device(f"cuda:{local_rank}") # noinspection PyTypeChecker
params['device_id'] = torch.device(f'cuda:{local_rank}')
dist.init_process_group(**params) dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda') torch.set_default_device('cuda')