mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Support multi-node training using Task.launch_multi_node()
This commit is contained in:
parent
a0bc87ab5c
commit
5625a4c485
143
clearml/task.py
143
clearml/task.py
@ -1,4 +1,5 @@
|
||||
import atexit
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@ -100,6 +101,7 @@ from .utilities.lowlevel.threads import get_current_thread_id
|
||||
from .utilities.process.mp import BackgroundMonitor, leave_process
|
||||
from .utilities.matching import matches_any_wildcard
|
||||
from .utilities.parallel import FutureTaskCaller
|
||||
from .utilities.networking import get_private_ip
|
||||
# noinspection PyProtectedMember
|
||||
from .backend_interface.task.args import _Arguments
|
||||
|
||||
@ -174,6 +176,9 @@ class Task(_Task):
|
||||
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
|
||||
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)
|
||||
|
||||
_launch_multi_node_section = "launch_multi_node"
|
||||
_launch_multi_node_instance_tag = "multi_node_instance"
|
||||
|
||||
class _ConnectedParametersType(object):
|
||||
argparse = "argument_parser"
|
||||
dictionary = "dictionary"
|
||||
@ -1639,6 +1644,144 @@ class Task(_Task):
|
||||
"""
|
||||
return self._get_logger(auto_connect_streams=self._log_to_backend)
|
||||
|
||||
def launch_multi_node(self, total_num_nodes, port=29500, queue=None, wait=False, addr=None):
|
||||
# type: (int, Optional[int], Optional[str], bool, Optional[str]) -> dict
|
||||
"""
|
||||
Enqueue multiple clones of the current task to a queue, allowing the task
|
||||
to be ran by multiple workers in parallel. Each task running this way is called a node.
|
||||
Each node has a rank The node that initialized the execution of the other nodes
|
||||
is called the `master node` and it has a rank equal to 0.
|
||||
|
||||
A dictionary named `multi_node_instance` will be connected to the tasks.
|
||||
One can use this dictionary to modify the behaviour of this function when running remotely.
|
||||
The contents of this dictionary correspond to the parameters of this function, and they are:
|
||||
- `total_num_nodes` - the total number of nodes, including the master node
|
||||
- `queue` - the queue to enqueue the nodes to
|
||||
|
||||
The following environment variables, will be set:
|
||||
- `MASTER_ADDR` - the address of the machine that the master node is running on
|
||||
- `MASTER_PORT` - the open port of the machine that the master node is running on
|
||||
- `WORLD_SIZE` - the total number of nodes, including the master
|
||||
- `RANK` - the rank of the current node (master has rank 0)
|
||||
|
||||
One may use this function in conjuction with PyTorch's distributed communication package.
|
||||
Note that `Task.launch_multi_node` should be called before `torch.distributed.init_process_group`.
|
||||
For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
from clearml import Task
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
def run(rank, size):
|
||||
print('World size is ', size)
|
||||
tensor = torch.zeros(1)
|
||||
if rank == 0:
|
||||
for i in range(1, size):
|
||||
tensor += 1
|
||||
dist.send(tensor=tensor, dst=i)
|
||||
print('Sending from rank ', rank, ' to rank ', i, ' data: ', tensor[0])
|
||||
else:
|
||||
dist.recv(tensor=tensor, src=0)
|
||||
print('Rank ', rank, ' received data: ', tensor[0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
task = Task.init('some_name', 'some_name')
|
||||
task.execute_remotely(queue_name='queue')
|
||||
config = task.launch_multi_node(4)
|
||||
dist.init_process_group('gloo')
|
||||
run(config.get('node_rank'), config.get('total_num_nodes'))
|
||||
|
||||
:param total_num_nodes: The total number of nodes to be enqueued, including the master node,
|
||||
which should already be enqueued when running remotely
|
||||
:param port: Port opened by the master node. If the environment variable `CLEARML_MULTI_NODE_MASTER_DEF_PORT`
|
||||
is set, the value of this parameter will be set to the one defined in `CLEARML_MULTI_NODE_MASTER_DEF_PORT`.
|
||||
If `CLEARML_MULTI_NODE_MASTER_DEF_PORT` doesn't exist, but `MASTER_PORT` does, then the value of this
|
||||
parameter will be set to the one defined in `MASTER_PORT`. If neither environment variables exist,
|
||||
the value passed to the parameter will be used
|
||||
:param queue: The queue to enqueue the nodes to. Can be different than the queue the master
|
||||
node is enqueued to. If None, the nodes will be enqueued to the same queue as the master node
|
||||
:param wait: If True, the master node will wait for the other nodes to start
|
||||
:param addr: The address of the master node's worker. If not set, it defaults to the private IP
|
||||
of the machine the master is running on
|
||||
|
||||
:return: A dictionary containing relevant information regarding the multi node run. This dictionary
|
||||
has the following entries:
|
||||
- `master_addr` - the address of the machine that the master node is running on
|
||||
- `master_port` - the open port of the machine that the master node is running on
|
||||
- `total_num_nodes` - the total number of nodes, including the master
|
||||
- `queue` - the queue the nodes are enqueued to, excluding the master
|
||||
- `node_rank` - the rank of the current node (master has rank 0)
|
||||
- `wait` - if True, the master node will wait for the other nodes to start
|
||||
"""
|
||||
def set_launch_multi_node_runtime_props(task, conf):
|
||||
# noinspection PyProtectedMember
|
||||
task._set_runtime_properties({"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()})
|
||||
|
||||
if total_num_nodes < 1:
|
||||
raise UsageError("total_num_nodes needs to be at least 1")
|
||||
if running_remotely() and not (self.data.execution and self.data.execution.queue) and not queue:
|
||||
raise UsageError("Master task is not enqueued to any queue and the queue parameter is None")
|
||||
|
||||
master_conf = {
|
||||
"master_addr": get_private_ip(),
|
||||
"master_port": int(os.environ.get("CLEARML_MULTI_NODE_MASTER_DEF_PORT", os.environ.get("MASTER_PORT", port))),
|
||||
"node_rank": 0,
|
||||
"wait": wait
|
||||
}
|
||||
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
|
||||
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
|
||||
if not running_remotely():
|
||||
return master_conf
|
||||
master_conf.update(editable_conf)
|
||||
runtime_properties = self._get_runtime_properties()
|
||||
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))
|
||||
|
||||
if remote_node_rank:
|
||||
# self is a child node, build the conf from the runtime proprerties
|
||||
current_conf = {
|
||||
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
|
||||
for entry in master_conf.keys()
|
||||
}
|
||||
else:
|
||||
nodes_to_wait = []
|
||||
# self is the master node, enqueue the other nodes
|
||||
set_launch_multi_node_runtime_props(self, master_conf)
|
||||
current_conf = master_conf
|
||||
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
|
||||
node = self.clone(source_task=self)
|
||||
node_conf = copy.deepcopy(master_conf)
|
||||
node_conf["node_rank"] = node_rank
|
||||
set_launch_multi_node_runtime_props(node, node_conf)
|
||||
node.set_system_tags(node.get_system_tags() + [self._launch_multi_node_instance_tag])
|
||||
if master_conf.get("queue"):
|
||||
Task.enqueue(node, queue_name=master_conf["queue"])
|
||||
else:
|
||||
Task.enqueue(node, queue_id=self.data.execution.queue)
|
||||
if master_conf.get("wait"):
|
||||
nodes_to_wait.append(node)
|
||||
for node_to_wait, rank in zip(nodes_to_wait, range(1, master_conf.get("total_num_nodes", total_num_nodes))):
|
||||
self.log.info("Waiting for node with task ID {} and rank {}".format(node_to_wait.id, rank))
|
||||
node_to_wait.wait_for_status(
|
||||
status=(
|
||||
Task.TaskStatusEnum.completed,
|
||||
Task.TaskStatusEnum.stopped,
|
||||
Task.TaskStatusEnum.closed,
|
||||
Task.TaskStatusEnum.failed,
|
||||
Task.TaskStatusEnum.in_progress
|
||||
),
|
||||
check_interval_sec=10
|
||||
)
|
||||
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
|
||||
|
||||
os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
|
||||
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
|
||||
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", ""))
|
||||
os.environ["RANK"] = str(current_conf.get("node_rank", ""))
|
||||
|
||||
return current_conf
|
||||
|
||||
def mark_started(self, force=False):
|
||||
# type: (bool) -> ()
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user