Add OpenMPI/Slurm support

This commit is contained in:
allegroai 2020-03-20 10:23:00 +02:00
parent 0adbd79975
commit babaf9f1ce
2 changed files with 28 additions and 4 deletions

View File

@ -95,6 +95,14 @@ def get_node_id(default=0):
return node_id return node_id
def get_is_master_node():
global __force_master_node
if __force_master_node:
return True
return get_node_id(default=0) == 0
def get_log_redirect_level(): def get_log_redirect_level():
""" Returns which log level (and up) should be redirected to stderr. None means no redirection. """ """ Returns which log level (and up) should be redirected to stderr. None means no redirection. """
value = LOG_STDERR_REDIRECT_LEVEL.get() value = LOG_STDERR_REDIRECT_LEVEL.get()
@ -107,3 +115,21 @@ def get_log_redirect_level():
def dev_worker_name(): def dev_worker_name():
return DEV_WORKER_NAME.get() return DEV_WORKER_NAME.get()
def __set_is_master_node():
try:
force_master_node = os.environ.pop('TRAINS_FORCE_MASTER_NODE', None)
except:
force_master_node = None
if force_master_node is not None:
try:
force_master_node = bool(int(force_master_node))
except:
force_master_node = None
return force_master_node
__force_master_node = __set_is_master_node()

View File

@ -35,7 +35,7 @@ from .binding.frameworks.tensorflow_bind import TensorflowBinding
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.joblib_bind import PatchedJoblib from .binding.joblib_bind import PatchedJoblib
from .binding.matplotlib_bind import PatchedMatplotlib from .binding.matplotlib_bind import PatchedMatplotlib
from .config import config, DEV_TASK_NO_REUSE, get_node_id from .config import config, DEV_TASK_NO_REUSE, get_is_master_node
from .config import running_remotely, get_remote_task_id from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache from .config.cache import SessionCache
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
@ -240,9 +240,7 @@ class Task(_Task):
# we could not find a task ID, revert to old stub behaviour # we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id: if not is_sub_process_task_id:
return _TaskStub() return _TaskStub()
elif running_remotely() and get_node_id(default=0) != 0: elif running_remotely() and not get_is_master_node():
print("get_node_id", get_node_id(), get_remote_task_id())
# make sure we only do it once per process # make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid() cls.__forked_proc_main_pid = os.getpid()
# make sure everyone understands we should act as if we are a subprocess (fake pid 1) # make sure everyone understands we should act as if we are a subprocess (fake pid 1)