mirror of
https://github.com/clearml/clearml
synced 2025-04-05 13:15:17 +00:00
Fix Task.launch_multi_node() not supported when used via pytorch lightning
This commit is contained in:
parent
aa227a0cdb
commit
e27d277e40
@ -180,6 +180,8 @@ class Task(_Task):
|
||||
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
|
||||
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)
|
||||
|
||||
__hidden_tag = "hidden"
|
||||
|
||||
_launch_multi_node_section = "launch_multi_node"
|
||||
_launch_multi_node_instance_tag = "multi_node_instance"
|
||||
|
||||
@ -1921,8 +1923,16 @@ class Task(_Task):
|
||||
"""
|
||||
return self._get_logger(auto_connect_streams=self._log_to_backend)
|
||||
|
||||
def launch_multi_node(self, total_num_nodes, port=29500, queue=None, wait=False, addr=None):
|
||||
# type: (int, Optional[int], Optional[str], bool, Optional[str]) -> dict
|
||||
def launch_multi_node(
|
||||
self,
|
||||
total_num_nodes, # type: int
|
||||
port=29500, # type: Optional[int]
|
||||
queue=None, # type: Optional[str]
|
||||
wait=False, # type: bool
|
||||
addr=None, # type: Optional[str]
|
||||
devices=None, # type: Optional[Union[int, Sequence[int]]]
|
||||
hide_children=False # bool
|
||||
):
|
||||
"""
|
||||
Enqueue multiple clones of the current task to a queue, allowing the task
|
||||
to be ran by multiple workers in parallel. Each task running this way is called a node.
|
||||
@ -1996,6 +2006,9 @@ class Task(_Task):
|
||||
parameter will be set to the one defined in ``MASTER_ADDR``. If neither environment variables exist,
|
||||
the value passed to the parameter will be used. If this value is None (default), the private IP of
|
||||
the machine the master node is running on will be used.
|
||||
:param devices: The devices to use. This can be a positive number indicating the number of devices to use,
|
||||
a sequence of indices or the value ``-1`` to indicate all available devices should be used.
|
||||
:param hide_children: If True, the children tasks will be hidden. Otherwise, they will be visible in the UI
|
||||
|
||||
:return: A dictionary containing relevant information regarding the multi node run. This dictionary has the following entries:
|
||||
|
||||
@ -2006,9 +2019,12 @@ class Task(_Task):
|
||||
- `node_rank` - the rank of the current node (master has rank 0)
|
||||
- `wait` - if True, the master node will wait for the other nodes to start
|
||||
"""
|
||||
|
||||
def set_launch_multi_node_runtime_props(task, conf):
|
||||
# noinspection PyProtectedMember
|
||||
task._set_runtime_properties({"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()})
|
||||
task._set_runtime_properties(
|
||||
{"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()}
|
||||
)
|
||||
|
||||
if total_num_nodes < 1:
|
||||
raise UsageError("total_num_nodes needs to be at least 1")
|
||||
@ -2024,6 +2040,7 @@ class Task(_Task):
|
||||
),
|
||||
"node_rank": 0,
|
||||
"wait": wait,
|
||||
"devices": devices
|
||||
}
|
||||
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
|
||||
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
|
||||
@ -2033,23 +2050,27 @@ class Task(_Task):
|
||||
runtime_properties = self._get_runtime_properties()
|
||||
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))
|
||||
|
||||
current_conf = master_conf
|
||||
if remote_node_rank:
|
||||
# self is a child node, build the conf from the runtime proprerties
|
||||
current_conf = {
|
||||
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
|
||||
for entry in master_conf.keys()
|
||||
}
|
||||
else:
|
||||
elif os.environ.get("CLEARML_MULTI_NODE_MASTER") is None:
|
||||
nodes_to_wait = []
|
||||
# self is the master node, enqueue the other nodes
|
||||
set_launch_multi_node_runtime_props(self, master_conf)
|
||||
current_conf = master_conf
|
||||
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
|
||||
node = self.clone(source_task=self, parent=self.id)
|
||||
node_conf = copy.deepcopy(master_conf)
|
||||
node_conf["node_rank"] = node_rank
|
||||
set_launch_multi_node_runtime_props(node, node_conf)
|
||||
node.set_system_tags(node.get_system_tags() + [self._launch_multi_node_instance_tag])
|
||||
node.set_system_tags(
|
||||
node.get_system_tags()
|
||||
+ [self._launch_multi_node_instance_tag]
|
||||
+ ([self.__hidden_tag] if hide_children else [])
|
||||
)
|
||||
if master_conf.get("queue"):
|
||||
Task.enqueue(node, queue_name=master_conf["queue"])
|
||||
else:
|
||||
@ -2064,16 +2085,42 @@ class Task(_Task):
|
||||
Task.TaskStatusEnum.stopped,
|
||||
Task.TaskStatusEnum.closed,
|
||||
Task.TaskStatusEnum.failed,
|
||||
Task.TaskStatusEnum.in_progress
|
||||
Task.TaskStatusEnum.in_progress,
|
||||
),
|
||||
check_interval_sec=10
|
||||
check_interval_sec=10,
|
||||
)
|
||||
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
|
||||
os.environ["CLEARML_MULTI_NODE_MASTER"] = "1"
|
||||
|
||||
num_devices = 1
|
||||
if devices is not None:
|
||||
try:
|
||||
num_devices = int(devices)
|
||||
except TypeError:
|
||||
try:
|
||||
num_devices = len(devices)
|
||||
except Exception as ex:
|
||||
raise ValueError("Failed parsing number of devices: {}".format(ex))
|
||||
except ValueError as ex:
|
||||
raise ValueError("Failed parsing number of devices: {}".format(ex))
|
||||
if num_devices < 0:
|
||||
try:
|
||||
import torch
|
||||
|
||||
num_devices = torch.cuda.device_count()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import `torch` while finding the number of devices. "
|
||||
"Please install it or set `devices` to a value different than -1"
|
||||
)
|
||||
|
||||
os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
|
||||
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
|
||||
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", ""))
|
||||
os.environ["RANK"] = str(current_conf.get("node_rank", ""))
|
||||
os.environ["RANK"] = str(
|
||||
current_conf.get("node_rank", 0) * num_devices + int(os.environ.get("LOCAL_RANK", "0"))
|
||||
)
|
||||
os.environ["NODE_RANK"] = str(current_conf.get("node_rank", ""))
|
||||
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", total_num_nodes) * num_devices)
|
||||
|
||||
return current_conf
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user