Fix initialization for torch: only call torch get_worker_info if torch was loaded

This commit is contained in:
allegroai 2020-12-06 11:22:52 +02:00
parent 85f6e48d52
commit f7cb4e3e9c

View File

@ -1,6 +1,7 @@
""" Configuration module. Uses backend_config to load system configuration. """ """ Configuration module. Uses backend_config to load system configuration. """
import logging import logging
import os import os
import sys
from os.path import expandvars, expanduser from os.path import expandvars, expanduser
from ..backend_api import load_config from ..backend_api import load_config
@ -78,24 +79,28 @@ def get_node_id(default=0):
if node_id is None: if node_id is None:
node_id = default node_id = default
# check if we have pyTorch node/worker ID torch_rank = None
try: # check if we have pyTorch node/worker ID (only if torch was already imported)
from torch.utils.data.dataloader import get_worker_info if 'torch' in sys.modules:
worker_info = get_worker_info() # noinspection PyBroadException
if not worker_info: try:
from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info()
if not worker_info:
torch_rank = None
else:
w_id = worker_info.id
# noinspection PyBroadException
try:
torch_rank = int(w_id)
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None torch_rank = None
else:
w_id = worker_info.id
try:
torch_rank = int(w_id)
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None
# if we also have a torch rank add it to the node rank # if we also have a torch rank add it to the node rank
if torch_rank is not None: if torch_rank is not None:
@ -118,8 +123,8 @@ def get_log_redirect_level():
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405 value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
try: try:
if value: if value:
return logging._checkLevel(value) return logging._checkLevel(value) # noqa
except (ValueError, TypeError): except (ValueError, TypeError, AttributeError):
pass pass