Fix initialization for torch: only call torch get_worker_info if torch was loaded

This commit is contained in:
allegroai 2020-12-06 11:22:52 +02:00
parent 85f6e48d52
commit f7cb4e3e9c

View File

@ -1,6 +1,7 @@
""" Configuration module. Uses backend_config to load system configuration. """ """ Configuration module. Uses backend_config to load system configuration. """
import logging import logging
import os import os
import sys
from os.path import expandvars, expanduser from os.path import expandvars, expanduser
from ..backend_api import load_config from ..backend_api import load_config
@ -78,14 +79,18 @@ def get_node_id(default=0):
if node_id is None: if node_id is None:
node_id = default node_id = default
# check if we have pyTorch node/worker ID torch_rank = None
# check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules:
# noinspection PyBroadException
try: try:
from torch.utils.data.dataloader import get_worker_info from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info() worker_info = get_worker_info()
if not worker_info: if not worker_info:
torch_rank = None torch_rank = None
else: else:
w_id = worker_info.id w_id = worker_info.id
# noinspection PyBroadException
try: try:
torch_rank = int(w_id) torch_rank = int(w_id)
except Exception: except Exception:
@ -118,8 +123,8 @@ def get_log_redirect_level():
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405 value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
try: try:
if value: if value:
return logging._checkLevel(value) return logging._checkLevel(value) # noqa
except (ValueError, TypeError): except (ValueError, TypeError, AttributeError):
pass pass