Surpass type checks

This commit is contained in:
Chenggang Zhao 2025-06-18 16:04:42 +08:00
parent b56f7c2c8c
commit 9d4f7ef8ee

View File

@ -1,7 +1,7 @@
import inspect
import numpy as np
import os
import sys
import numpy as np
import torch
import torch.distributed as dist
from typing import Optional
@ -23,7 +23,8 @@ def init_dist(local_rank: int, num_local_ranks: int):
'rank': node_rank * num_local_ranks + local_rank,
}
if 'device_id' in sig.parameters:
params['device_id'] = torch.device(f"cuda:{local_rank}")
# noinspection PyTypeChecker
params['device_id'] = torch.device(f'cuda:{local_rank}')
dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')