mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Set device_id
to suppress pytorch warning.
This commit is contained in:
parent
77f97f79bd
commit
bf4a4a21d2
@ -14,12 +14,17 @@ def init_dist(local_rank: int, num_local_ranks: int):
|
|||||||
node_rank = int(os.getenv('RANK', 0))
|
node_rank = int(os.getenv('RANK', 0))
|
||||||
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
|
||||||
|
|
||||||
dist.init_process_group(
|
import inspect
|
||||||
backend='nccl',
|
sig = inspect.signature(dist.init_process_group)
|
||||||
init_method=f'tcp://{ip}:{port}',
|
params = {
|
||||||
world_size=num_nodes * num_local_ranks,
|
'backend': 'nccl',
|
||||||
rank=node_rank * num_local_ranks + local_rank
|
'init_method': f'tcp://{ip}:{port}',
|
||||||
)
|
'world_size': num_nodes * num_local_ranks,
|
||||||
|
'rank': node_rank * num_local_ranks + local_rank,
|
||||||
|
}
|
||||||
|
if 'device_id' in sig.parameters:
|
||||||
|
params['device_id'] = torch.device(f"cuda:{local_rank}")
|
||||||
|
dist.init_process_group(**params)
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
torch.set_default_device('cuda')
|
torch.set_default_device('cuda')
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
Loading…
Reference in New Issue
Block a user