mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Fix initialization for torch: only call torch get_worker_info if torch was loaded
This commit is contained in:
parent
85f6e48d52
commit
f7cb4e3e9c
@ -1,6 +1,7 @@
|
||||
""" Configuration module. Uses backend_config to load system configuration. """
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from os.path import expandvars, expanduser
|
||||
|
||||
from ..backend_api import load_config
|
||||
@ -78,24 +79,28 @@ def get_node_id(default=0):
|
||||
if node_id is None:
|
||||
node_id = default
|
||||
|
||||
# check if we have pyTorch node/worker ID
|
||||
try:
|
||||
from torch.utils.data.dataloader import get_worker_info
|
||||
worker_info = get_worker_info()
|
||||
if not worker_info:
|
||||
torch_rank = None
|
||||
# check if we have pyTorch node/worker ID (only if torch was already imported)
|
||||
if 'torch' in sys.modules:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from torch.utils.data.dataloader import get_worker_info # noqa
|
||||
worker_info = get_worker_info()
|
||||
if not worker_info:
|
||||
torch_rank = None
|
||||
else:
|
||||
w_id = worker_info.id
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
torch_rank = int(w_id)
|
||||
except Exception:
|
||||
# guess a number based on wid hopefully unique value
|
||||
import hashlib
|
||||
h = hashlib.md5()
|
||||
h.update(str(w_id).encode('utf-8'))
|
||||
torch_rank = int(h.hexdigest(), 16)
|
||||
except Exception:
|
||||
torch_rank = None
|
||||
else:
|
||||
w_id = worker_info.id
|
||||
try:
|
||||
torch_rank = int(w_id)
|
||||
except Exception:
|
||||
# guess a number based on wid hopefully unique value
|
||||
import hashlib
|
||||
h = hashlib.md5()
|
||||
h.update(str(w_id).encode('utf-8'))
|
||||
torch_rank = int(h.hexdigest(), 16)
|
||||
except Exception:
|
||||
torch_rank = None
|
||||
|
||||
# if we also have a torch rank add it to the node rank
|
||||
if torch_rank is not None:
|
||||
@ -118,8 +123,8 @@ def get_log_redirect_level():
|
||||
value = LOG_STDERR_REDIRECT_LEVEL.get() # noqa: F405
|
||||
try:
|
||||
if value:
|
||||
return logging._checkLevel(value)
|
||||
except (ValueError, TypeError):
|
||||
return logging._checkLevel(value) # noqa
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user