1
0
mirror of https://github.com/clearml/clearml synced 2025-03-03 18:52:12 +00:00

Enhance utilities

This commit is contained in:
allegroai 2021-10-16 23:07:33 +03:00
parent fbb9af21f7
commit e4ceeb2c11
2 changed files with 37 additions and 8 deletions
clearml

View File

@ -113,6 +113,35 @@ def get_log_to_backend(default=None):
return LOG_TO_BACKEND_ENV_VAR.get(default=default) # noqa: F405
def get_node_count():
# noinspection PyBroadException
try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
return mpi_world_rank
except Exception:
pass
# noinspection PyBroadException
try:
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_JOB_NUM_NODES')))
return mpi_rank
except Exception:
pass
# check if we have pyTorch node/worker ID (only if torch was already imported)
if 'torch' in sys.modules:
# noinspection PyBroadException
try:
from torch.utils.data.dataloader import get_worker_info # noqa
worker_info = get_worker_info()
if worker_info:
return int(worker_info.num_workers)
except Exception:
pass
return None
def get_node_id(default=0):
node_id = NODE_ID_ENV_VAR.get() # noqa: F405
@ -124,7 +153,9 @@ def get_node_id(default=0):
# noinspection PyBroadException
try:
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID')))
mpi_rank = int(os.environ.get(
'OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID', os.environ.get('SLURM_NODEID')))
)
except Exception:
mpi_rank = None

View File

@ -91,9 +91,8 @@ def verify_basic_type(a_dict_list, basic_types=None):
all(verify_basic_type(v) for v in a_dict_list.values())
def flatten_dictionary(a_dict, prefix=''):
def flatten_dictionary(a_dict, prefix='', sep='/'):
flat_dict = {}
sep = '/'
basic_types = (float, int, bool, six.string_types, )
for k, v in a_dict.items():
k = str(k)
@ -102,7 +101,7 @@ def flatten_dictionary(a_dict, prefix=''):
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
flat_dict[prefix + k] = v
elif isinstance(v, dict):
nested_flat_dict = flatten_dictionary(v, prefix=prefix + k + sep)
nested_flat_dict = flatten_dictionary(v, prefix=prefix + k + sep, sep=sep)
if nested_flat_dict:
flat_dict.update(nested_flat_dict)
else:
@ -114,9 +113,8 @@ def flatten_dictionary(a_dict, prefix=''):
return flat_dict
def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
def nested_from_flat_dictionary(a_dict, flat_dict, prefix='', sep='/'):
basic_types = (float, int, bool, six.string_types, )
sep = '/'
org_dict = copy(a_dict)
for k, v in org_dict.items():
k = str(k)
@ -125,7 +123,7 @@ def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
a_dict[k] = flat_dict.get(prefix + k, v)
elif isinstance(v, dict):
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix + k + sep) or v
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix + k + sep, sep=sep) or v
else:
# this is a mixture of list and dict, or any other object,
# leave it as is, we have nothing to do with it.
@ -145,7 +143,7 @@ def naive_nested_from_flat_dictionary(flat_dict, sep='/'):
k[len(sub_prefix) + 1:]: v
for k, v in bucket
if len(k) > len(sub_prefix)
}
}, sep=sep
)
)
for sub_prefix, bucket in (