mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Add support for auto detecting torch and transformers accelerate distributed execution
This commit is contained in:
parent
141a183235
commit
801c7b4cd4
@ -102,7 +102,7 @@ class ScriptRequirements(object):
|
|||||||
for fname, lines in tfmodule.items():
|
for fname, lines in tfmodule.items():
|
||||||
modules.add('tensorflow', fname, lines)
|
modules.add('tensorflow', fname, lines)
|
||||||
|
|
||||||
# if we have torch and it supports tensorboard, we should add that as well
|
# if we have torch, and it supports tensorboard, we should add that as well
|
||||||
# (because it will not be detected automatically)
|
# (because it will not be detected automatically)
|
||||||
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
|
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -336,14 +336,14 @@ class _JupyterObserver(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# noinspection PyPackageRequirements
|
# noinspection PyPackageRequirements
|
||||||
from nbconvert.exporters import PythonExporter
|
from nbconvert.exporters import PythonExporter # noqa
|
||||||
_script_exporter = PythonExporter()
|
_script_exporter = PythonExporter()
|
||||||
except Exception:
|
except Exception:
|
||||||
_script_exporter = None
|
_script_exporter = None
|
||||||
|
|
||||||
if _script_exporter is None:
|
if _script_exporter is None:
|
||||||
# noinspection PyPackageRequirements
|
# noinspection PyPackageRequirements
|
||||||
from nbconvert.exporters.script import ScriptExporter
|
from nbconvert.exporters.script import ScriptExporter # noqa
|
||||||
_script_exporter = ScriptExporter()
|
_script_exporter = ScriptExporter()
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -622,7 +622,7 @@ class ScriptInfo(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# noinspection PyPackageRequirements
|
# noinspection PyPackageRequirements
|
||||||
from notebook.notebookapp import list_running_servers # <= Notebook v6
|
from notebook.notebookapp import list_running_servers # noqa <= Notebook v6
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
jupyter_servers += list(list_running_servers())
|
jupyter_servers += list(list_running_servers())
|
||||||
@ -637,7 +637,7 @@ class ScriptInfo(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# noinspection PyPackageRequirements
|
# noinspection PyPackageRequirements
|
||||||
from jupyter_server.serverapp import list_running_servers
|
from jupyter_server.serverapp import list_running_servers # noqa
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
jupyter_servers += list(list_running_servers())
|
jupyter_servers += list(list_running_servers())
|
||||||
@ -724,7 +724,7 @@ class ScriptInfo(object):
|
|||||||
is_google_colab = False
|
is_google_colab = False
|
||||||
log_history = False
|
log_history = False
|
||||||
colab_name = None
|
colab_name = None
|
||||||
# check if this is google.colab, then there is no local file
|
# check if this is `google.colab`, then there is no local file
|
||||||
is_google_colab = ScriptInfo.is_google_colab()
|
is_google_colab = ScriptInfo.is_google_colab()
|
||||||
|
|
||||||
if is_google_colab:
|
if is_google_colab:
|
||||||
@ -753,7 +753,7 @@ class ScriptInfo(object):
|
|||||||
if not entry_point.exists():
|
if not entry_point.exists():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5])+'.ipynb'
|
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb'
|
||||||
# now we should try to find the actual file
|
# now we should try to find the actual file
|
||||||
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
|
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
|
||||||
if not entry_point_alternative.is_file():
|
if not entry_point_alternative.is_file():
|
||||||
@ -828,7 +828,7 @@ class ScriptInfo(object):
|
|||||||
# returns tuple (notebook name, raw string notebook)
|
# returns tuple (notebook name, raw string notebook)
|
||||||
# None, None if fails
|
# None, None if fails
|
||||||
try:
|
try:
|
||||||
from google.colab import _message
|
from google.colab import _message # noqa
|
||||||
|
|
||||||
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
|
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
|
||||||
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
|
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
|
||||||
@ -995,6 +995,10 @@ class ScriptInfo(object):
|
|||||||
working_dir = cls._get_working_dir(repo_root)
|
working_dir = cls._get_working_dir(repo_root)
|
||||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||||
|
|
||||||
|
# check if we are running with torch distributed, or transformers accelerate
|
||||||
|
# make sure we change the entry point to reflect it.
|
||||||
|
entry_point = cls._detect_distributed_execution(entry_point, log)
|
||||||
|
|
||||||
if check_uncommitted:
|
if check_uncommitted:
|
||||||
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
|
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
|
||||||
if jupyter_filepath:
|
if jupyter_filepath:
|
||||||
@ -1010,7 +1014,7 @@ class ScriptInfo(object):
|
|||||||
if len(diff) > cls.max_diff_size_bytes:
|
if len(diff) > cls.max_diff_size_bytes:
|
||||||
messages.append(
|
messages.append(
|
||||||
"======> WARNING! Git diff too large to store "
|
"======> WARNING! Git diff too large to store "
|
||||||
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024))
|
"({}kb), skipping uncommitted changes <======".format(len(diff) // 1024))
|
||||||
auxiliary_git_diff = diff
|
auxiliary_git_diff = diff
|
||||||
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
|
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
|
||||||
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
|
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
|
||||||
@ -1065,6 +1069,52 @@ class ScriptInfo(object):
|
|||||||
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
|
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
|
||||||
script_requirements)
|
script_requirements)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _detect_distributed_execution(cls, entry_point, log):
|
||||||
|
# check if we are running with torch distributed, or transformers accelerate
|
||||||
|
# make sure we change the entry point to reflect it.
|
||||||
|
is_torch_distributed = os.environ.get("TORCHELASTIC_RUN_ID") is not None
|
||||||
|
is_transformers_distributed = os.environ.get("ACCELERATE_DYNAMO_MODE") is not None
|
||||||
|
if not is_torch_distributed and not is_transformers_distributed:
|
||||||
|
return entry_point
|
||||||
|
|
||||||
|
# this torch distributed
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from psutil import Process # noqa
|
||||||
|
cmdline = Process().parent().cmdline()
|
||||||
|
# first find the torch model call "torch.distributed.run" or "torch.distributed.launch"
|
||||||
|
if is_torch_distributed:
|
||||||
|
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("torch.distributed."))
|
||||||
|
elif is_transformers_distributed:
|
||||||
|
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("accelerate.commands."))
|
||||||
|
else:
|
||||||
|
raise Exception() # we should not get here
|
||||||
|
|
||||||
|
cmdline = cmdline[cmdstart_i:]
|
||||||
|
# reverse look into the paths
|
||||||
|
cmdend_i = next(i for i, c in enumerate(cmdline) if Path(c).stem == Path(entry_point).stem)
|
||||||
|
filearg = cmdline[cmdend_i]
|
||||||
|
# notice --args (script args) are passed on the Args section, we skip detecting them here
|
||||||
|
# we are also already removing the filearg from the cmd (it is the last before script args)
|
||||||
|
new_cmd = cmdline[:cmdend_i]
|
||||||
|
|
||||||
|
# we assume our entrypoint is the last parameter of the execution cmd line
|
||||||
|
if Path(filearg).stem == Path(entry_point).stem:
|
||||||
|
entry_point = "-m {} {}".format(" ".join(new_cmd), entry_point)
|
||||||
|
if log:
|
||||||
|
log.info(
|
||||||
|
"{} execution detected: adjusting entrypoint to "
|
||||||
|
"reflect distributed execution arguments".format(
|
||||||
|
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
if log:
|
||||||
|
log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format(
|
||||||
|
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate"))
|
||||||
|
|
||||||
|
return entry_point
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __legacy_jupyter_notebook_server_json_parsing():
|
def __legacy_jupyter_notebook_server_json_parsing():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
|
@ -186,7 +186,15 @@ def get_node_id(default=0):
|
|||||||
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
|
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
|
||||||
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
|
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
|
||||||
|
|
||||||
# if node is is till None, use the default
|
# if node is still None, use the global RANK
|
||||||
|
if node_id is None:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
node_id = int(os.environ.get("RANK"))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# if node is still None, use the default
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = default
|
node_id = default
|
||||||
|
|
||||||
|
@ -40,7 +40,8 @@ from .backend_config.defs import get_active_config_file, get_config_file
|
|||||||
from .backend_api.services import tasks, projects, events
|
from .backend_api.services import tasks, projects, events
|
||||||
from .backend_api.session.session import (
|
from .backend_api.session.session import (
|
||||||
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
|
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
|
||||||
from .backend_api.session.defs import ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, ENV_OFFLINE_MODE, MissingConfigError
|
from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG,
|
||||||
|
ENV_OFFLINE_MODE, MissingConfigError)
|
||||||
from .backend_interface.metrics import Metrics
|
from .backend_interface.metrics import Metrics
|
||||||
from .backend_interface.model import Model as BackendModel
|
from .backend_interface.model import Model as BackendModel
|
||||||
from .backend_interface.base import InterfaceBase
|
from .backend_interface.base import InterfaceBase
|
||||||
@ -97,6 +98,8 @@ from .utilities.proxy_object import (
|
|||||||
from .utilities.resource_monitor import ResourceMonitor
|
from .utilities.resource_monitor import ResourceMonitor
|
||||||
from .utilities.seed import make_deterministic
|
from .utilities.seed import make_deterministic
|
||||||
from .utilities.lowlevel.threads import get_current_thread_id
|
from .utilities.lowlevel.threads import get_current_thread_id
|
||||||
|
from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \
|
||||||
|
create_torch_distributed_anchor
|
||||||
from .utilities.process.mp import BackgroundMonitor, leave_process
|
from .utilities.process.mp import BackgroundMonitor, leave_process
|
||||||
from .utilities.process.exit_hooks import ExitHooks
|
from .utilities.process.exit_hooks import ExitHooks
|
||||||
from .utilities.matching import matches_any_wildcard
|
from .utilities.matching import matches_any_wildcard
|
||||||
@ -105,6 +108,7 @@ from .utilities.networking import get_private_ip
|
|||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
from .backend_interface.task.args import _Arguments
|
from .backend_interface.task.args import _Arguments
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pandas
|
import pandas
|
||||||
import numpy
|
import numpy
|
||||||
@ -527,10 +531,16 @@ class Task(_Task):
|
|||||||
is_deferred = False
|
is_deferred = False
|
||||||
try:
|
try:
|
||||||
if not running_remotely():
|
if not running_remotely():
|
||||||
|
# check remote status
|
||||||
|
_local_rank = get_torch_local_rank()
|
||||||
|
if _local_rank is not None and _local_rank > 0:
|
||||||
|
is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30)
|
||||||
|
|
||||||
# only allow if running locally and creating the first Task
|
# only allow if running locally and creating the first Task
|
||||||
# otherwise we ignore and perform in order
|
# otherwise we ignore and perform in order
|
||||||
if ENV_DEFERRED_TASK_INIT.get():
|
if ENV_DEFERRED_TASK_INIT.get():
|
||||||
deferred_init = True
|
deferred_init = True
|
||||||
|
|
||||||
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
|
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
|
||||||
def completed_cb(x):
|
def completed_cb(x):
|
||||||
Task.__main_task = x
|
Task.__main_task = x
|
||||||
@ -571,6 +581,11 @@ class Task(_Task):
|
|||||||
not auto_connect_frameworks.get('detect_repository', True)) else True,
|
not auto_connect_frameworks.get('detect_repository', True)) else True,
|
||||||
auto_connect_streams=auto_connect_streams,
|
auto_connect_streams=auto_connect_streams,
|
||||||
)
|
)
|
||||||
|
# check if we are local rank 0 (local master),
|
||||||
|
# create an anchor with task ID for the other processes
|
||||||
|
if _local_rank == 0:
|
||||||
|
create_torch_distributed_anchor(task_id=task.id)
|
||||||
|
|
||||||
except MissingConfigError as e:
|
except MissingConfigError as e:
|
||||||
if not ENV_IGNORE_MISSING_CONFIG.get():
|
if not ENV_IGNORE_MISSING_CONFIG.get():
|
||||||
raise
|
raise
|
||||||
|
96
clearml/utilities/lowlevel/distributed.py
Normal file
96
clearml/utilities/lowlevel/distributed.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
from logging import getLogger
|
||||||
|
from time import sleep, time
|
||||||
|
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_local_rank():
|
||||||
|
"""
|
||||||
|
return the local rank of the process, notice local rank 0 does not mean global rank 0
|
||||||
|
return None if no torch distributed is running
|
||||||
|
"""
|
||||||
|
if os.environ.get("TORCHELASTIC_RUN_ID") is not None:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
return int(os.environ.get("LOCAL_RANK"))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_torch_distributed_anchor(task_id):
|
||||||
|
"""
|
||||||
|
This will create a temporary file to pass the Task ID created by local_rank 0 of
|
||||||
|
if None local rank 0 is calling this file, it
|
||||||
|
|
||||||
|
Only call when running locally (i.e. without an agent),
|
||||||
|
if running remotely there is no need to pass Task ID, it will be passed externally
|
||||||
|
"""
|
||||||
|
local_file_name = ".clearml_torch_distributed_id"
|
||||||
|
|
||||||
|
if get_torch_local_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
|
||||||
|
|
||||||
|
if not torch_dist_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
torch_dist_path = Path(torch_dist_path).parent.parent.parent
|
||||||
|
# create the file
|
||||||
|
with open(torch_dist_path / local_file_name, "wt") as f:
|
||||||
|
f.write(str(task_id)+"\n")
|
||||||
|
except Exception:
|
||||||
|
# we failed for some reason?
|
||||||
|
getLogger().warning("Failed creating torch task ID anchor file: {}".format(torch_dist_path))
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_distributed_anchor_task_id(timeout=None):
|
||||||
|
"""
|
||||||
|
This will wait until a temporary file appears and read the Task ID created by local_rank 0 of
|
||||||
|
|
||||||
|
Only call when running locally (i.e. without an agent),
|
||||||
|
if running remotely there is no need to pass Task ID, it will be passed externally
|
||||||
|
|
||||||
|
:return Task ID of the local task to report to
|
||||||
|
"""
|
||||||
|
|
||||||
|
# check that we are not local rank 0
|
||||||
|
_local_rank = get_torch_local_rank()
|
||||||
|
if not _local_rank:
|
||||||
|
return
|
||||||
|
|
||||||
|
local_file_name = ".clearml_torch_distributed_id"
|
||||||
|
|
||||||
|
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
|
||||||
|
if not torch_dist_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
task_id = None
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
torch_dist_path = Path(torch_dist_path).parent.parent.parent / local_file_name
|
||||||
|
|
||||||
|
tic = time()
|
||||||
|
# wait until disturbed file exists
|
||||||
|
while not torch_dist_path.is_file():
|
||||||
|
# if we found nothing, return None
|
||||||
|
if timeout is not None and time() - tic > timeout:
|
||||||
|
getLogger().warning("Failed detecting rank zero clearml Task ID, creating a new Task")
|
||||||
|
return None
|
||||||
|
# wait
|
||||||
|
sleep(0.1)
|
||||||
|
|
||||||
|
# create the file
|
||||||
|
with open(torch_dist_path, "rt") as f:
|
||||||
|
task_id = f.read().strip(" \n")
|
||||||
|
except Exception:
|
||||||
|
# we failed for some reason?
|
||||||
|
pass
|
||||||
|
|
||||||
|
getLogger().warning("Torch Distributed Local Rank {} Task ID {} detected".format(_local_rank, task_id))
|
||||||
|
return task_id
|
Loading…
Reference in New Issue
Block a user