mirror of
https://github.com/clearml/clearml
synced 2025-06-10 00:25:53 +00:00
Fix Task.launch_multi_node() not supported when used via pytorch lightning
This commit is contained in:
parent
aa227a0cdb
commit
e27d277e40
@ -180,6 +180,8 @@ class Task(_Task):
|
|||||||
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
|
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
|
||||||
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)
|
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)
|
||||||
|
|
||||||
|
__hidden_tag = "hidden"
|
||||||
|
|
||||||
_launch_multi_node_section = "launch_multi_node"
|
_launch_multi_node_section = "launch_multi_node"
|
||||||
_launch_multi_node_instance_tag = "multi_node_instance"
|
_launch_multi_node_instance_tag = "multi_node_instance"
|
||||||
|
|
||||||
@ -1921,8 +1923,16 @@ class Task(_Task):
|
|||||||
"""
|
"""
|
||||||
return self._get_logger(auto_connect_streams=self._log_to_backend)
|
return self._get_logger(auto_connect_streams=self._log_to_backend)
|
||||||
|
|
||||||
def launch_multi_node(self, total_num_nodes, port=29500, queue=None, wait=False, addr=None):
|
def launch_multi_node(
|
||||||
# type: (int, Optional[int], Optional[str], bool, Optional[str]) -> dict
|
self,
|
||||||
|
total_num_nodes, # type: int
|
||||||
|
port=29500, # type: Optional[int]
|
||||||
|
queue=None, # type: Optional[str]
|
||||||
|
wait=False, # type: bool
|
||||||
|
addr=None, # type: Optional[str]
|
||||||
|
devices=None, # type: Optional[Union[int, Sequence[int]]]
|
||||||
|
hide_children=False # bool
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Enqueue multiple clones of the current task to a queue, allowing the task
|
Enqueue multiple clones of the current task to a queue, allowing the task
|
||||||
to be ran by multiple workers in parallel. Each task running this way is called a node.
|
to be ran by multiple workers in parallel. Each task running this way is called a node.
|
||||||
@ -1996,6 +2006,9 @@ class Task(_Task):
|
|||||||
parameter will be set to the one defined in ``MASTER_ADDR``. If neither environment variables exist,
|
parameter will be set to the one defined in ``MASTER_ADDR``. If neither environment variables exist,
|
||||||
the value passed to the parameter will be used. If this value is None (default), the private IP of
|
the value passed to the parameter will be used. If this value is None (default), the private IP of
|
||||||
the machine the master node is running on will be used.
|
the machine the master node is running on will be used.
|
||||||
|
:param devices: The devices to use. This can be a positive number indicating the number of devices to use,
|
||||||
|
a sequence of indices or the value ``-1`` to indicate all available devices should be used.
|
||||||
|
:param hide_children: If True, the children tasks will be hidden. Otherwise, they will be visible in the UI
|
||||||
|
|
||||||
:return: A dictionary containing relevant information regarding the multi node run. This dictionary has the following entries:
|
:return: A dictionary containing relevant information regarding the multi node run. This dictionary has the following entries:
|
||||||
|
|
||||||
@ -2006,9 +2019,12 @@ class Task(_Task):
|
|||||||
- `node_rank` - the rank of the current node (master has rank 0)
|
- `node_rank` - the rank of the current node (master has rank 0)
|
||||||
- `wait` - if True, the master node will wait for the other nodes to start
|
- `wait` - if True, the master node will wait for the other nodes to start
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_launch_multi_node_runtime_props(task, conf):
|
def set_launch_multi_node_runtime_props(task, conf):
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
task._set_runtime_properties({"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()})
|
task._set_runtime_properties(
|
||||||
|
{"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()}
|
||||||
|
)
|
||||||
|
|
||||||
if total_num_nodes < 1:
|
if total_num_nodes < 1:
|
||||||
raise UsageError("total_num_nodes needs to be at least 1")
|
raise UsageError("total_num_nodes needs to be at least 1")
|
||||||
@ -2024,6 +2040,7 @@ class Task(_Task):
|
|||||||
),
|
),
|
||||||
"node_rank": 0,
|
"node_rank": 0,
|
||||||
"wait": wait,
|
"wait": wait,
|
||||||
|
"devices": devices
|
||||||
}
|
}
|
||||||
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
|
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
|
||||||
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
|
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
|
||||||
@ -2033,23 +2050,27 @@ class Task(_Task):
|
|||||||
runtime_properties = self._get_runtime_properties()
|
runtime_properties = self._get_runtime_properties()
|
||||||
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))
|
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))
|
||||||
|
|
||||||
|
current_conf = master_conf
|
||||||
if remote_node_rank:
|
if remote_node_rank:
|
||||||
# self is a child node, build the conf from the runtime proprerties
|
# self is a child node, build the conf from the runtime proprerties
|
||||||
current_conf = {
|
current_conf = {
|
||||||
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
|
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
|
||||||
for entry in master_conf.keys()
|
for entry in master_conf.keys()
|
||||||
}
|
}
|
||||||
else:
|
elif os.environ.get("CLEARML_MULTI_NODE_MASTER") is None:
|
||||||
nodes_to_wait = []
|
nodes_to_wait = []
|
||||||
# self is the master node, enqueue the other nodes
|
# self is the master node, enqueue the other nodes
|
||||||
set_launch_multi_node_runtime_props(self, master_conf)
|
set_launch_multi_node_runtime_props(self, master_conf)
|
||||||
current_conf = master_conf
|
|
||||||
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
|
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
|
||||||
node = self.clone(source_task=self, parent=self.id)
|
node = self.clone(source_task=self, parent=self.id)
|
||||||
node_conf = copy.deepcopy(master_conf)
|
node_conf = copy.deepcopy(master_conf)
|
||||||
node_conf["node_rank"] = node_rank
|
node_conf["node_rank"] = node_rank
|
||||||
set_launch_multi_node_runtime_props(node, node_conf)
|
set_launch_multi_node_runtime_props(node, node_conf)
|
||||||
node.set_system_tags(node.get_system_tags() + [self._launch_multi_node_instance_tag])
|
node.set_system_tags(
|
||||||
|
node.get_system_tags()
|
||||||
|
+ [self._launch_multi_node_instance_tag]
|
||||||
|
+ ([self.__hidden_tag] if hide_children else [])
|
||||||
|
)
|
||||||
if master_conf.get("queue"):
|
if master_conf.get("queue"):
|
||||||
Task.enqueue(node, queue_name=master_conf["queue"])
|
Task.enqueue(node, queue_name=master_conf["queue"])
|
||||||
else:
|
else:
|
||||||
@ -2064,16 +2085,42 @@ class Task(_Task):
|
|||||||
Task.TaskStatusEnum.stopped,
|
Task.TaskStatusEnum.stopped,
|
||||||
Task.TaskStatusEnum.closed,
|
Task.TaskStatusEnum.closed,
|
||||||
Task.TaskStatusEnum.failed,
|
Task.TaskStatusEnum.failed,
|
||||||
Task.TaskStatusEnum.in_progress
|
Task.TaskStatusEnum.in_progress,
|
||||||
),
|
),
|
||||||
check_interval_sec=10
|
check_interval_sec=10,
|
||||||
)
|
)
|
||||||
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
|
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
|
||||||
|
os.environ["CLEARML_MULTI_NODE_MASTER"] = "1"
|
||||||
|
|
||||||
|
num_devices = 1
|
||||||
|
if devices is not None:
|
||||||
|
try:
|
||||||
|
num_devices = int(devices)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
num_devices = len(devices)
|
||||||
|
except Exception as ex:
|
||||||
|
raise ValueError("Failed parsing number of devices: {}".format(ex))
|
||||||
|
except ValueError as ex:
|
||||||
|
raise ValueError("Failed parsing number of devices: {}".format(ex))
|
||||||
|
if num_devices < 0:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
num_devices = torch.cuda.device_count()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import `torch` while finding the number of devices. "
|
||||||
|
"Please install it or set `devices` to a value different than -1"
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
|
os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
|
||||||
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
|
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
|
||||||
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", ""))
|
os.environ["RANK"] = str(
|
||||||
os.environ["RANK"] = str(current_conf.get("node_rank", ""))
|
current_conf.get("node_rank", 0) * num_devices + int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
)
|
||||||
|
os.environ["NODE_RANK"] = str(current_conf.get("node_rank", ""))
|
||||||
|
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", total_num_nodes) * num_devices)
|
||||||
|
|
||||||
return current_conf
|
return current_conf
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user