mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Enhance utilities
This commit is contained in:
parent
fbb9af21f7
commit
e4ceeb2c11
clearml
@ -113,6 +113,35 @@ def get_log_to_backend(default=None):
|
||||
return LOG_TO_BACKEND_ENV_VAR.get(default=default) # noqa: F405
|
||||
|
||||
|
||||
def get_node_count():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
|
||||
return mpi_world_rank
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_JOB_NUM_NODES')))
|
||||
return mpi_rank
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# check if we have pyTorch node/worker ID (only if torch was already imported)
|
||||
if 'torch' in sys.modules:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from torch.utils.data.dataloader import get_worker_info # noqa
|
||||
worker_info = get_worker_info()
|
||||
if worker_info:
|
||||
return int(worker_info.num_workers)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_node_id(default=0):
|
||||
node_id = NODE_ID_ENV_VAR.get() # noqa: F405
|
||||
|
||||
@ -124,7 +153,9 @@ def get_node_id(default=0):
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID')))
|
||||
mpi_rank = int(os.environ.get(
|
||||
'OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID', os.environ.get('SLURM_NODEID')))
|
||||
)
|
||||
except Exception:
|
||||
mpi_rank = None
|
||||
|
||||
|
@ -91,9 +91,8 @@ def verify_basic_type(a_dict_list, basic_types=None):
|
||||
all(verify_basic_type(v) for v in a_dict_list.values())
|
||||
|
||||
|
||||
def flatten_dictionary(a_dict, prefix=''):
|
||||
def flatten_dictionary(a_dict, prefix='', sep='/'):
|
||||
flat_dict = {}
|
||||
sep = '/'
|
||||
basic_types = (float, int, bool, six.string_types, )
|
||||
for k, v in a_dict.items():
|
||||
k = str(k)
|
||||
@ -102,7 +101,7 @@ def flatten_dictionary(a_dict, prefix=''):
|
||||
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
|
||||
flat_dict[prefix + k] = v
|
||||
elif isinstance(v, dict):
|
||||
nested_flat_dict = flatten_dictionary(v, prefix=prefix + k + sep)
|
||||
nested_flat_dict = flatten_dictionary(v, prefix=prefix + k + sep, sep=sep)
|
||||
if nested_flat_dict:
|
||||
flat_dict.update(nested_flat_dict)
|
||||
else:
|
||||
@ -114,9 +113,8 @@ def flatten_dictionary(a_dict, prefix=''):
|
||||
return flat_dict
|
||||
|
||||
|
||||
def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
|
||||
def nested_from_flat_dictionary(a_dict, flat_dict, prefix='', sep='/'):
|
||||
basic_types = (float, int, bool, six.string_types, )
|
||||
sep = '/'
|
||||
org_dict = copy(a_dict)
|
||||
for k, v in org_dict.items():
|
||||
k = str(k)
|
||||
@ -125,7 +123,7 @@ def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
|
||||
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
|
||||
a_dict[k] = flat_dict.get(prefix + k, v)
|
||||
elif isinstance(v, dict):
|
||||
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix + k + sep) or v
|
||||
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix + k + sep, sep=sep) or v
|
||||
else:
|
||||
# this is a mixture of list and dict, or any other object,
|
||||
# leave it as is, we have nothing to do with it.
|
||||
@ -145,7 +143,7 @@ def naive_nested_from_flat_dictionary(flat_dict, sep='/'):
|
||||
k[len(sub_prefix) + 1:]: v
|
||||
for k, v in bucket
|
||||
if len(k) > len(sub_prefix)
|
||||
}
|
||||
}, sep=sep
|
||||
)
|
||||
)
|
||||
for sub_prefix, bucket in (
|
||||
|
Loading…
Reference in New Issue
Block a user