Fix initialization for torch: only call torch get_worker_info if torch was loaded

This commit is contained in:
allegroai 2020-12-06 11:22:52 +02:00
parent 85f6e48d52
commit f7cb4e3e9c

View File

@ -1,6 +1,7 @@
""" Configuration module. Uses backend_config to load system configuration. """
import logging
import os
import sys
from os.path import expandvars, expanduser
from ..backend_api import load_config
@ -78,24 +79,28 @@ def get_node_id(default=0):
if node_id is None:
node_id = default
# check if we have pyTorch node/worker ID
try:
from torch.utils.data.dataloader import get_worker_info
worker_info = get_worker_info()
if not worker_info:
torch_rank = None
# check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules:
# noinspection PyBroadException
try:
from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info()
if not worker_info:
torch_rank = None
else:
w_id = worker_info.id
# noinspection PyBroadException
try:
torch_rank = int(w_id)
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None
else:
w_id = worker_info.id
try:
torch_rank = int(w_id)
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None
# if we also have a torch rank add it to the node rank
if torch_rank is not None:
@ -118,8 +123,8 @@ def get_log_redirect_level():
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
try:
if value:
return logging._checkLevel(value)
except (ValueError, TypeError):
return logging._checkLevel(value) # noqa
except (ValueError, TypeError, AttributeError):
pass