Add support for auto detecting torch and transformers accelerate distributed execution

This commit is contained in:
allegroai 2024-01-06 12:34:32 +02:00
parent 141a183235
commit 801c7b4cd4
4 changed files with 180 additions and 11 deletions

View File

@ -102,7 +102,7 @@ class ScriptRequirements(object):
for fname, lines in tfmodule.items():
modules.add('tensorflow', fname, lines)
# if we have torch and it supports tensorboard, we should add that as well
# if we have torch, and it supports tensorboard, we should add that as well
# (because it will not be detected automatically)
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
# noinspection PyBroadException
@ -336,14 +336,14 @@ class _JupyterObserver(object):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from nbconvert.exporters import PythonExporter
from nbconvert.exporters import PythonExporter # noqa
_script_exporter = PythonExporter()
except Exception:
_script_exporter = None
if _script_exporter is None:
# noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter
from nbconvert.exporters.script import ScriptExporter # noqa
_script_exporter = ScriptExporter()
except Exception as ex:
@ -622,7 +622,7 @@ class ScriptInfo(object):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from notebook.notebookapp import list_running_servers # <= Notebook v6
from notebook.notebookapp import list_running_servers # noqa <= Notebook v6
# noinspection PyBroadException
try:
jupyter_servers += list(list_running_servers())
@ -637,7 +637,7 @@ class ScriptInfo(object):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from jupyter_server.serverapp import list_running_servers
from jupyter_server.serverapp import list_running_servers # noqa
# noinspection PyBroadException
try:
jupyter_servers += list(list_running_servers())
@ -724,7 +724,7 @@ class ScriptInfo(object):
is_google_colab = False
log_history = False
colab_name = None
# check if this is google.colab, then there is no local file
# check if this is `google.colab`, then there is no local file
is_google_colab = ScriptInfo.is_google_colab()
if is_google_colab:
@ -753,7 +753,7 @@ class ScriptInfo(object):
if not entry_point.exists():
# noinspection PyBroadException
try:
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5])+'.ipynb'
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb'
# now we should try to find the actual file
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
if not entry_point_alternative.is_file():
@ -828,7 +828,7 @@ class ScriptInfo(object):
# returns tuple (notebook name, raw string notebook)
# None, None if fails
try:
from google.colab import _message
from google.colab import _message # noqa
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
@ -995,6 +995,10 @@ class ScriptInfo(object):
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
# check if we are running with torch distributed, or transformers accelerate
# make sure we change the entry point to reflect it.
entry_point = cls._detect_distributed_execution(entry_point, log)
if check_uncommitted:
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
if jupyter_filepath:
@ -1010,7 +1014,7 @@ class ScriptInfo(object):
if len(diff) > cls.max_diff_size_bytes:
messages.append(
"======> WARNING! Git diff too large to store "
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024))
"({}kb), skipping uncommitted changes <======".format(len(diff) // 1024))
auxiliary_git_diff = diff
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
@ -1065,6 +1069,52 @@ class ScriptInfo(object):
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
script_requirements)
@classmethod
def _detect_distributed_execution(cls, entry_point, log):
# check if we are running with torch distributed, or transformers accelerate
# make sure we change the entry point to reflect it.
is_torch_distributed = os.environ.get("TORCHELASTIC_RUN_ID") is not None
is_transformers_distributed = os.environ.get("ACCELERATE_DYNAMO_MODE") is not None
if not is_torch_distributed and not is_transformers_distributed:
return entry_point
# this torch distributed
# noinspection PyBroadException
try:
from psutil import Process # noqa
cmdline = Process().parent().cmdline()
# first find the torch model call "torch.distributed.run" or "torch.distributed.launch"
if is_torch_distributed:
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("torch.distributed."))
elif is_transformers_distributed:
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("accelerate.commands."))
else:
raise Exception() # we should not get here
cmdline = cmdline[cmdstart_i:]
# reverse look into the paths
cmdend_i = next(i for i, c in enumerate(cmdline) if Path(c).stem == Path(entry_point).stem)
filearg = cmdline[cmdend_i]
# notice --args (script args) are passed on the Args section, we skip detecting them here
# we are also already removing the filearg from the cmd (it is the last before script args)
new_cmd = cmdline[:cmdend_i]
# we assume our entrypoint is the last parameter of the execution cmd line
if Path(filearg).stem == Path(entry_point).stem:
entry_point = "-m {} {}".format(" ".join(new_cmd), entry_point)
if log:
log.info(
"{} execution detected: adjusting entrypoint to "
"reflect distributed execution arguments".format(
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate")
)
except Exception:
if log:
log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format(
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate"))
return entry_point
@staticmethod
def __legacy_jupyter_notebook_server_json_parsing():
# noinspection PyBroadException

View File

@ -186,7 +186,15 @@ def get_node_id(default=0):
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
# if node is is till None, use the default
# if node is still None, use the global RANK
if node_id is None:
# noinspection PyBroadException
try:
node_id = int(os.environ.get("RANK"))
except Exception:
pass
# if node is still None, use the default
if node_id is None:
node_id = default

View File

@ -40,7 +40,8 @@ from .backend_config.defs import get_active_config_file, get_config_file
from .backend_api.services import tasks, projects, events
from .backend_api.session.session import (
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
from .backend_api.session.defs import ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, ENV_OFFLINE_MODE, MissingConfigError
from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG,
ENV_OFFLINE_MODE, MissingConfigError)
from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel
from .backend_interface.base import InterfaceBase
@ -97,6 +98,8 @@ from .utilities.proxy_object import (
from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
from .utilities.lowlevel.threads import get_current_thread_id
from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \
create_torch_distributed_anchor
from .utilities.process.mp import BackgroundMonitor, leave_process
from .utilities.process.exit_hooks import ExitHooks
from .utilities.matching import matches_any_wildcard
@ -105,6 +108,7 @@ from .utilities.networking import get_private_ip
# noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments
if TYPE_CHECKING:
import pandas
import numpy
@ -527,10 +531,16 @@ class Task(_Task):
is_deferred = False
try:
if not running_remotely():
# check remote status
_local_rank = get_torch_local_rank()
if _local_rank is not None and _local_rank > 0:
is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30)
# only allow if running locally and creating the first Task
# otherwise we ignore and perform in order
if ENV_DEFERRED_TASK_INIT.get():
deferred_init = True
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
def completed_cb(x):
Task.__main_task = x
@ -571,6 +581,11 @@ class Task(_Task):
not auto_connect_frameworks.get('detect_repository', True)) else True,
auto_connect_streams=auto_connect_streams,
)
# check if we are local rank 0 (local master),
# create an anchor with task ID for the other processes
if _local_rank == 0:
create_torch_distributed_anchor(task_id=task.id)
except MissingConfigError as e:
if not ENV_IGNORE_MISSING_CONFIG.get():
raise

