Add initial slurm support (multiple nodes sharing the same task id)

This commit is contained in:
allegroai 2020-03-12 18:12:16 +02:00
parent 5b29aa194c
commit afad6a42ea
4 changed files with 64 additions and 10 deletions

View File

@ -43,8 +43,8 @@ sdk {
subsampling: 0 subsampling: 0
} }
# Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to True, each series should have its own graph) # Support plot-per-graph fully matching Tensorboard behavior (i.e. if this is set to true, each series should have its own graph)
tensorboard_single_series_per_graph: False tensorboard_single_series_per_graph: false
} }
network { network {
@ -125,11 +125,11 @@ sdk {
log { log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout) # debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False null_log_propagate: false
task_log_buffer_capacity: 66 task_log_buffer_capacity: 66
# disable urllib info and lower levels # disable urllib info and lower levels
disable_urllib3_info: True disable_urllib3_info: true
} }
development { development {
@ -139,14 +139,14 @@ sdk {
task_reuse_time_window_in_hours: 72.0 task_reuse_time_window_in_hours: 72.0
# Run VCS repository detection asynchronously # Run VCS repository detection asynchronously
vcs_repo_detect_async: True vcs_repo_detect_async: true
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode # Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section # This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True store_uncommitted_code_diff: true
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset # Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True support_stopping: true
# Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead. # Default Task output_uri. if output_uri is not provided to Task.init, default_output_uri will be used instead.
default_output_uri: "" default_output_uri: ""
@ -160,7 +160,7 @@ sdk {
ping_period_sec: 30 ping_period_sec: 30
# Log all stdout & stderr # Log all stdout & stderr
log_stdout: True log_stdout: true
} }
} }
} }

View File

@ -1059,6 +1059,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
pid = pid or os.getpid() pid = pid or os.getpid()
if not task: if not task:
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':') PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':')
elif isinstance(task, str):
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + task)
else: else:
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + str(task.id)) PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + str(task.id))
# make sure we refresh the edit lock next time we need it, # make sure we refresh the edit lock next time we need it,

View File

@ -1,5 +1,6 @@
""" Configuration module. Uses backend_config to load system configuration. """ """ Configuration module. Uses backend_config to load system configuration. """
import logging import logging
import os
from os.path import expandvars, expanduser from os.path import expandvars, expanduser
from ..backend_api import load_config from ..backend_api import load_config
@ -47,7 +48,51 @@ def get_log_to_backend(default=None):
def get_node_id(default=0): def get_node_id(default=0):
return NODE_ID_ENV_VAR.get(default=default) node_id = NODE_ID_ENV_VAR.get()
try:
mpi_world_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', os.environ.get('PMI_RANK')))
except:
mpi_world_rank = None
try:
mpi_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', os.environ.get('SLURM_PROCID')))
except:
mpi_rank = None
# if we have no node_id, use the mpi rank
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
# if node is is till None, use the default
if node_id is None:
node_id = default
# check if we have pyTorch node/worker ID
try:
from torch.utils.data.dataloader import get_worker_info
worker_info = get_worker_info()
if not worker_info:
torch_rank = None
else:
w_id = worker_info.id
try:
torch_rank = int(w_id)
except Exception:
# guess a number based on wid hopefully unique value
import hashlib
h = hashlib.md5()
h.update(str(w_id).encode('utf-8'))
torch_rank = int(h.hexdigest(), 16)
except Exception:
torch_rank = None
# if we also have a torch rank add it to the node rank
if torch_rank is not None:
# Since we dont know the world rank, we assume it is not bigger than 10k
node_id = (10000 * node_id) + torch_rank
return node_id
def get_log_redirect_level(): def get_log_redirect_level():

View File

@ -35,7 +35,7 @@ from .binding.frameworks.tensorflow_bind import TensorflowBinding
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.joblib_bind import PatchedJoblib from .binding.joblib_bind import PatchedJoblib
from .binding.matplotlib_bind import PatchedMatplotlib from .binding.matplotlib_bind import PatchedMatplotlib
from .config import config, DEV_TASK_NO_REUSE from .config import config, DEV_TASK_NO_REUSE, get_node_id
from .config import running_remotely, get_remote_task_id from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache from .config.cache import SessionCache
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
@ -240,6 +240,13 @@ class Task(_Task):
# we could not find a task ID, revert to old stub behaviour # we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id: if not is_sub_process_task_id:
return _TaskStub() return _TaskStub()
elif running_remotely() and get_node_id(default=0) != 0:
print("get_node_id", get_node_id(), get_remote_task_id())
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure everyone understands we should act as if we are a subprocess (fake pid 1)
cls.__update_master_pid_task(pid=1, task=get_remote_task_id())
else: else:
# set us as master process (without task ID) # set us as master process (without task ID)
cls.__update_master_pid_task() cls.__update_master_pid_task()