Add support for auto detecting torch and transformers accelerate distributed execution

This commit is contained in:
allegroai 2024-01-06 12:34:32 +02:00
parent 141a183235
commit 801c7b4cd4
4 changed files with 180 additions and 11 deletions

View File

@ -102,7 +102,7 @@ class ScriptRequirements(object):
for fname, lines in tfmodule.items(): for fname, lines in tfmodule.items():
modules.add('tensorflow', fname, lines) modules.add('tensorflow', fname, lines)
# if we have torch and it supports tensorboard, we should add that as well # if we have torch, and it supports tensorboard, we should add that as well
# (because it will not be detected automatically) # (because it will not be detected automatically)
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules: if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
# noinspection PyBroadException # noinspection PyBroadException
@ -336,14 +336,14 @@ class _JupyterObserver(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from nbconvert.exporters import PythonExporter from nbconvert.exporters import PythonExporter # noqa
_script_exporter = PythonExporter() _script_exporter = PythonExporter()
except Exception: except Exception:
_script_exporter = None _script_exporter = None
if _script_exporter is None: if _script_exporter is None:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter from nbconvert.exporters.script import ScriptExporter # noqa
_script_exporter = ScriptExporter() _script_exporter = ScriptExporter()
except Exception as ex: except Exception as ex:
@ -622,7 +622,7 @@ class ScriptInfo(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from notebook.notebookapp import list_running_servers # <= Notebook v6 from notebook.notebookapp import list_running_servers # noqa <= Notebook v6
# noinspection PyBroadException # noinspection PyBroadException
try: try:
jupyter_servers += list(list_running_servers()) jupyter_servers += list(list_running_servers())
@ -637,7 +637,7 @@ class ScriptInfo(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from jupyter_server.serverapp import list_running_servers from jupyter_server.serverapp import list_running_servers # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try:
jupyter_servers += list(list_running_servers()) jupyter_servers += list(list_running_servers())
@ -724,7 +724,7 @@ class ScriptInfo(object):
is_google_colab = False is_google_colab = False
log_history = False log_history = False
colab_name = None colab_name = None
# check if this is google.colab, then there is no local file # check if this is `google.colab`, then there is no local file
is_google_colab = ScriptInfo.is_google_colab() is_google_colab = ScriptInfo.is_google_colab()
if is_google_colab: if is_google_colab:
@ -753,7 +753,7 @@ class ScriptInfo(object):
if not entry_point.exists(): if not entry_point.exists():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5])+'.ipynb' alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb'
# now we should try to find the actual file # now we should try to find the actual file
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute() entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
if not entry_point_alternative.is_file(): if not entry_point_alternative.is_file():
@ -828,7 +828,7 @@ class ScriptInfo(object):
# returns tuple (notebook name, raw string notebook) # returns tuple (notebook name, raw string notebook)
# None, None if fails # None, None if fails
try: try:
from google.colab import _message from google.colab import _message # noqa
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb'] notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb") notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
@ -995,6 +995,10 @@ class ScriptInfo(object):
working_dir = cls._get_working_dir(repo_root) working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path) entry_point = cls._get_entry_point(repo_root, script_path)
# check if we are running with torch distributed, or transformers accelerate
# make sure we change the entry point to reflect it.
entry_point = cls._detect_distributed_execution(entry_point, log)
if check_uncommitted: if check_uncommitted:
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff) # if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
if jupyter_filepath: if jupyter_filepath:
@ -1010,7 +1014,7 @@ class ScriptInfo(object):
if len(diff) > cls.max_diff_size_bytes: if len(diff) > cls.max_diff_size_bytes:
messages.append( messages.append(
"======> WARNING! Git diff too large to store " "======> WARNING! Git diff too large to store "
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024)) "({}kb), skipping uncommitted changes <======".format(len(diff) // 1024))
auxiliary_git_diff = diff auxiliary_git_diff = diff
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \ diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
'# full git diff available in Artifacts/auxiliary_git_diff\n' \ '# full git diff available in Artifacts/auxiliary_git_diff\n' \
@ -1065,6 +1069,52 @@ class ScriptInfo(object):
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff), return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
script_requirements) script_requirements)
@classmethod
def _detect_distributed_execution(cls, entry_point, log):
# check if we are running with torch distributed, or transformers accelerate
# make sure we change the entry point to reflect it.
is_torch_distributed = os.environ.get("TORCHELASTIC_RUN_ID") is not None
is_transformers_distributed = os.environ.get("ACCELERATE_DYNAMO_MODE") is not None
if not is_torch_distributed and not is_transformers_distributed:
return entry_point
# this torch distributed
# noinspection PyBroadException
try:
from psutil import Process # noqa
cmdline = Process().parent().cmdline()
# first find the torch model call "torch.distributed.run" or "torch.distributed.launch"
if is_torch_distributed:
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("torch.distributed."))
elif is_transformers_distributed:
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("accelerate.commands."))
else:
raise Exception() # we should not get here
cmdline = cmdline[cmdstart_i:]
# reverse look into the paths
cmdend_i = next(i for i, c in enumerate(cmdline) if Path(c).stem == Path(entry_point).stem)
filearg = cmdline[cmdend_i]
# notice --args (script args) are passed on the Args section, we skip detecting them here
# we are also already removing the filearg from the cmd (it is the last before script args)
new_cmd = cmdline[:cmdend_i]
# we assume our entrypoint is the last parameter of the execution cmd line
if Path(filearg).stem == Path(entry_point).stem:
entry_point = "-m {} {}".format(" ".join(new_cmd), entry_point)
if log:
log.info(
"{} execution detected: adjusting entrypoint to "
"reflect distributed execution arguments".format(
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate")
)
except Exception:
if log:
log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format(
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate"))
return entry_point
@staticmethod @staticmethod
def __legacy_jupyter_notebook_server_json_parsing(): def __legacy_jupyter_notebook_server_json_parsing():
# noinspection PyBroadException # noinspection PyBroadException

View File

@ -186,7 +186,15 @@ def get_node_id(default=0):
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None): if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
# if node is is till None, use the default # if node is still None, use the global RANK
if node_id is None:
# noinspection PyBroadException
try:
node_id = int(os.environ.get("RANK"))
except Exception:
pass
# if node is still None, use the default
if node_id is None: if node_id is None:
node_id = default node_id = default

View File

@ -40,7 +40,8 @@ from .backend_config.defs import get_active_config_file, get_config_file
from .backend_api.services import tasks, projects, events from .backend_api.services import tasks, projects, events
from .backend_api.session.session import ( from .backend_api.session.session import (
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, ) Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
from .backend_api.session.defs import ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, ENV_OFFLINE_MODE, MissingConfigError from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG,
ENV_OFFLINE_MODE, MissingConfigError)
from .backend_interface.metrics import Metrics from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.base import InterfaceBase from .backend_interface.base import InterfaceBase
@ -97,6 +98,8 @@ from .utilities.proxy_object import (
from .utilities.resource_monitor import ResourceMonitor from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic from .utilities.seed import make_deterministic
from .utilities.lowlevel.threads import get_current_thread_id from .utilities.lowlevel.threads import get_current_thread_id
from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \
create_torch_distributed_anchor
from .utilities.process.mp import BackgroundMonitor, leave_process from .utilities.process.mp import BackgroundMonitor, leave_process
from .utilities.process.exit_hooks import ExitHooks from .utilities.process.exit_hooks import ExitHooks
from .utilities.matching import matches_any_wildcard from .utilities.matching import matches_any_wildcard
@ -105,6 +108,7 @@ from .utilities.networking import get_private_ip
# noinspection PyProtectedMember # noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments from .backend_interface.task.args import _Arguments
if TYPE_CHECKING: if TYPE_CHECKING:
import pandas import pandas
import numpy import numpy
@ -527,10 +531,16 @@ class Task(_Task):
is_deferred = False is_deferred = False
try: try:
if not running_remotely(): if not running_remotely():
# check remote status
_local_rank = get_torch_local_rank()
if _local_rank is not None and _local_rank > 0:
is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30)
# only allow if running locally and creating the first Task # only allow if running locally and creating the first Task
# otherwise we ignore and perform in order # otherwise we ignore and perform in order
if ENV_DEFERRED_TASK_INIT.get(): if ENV_DEFERRED_TASK_INIT.get():
deferred_init = True deferred_init = True
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag: if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
def completed_cb(x): def completed_cb(x):
Task.__main_task = x Task.__main_task = x
@ -571,6 +581,11 @@ class Task(_Task):
not auto_connect_frameworks.get('detect_repository', True)) else True, not auto_connect_frameworks.get('detect_repository', True)) else True,
auto_connect_streams=auto_connect_streams, auto_connect_streams=auto_connect_streams,
) )
# check if we are local rank 0 (local master),
# create an anchor with task ID for the other processes
if _local_rank == 0:
create_torch_distributed_anchor(task_id=task.id)
except MissingConfigError as e: except MissingConfigError as e:
if not ENV_IGNORE_MISSING_CONFIG.get(): if not ENV_IGNORE_MISSING_CONFIG.get():
raise raise