View File

@ -0,0 +1,96 @@
import os
from logging import getLogger
from time import sleep, time
from pathlib2 import Path
def get_torch_local_rank():
"""
return the local rank of the process, notice local rank 0 does not mean global rank 0
return None if no torch distributed is running
"""
if os.environ.get("TORCHELASTIC_RUN_ID") is not None:
# noinspection PyBroadException
try:
return int(os.environ.get("LOCAL_RANK"))
except Exception:
return None
return None
def create_torch_distributed_anchor(task_id):
"""
This will create a temporary file to pass the Task ID created by local_rank 0 of
if None local rank 0 is calling this file, it
Only call when running locally (i.e. without an agent),
if running remotely there is no need to pass Task ID, it will be passed externally
"""
local_file_name = ".clearml_torch_distributed_id"
if get_torch_local_rank() != 0:
return
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
if not torch_dist_path:
return
# noinspection PyBroadException
try:
torch_dist_path = Path(torch_dist_path).parent.parent.parent
# create the file
with open(torch_dist_path / local_file_name, "wt") as f:
f.write(str(task_id)+"\n")
except Exception:
# we failed for some reason?
getLogger().warning("Failed creating torch task ID anchor file: {}".format(torch_dist_path))
def get_torch_distributed_anchor_task_id(timeout=None):
"""
This will wait until a temporary file appears and read the Task ID created by local_rank 0 of
Only call when running locally (i.e. without an agent),
if running remotely there is no need to pass Task ID, it will be passed externally
:return Task ID of the local task to report to
"""
# check that we are not local rank 0
_local_rank = get_torch_local_rank()
if not _local_rank:
return
local_file_name = ".clearml_torch_distributed_id"
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
if not torch_dist_path:
return
task_id = None
# noinspection PyBroadException
try:
torch_dist_path = Path(torch_dist_path).parent.parent.parent / local_file_name
tic = time()
# wait until disturbed file exists
while not torch_dist_path.is_file():
# if we found nothing, return None
if timeout is not None and time() - tic > timeout:
getLogger().warning("Failed detecting rank zero clearml Task ID, creating a new Task")
return None
# wait
sleep(0.1)
# create the file
with open(torch_dist_path, "rt") as f:
task_id = f.read().strip(" \n")
except Exception:
# we failed for some reason?
pass
getLogger().warning("Torch Distributed Local Rank {} Task ID {} detected".format(_local_rank, task_id))
return task_id