mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Add support for auto detecting torch and transformers accelerate distributed execution
This commit is contained in:
parent
141a183235
commit
801c7b4cd4
@ -102,7 +102,7 @@ class ScriptRequirements(object):
|
||||
for fname, lines in tfmodule.items():
|
||||
modules.add('tensorflow', fname, lines)
|
||||
|
||||
# if we have torch and it supports tensorboard, we should add that as well
|
||||
# if we have torch, and it supports tensorboard, we should add that as well
|
||||
# (because it will not be detected automatically)
|
||||
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
|
||||
# noinspection PyBroadException
|
||||
@ -336,14 +336,14 @@ class _JupyterObserver(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters import PythonExporter
|
||||
from nbconvert.exporters import PythonExporter # noqa
|
||||
_script_exporter = PythonExporter()
|
||||
except Exception:
|
||||
_script_exporter = None
|
||||
|
||||
if _script_exporter is None:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters.script import ScriptExporter
|
||||
from nbconvert.exporters.script import ScriptExporter # noqa
|
||||
_script_exporter = ScriptExporter()
|
||||
|
||||
except Exception as ex:
|
||||
@ -622,7 +622,7 @@ class ScriptInfo(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from notebook.notebookapp import list_running_servers # <= Notebook v6
|
||||
from notebook.notebookapp import list_running_servers # noqa <= Notebook v6
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
jupyter_servers += list(list_running_servers())
|
||||
@ -637,7 +637,7 @@ class ScriptInfo(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from jupyter_server.serverapp import list_running_servers
|
||||
from jupyter_server.serverapp import list_running_servers # noqa
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
jupyter_servers += list(list_running_servers())
|
||||
@ -724,7 +724,7 @@ class ScriptInfo(object):
|
||||
is_google_colab = False
|
||||
log_history = False
|
||||
colab_name = None
|
||||
# check if this is google.colab, then there is no local file
|
||||
# check if this is `google.colab`, then there is no local file
|
||||
is_google_colab = ScriptInfo.is_google_colab()
|
||||
|
||||
if is_google_colab:
|
||||
@ -753,7 +753,7 @@ class ScriptInfo(object):
|
||||
if not entry_point.exists():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5])+'.ipynb'
|
||||
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb'
|
||||
# now we should try to find the actual file
|
||||
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
|
||||
if not entry_point_alternative.is_file():
|
||||
@ -828,7 +828,7 @@ class ScriptInfo(object):
|
||||
# returns tuple (notebook name, raw string notebook)
|
||||
# None, None if fails
|
||||
try:
|
||||
from google.colab import _message
|
||||
from google.colab import _message # noqa
|
||||
|
||||
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
|
||||
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
|
||||
@ -995,6 +995,10 @@ class ScriptInfo(object):
|
||||
working_dir = cls._get_working_dir(repo_root)
|
||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||
|
||||
# check if we are running with torch distributed, or transformers accelerate
|
||||
# make sure we change the entry point to reflect it.
|
||||
entry_point = cls._detect_distributed_execution(entry_point, log)
|
||||
|
||||
if check_uncommitted:
|
||||
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
|
||||
if jupyter_filepath:
|
||||
@ -1010,7 +1014,7 @@ class ScriptInfo(object):
|
||||
if len(diff) > cls.max_diff_size_bytes:
|
||||
messages.append(
|
||||
"======> WARNING! Git diff too large to store "
|
||||
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024))
|
||||
"({}kb), skipping uncommitted changes <======".format(len(diff) // 1024))
|
||||
auxiliary_git_diff = diff
|
||||
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
|
||||
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
|
||||
@ -1065,6 +1069,52 @@ class ScriptInfo(object):
|
||||
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
|
||||
script_requirements)
|
||||
|
||||
@classmethod
|
||||
def _detect_distributed_execution(cls, entry_point, log):
|
||||
# check if we are running with torch distributed, or transformers accelerate
|
||||
# make sure we change the entry point to reflect it.
|
||||
is_torch_distributed = os.environ.get("TORCHELASTIC_RUN_ID") is not None
|
||||
is_transformers_distributed = os.environ.get("ACCELERATE_DYNAMO_MODE") is not None
|
||||
if not is_torch_distributed and not is_transformers_distributed:
|
||||
return entry_point
|
||||
|
||||
# this torch distributed
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from psutil import Process # noqa
|
||||
cmdline = Process().parent().cmdline()
|
||||
# first find the torch model call "torch.distributed.run" or "torch.distributed.launch"
|
||||
if is_torch_distributed:
|
||||
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("torch.distributed."))
|
||||
elif is_transformers_distributed:
|
||||
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("accelerate.commands."))
|
||||
else:
|
||||
raise Exception() # we should not get here
|
||||
|
||||
cmdline = cmdline[cmdstart_i:]
|
||||
# reverse look into the paths
|
||||
cmdend_i = next(i for i, c in enumerate(cmdline) if Path(c).stem == Path(entry_point).stem)
|
||||
filearg = cmdline[cmdend_i]
|
||||
# notice --args (script args) are passed on the Args section, we skip detecting them here
|
||||
# we are also already removing the filearg from the cmd (it is the last before script args)
|
||||
new_cmd = cmdline[:cmdend_i]
|
||||
|
||||
# we assume our entrypoint is the last parameter of the execution cmd line
|
||||
if Path(filearg).stem == Path(entry_point).stem:
|
||||
entry_point = "-m {} {}".format(" ".join(new_cmd), entry_point)
|
||||
if log:
|
||||
log.info(
|
||||
"{} execution detected: adjusting entrypoint to "
|
||||
"reflect distributed execution arguments".format(
|
||||
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate")
|
||||
)
|
||||
except Exception:
|
||||
if log:
|
||||
log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format(
|
||||
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate"))
|
||||
|
||||
return entry_point
|
||||
|
||||
@staticmethod
|
||||
def __legacy_jupyter_notebook_server_json_parsing():
|
||||
# noinspection PyBroadException
|
||||
|
@ -186,7 +186,15 @@ def get_node_id(default=0):
|
||||
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
|
||||
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
|
||||
|
||||
# if node is is till None, use the default
|
||||
# if node is still None, use the global RANK
|
||||
if node_id is None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
node_id = int(os.environ.get("RANK"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if node is still None, use the default
|
||||
if node_id is None:
|
||||
node_id = default
|
||||
|
||||
|
@ -40,7 +40,8 @@ from .backend_config.defs import get_active_config_file, get_config_file
|
||||
from .backend_api.services import tasks, projects, events
|
||||
from .backend_api.session.session import (
|
||||
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
|
||||
from .backend_api.session.defs import ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, ENV_OFFLINE_MODE, MissingConfigError
|
||||
from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG,
|
||||
ENV_OFFLINE_MODE, MissingConfigError)
|
||||
from .backend_interface.metrics import Metrics
|
||||
from .backend_interface.model import Model as BackendModel
|
||||
from .backend_interface.base import InterfaceBase
|
||||
@ -97,6 +98,8 @@ from .utilities.proxy_object import (
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
from .utilities.lowlevel.threads import get_current_thread_id
|
||||
from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \
|
||||
create_torch_distributed_anchor
|
||||
from .utilities.process.mp import BackgroundMonitor, leave_process
|
||||
from .utilities.process.exit_hooks import ExitHooks
|
||||
from .utilities.matching import matches_any_wildcard
|
||||
@ -105,6 +108,7 @@ from .utilities.networking import get_private_ip
|
||||
# noinspection PyProtectedMember
|
||||
from .backend_interface.task.args import _Arguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas
|
||||
import numpy
|
||||
@ -527,10 +531,16 @@ class Task(_Task):
|
||||
is_deferred = False
|
||||
try:
|
||||
if not running_remotely():
|
||||
# check remote status
|
||||
_local_rank = get_torch_local_rank()
|
||||
if _local_rank is not None and _local_rank > 0:
|
||||
is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30)
|
||||
|
||||
# only allow if running locally and creating the first Task
|
||||
# otherwise we ignore and perform in order
|
||||
if ENV_DEFERRED_TASK_INIT.get():
|
||||
deferred_init = True
|
||||
|
||||
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
|
||||
def completed_cb(x):
|
||||
Task.__main_task = x
|
||||
@ -571,6 +581,11 @@ class Task(_Task):
|
||||
not auto_connect_frameworks.get('detect_repository', True)) else True,
|
||||
auto_connect_streams=auto_connect_streams,
|
||||
)
|
||||
# check if we are local rank 0 (local master),
|
||||
# create an anchor with task ID for the other processes
|
||||
if _local_rank == 0:
|
||||
create_torch_distributed_anchor(task_id=task.id)
|
||||
|
||||
except MissingConfigError as e:
|
||||
if not ENV_IGNORE_MISSING_CONFIG.get():
|
||||
raise
|
||||
|
96
clearml/utilities/lowlevel/distributed.py
Normal file
96
clearml/utilities/lowlevel/distributed.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
from logging import getLogger
|
||||
from time import sleep, time
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
def get_torch_local_rank():
|
||||
"""
|
||||
return the local rank of the process, notice local rank 0 does not mean global rank 0
|
||||
return None if no torch distributed is running
|
||||
"""
|
||||
if os.environ.get("TORCHELASTIC_RUN_ID") is not None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return int(os.environ.get("LOCAL_RANK"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_torch_distributed_anchor(task_id):
|
||||
"""
|
||||
This will create a temporary file to pass the Task ID created by local_rank 0 of
|
||||
if None local rank 0 is calling this file, it
|
||||
|
||||
Only call when running locally (i.e. without an agent),
|
||||
if running remotely there is no need to pass Task ID, it will be passed externally
|
||||
"""
|
||||
local_file_name = ".clearml_torch_distributed_id"
|
||||
|
||||
if get_torch_local_rank() != 0:
|
||||
return
|
||||
|
||||
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
|
||||
|
||||
if not torch_dist_path:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
torch_dist_path = Path(torch_dist_path).parent.parent.parent
|
||||
# create the file
|
||||
with open(torch_dist_path / local_file_name, "wt") as f:
|
||||
f.write(str(task_id)+"\n")
|
||||
except Exception:
|
||||
# we failed for some reason?
|
||||
getLogger().warning("Failed creating torch task ID anchor file: {}".format(torch_dist_path))
|
||||
|
||||
|
||||
def get_torch_distributed_anchor_task_id(timeout=None):
|
||||
"""
|
||||
This will wait until a temporary file appears and read the Task ID created by local_rank 0 of
|
||||
|
||||
Only call when running locally (i.e. without an agent),
|
||||
if running remotely there is no need to pass Task ID, it will be passed externally
|
||||
|
||||
:return Task ID of the local task to report to
|
||||
"""
|
||||
|
||||
# check that we are not local rank 0
|
||||
_local_rank = get_torch_local_rank()
|
||||
if not _local_rank:
|
||||
return
|
||||
|
||||
local_file_name = ".clearml_torch_distributed_id"
|
||||
|
||||
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
|
||||
if not torch_dist_path:
|
||||
return
|
||||
|
||||
task_id = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
torch_dist_path = Path(torch_dist_path).parent.parent.parent / local_file_name
|
||||
|
||||
tic = time()
|
||||
# wait until disturbed file exists
|
||||
while not torch_dist_path.is_file():
|
||||
# if we found nothing, return None
|
||||
if timeout is not None and time() - tic > timeout:
|
||||
getLogger().warning("Failed detecting rank zero clearml Task ID, creating a new Task")
|
||||
return None
|
||||
# wait
|
||||
sleep(0.1)
|
||||
|
||||
# create the file
|
||||
with open(torch_dist_path, "rt") as f:
|
||||
task_id = f.read().strip(" \n")
|
||||
except Exception:
|
||||
# we failed for some reason?
|
||||
pass
|
||||
|
||||
getLogger().warning("Torch Distributed Local Rank {} Task ID {} detected".format(_local_rank, task_id))
|
||||
return task_id
|
Loading…
Reference in New Issue
Block a user