Merge pull request #222 from deepseek-ai/set_dev_id

Set `device_id` to suppress pytorch warning.
This commit is contained in:
Shangyan Zhou 2025-06-18 14:53:26 +08:00 committed by GitHub
commit a2d2354e1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import numpy as np
import torch
import torch.distributed as dist
from typing import Optional
import inspect
def init_dist(local_rank: int, num_local_ranks: int):
@ -14,12 +15,16 @@ def init_dist(local_rank: int, num_local_ranks: int):
node_rank = int(os.getenv('RANK', 0))
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group(
backend='nccl',
init_method=f'tcp://{ip}:{port}',
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank
)
sig = inspect.signature(dist.init_process_group)
params = {
'backend': 'nccl',
'init_method': f'tcp://{ip}:{port}',
'world_size': num_nodes * num_local_ranks,
'rank': node_rank * num_local_ranks + local_rank,
}
if 'device_id' in sig.parameters:
params['device_id'] = torch.device(f"cuda:{local_rank}")
dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cuda')
torch.cuda.set_device(local_rank)