View File

@ -0,0 +1,96 @@
import os
from logging import getLogger
from time import sleep, time
from pathlib2 import Path
def get_torch_local_rank():
"""
return the local rank of the process, notice local rank 0 does not mean global rank 0
return None if no torch distributed is running
"""
if os.environ.get("TORCHELASTIC_RUN_ID") is not None:
# noinspection PyBroadException
try:
return int(os.environ.get("LOCAL_RANK"))
except Exception:
return None
return None
def create_torch_distributed_anchor(task_id):
"""
This will create a temporary file to pass the Task ID created by local_rank 0 of
if None local rank 0 is calling this file, it
Only call when running locally (i.e. without an agent),
if running remotely there is no need to pass Task ID, it will be passed externally
"""
local_file_name = ".clearml_torch_distributed_id"
if get_torch_local_rank() != 0:
return
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
if not torch_dist_path:
return
# noinspection PyBroadException
try:
torch_dist_path = Path(torch_dist_path).parent.parent.parent
# create the file
with open(torch_dist_path / local_file_name, "wt") as f:
f.write(str(task_id)+"\n")
except Exception:
# we failed for some reason?
getLogger().warning("Failed creating torch task ID anchor file: {}".format(torch_dist_path))
def get_torch_distributed_anchor_task_id(timeout=None):
"""
This will wait until a temporary file appears and read the Task ID created by local_rank 0 of
Only call when running locally (i.e. without an agent),
if running remotely there is no need to pass Task ID, it will be passed externally
:return Task ID of the local task to report to
"""
# check that we are not local rank 0
_local_rank = get_torch_local_rank()
if not _local_rank:
return
local_file_name = ".clearml_torch_distributed_id"
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
if not torch_dist_path:
return
task_id = None
# noinspection PyBroadException
try:
torch_dist_path = Path(torch_dist_path).parent.parent.parent / local_file_name
tic = time()
# wait until disturbed file exists
while not torch_dist_path.is_file():
# if we found nothing, return None
if timeout is not None and time() - tic > timeout:
getLogger().warning("Failed detecting rank zero clearml Task ID, creating a new Task")
return None
# wait
sleep(0.1)
# create the file
with open(torch_dist_path, "rt") as f:
task_id = f.read().strip(" \n")
except Exception:
# we failed for some reason?
pass
getLogger().warning("Torch Distributed Local Rank {} Task ID {} detected".format(_local_rank, task_id))
return task_id