Fix Task.launch_multi_node() not supported when used via pytorch lightning

This commit is contained in:
allegroai 2024-07-04 15:29:37 +03:00
parent aa227a0cdb
commit e27d277e40

View File

@ -180,6 +180,8 @@ class Task(_Task):
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)
__hidden_tag = "hidden"
_launch_multi_node_section = "launch_multi_node"
_launch_multi_node_instance_tag = "multi_node_instance"
@ -1921,8 +1923,16 @@ class Task(_Task):
"""
return self._get_logger(auto_connect_streams=self._log_to_backend)
def launch_multi_node(self, total_num_nodes, port=29500, queue=None, wait=False, addr=None):
# type: (int, Optional[int], Optional[str], bool, Optional[str]) -> dict
def launch_multi_node(
self,
total_num_nodes, # type: int
port=29500, # type: Optional[int]
queue=None, # type: Optional[str]
wait=False, # type: bool
addr=None, # type: Optional[str]
devices=None, # type: Optional[Union[int, Sequence[int]]]
hide_children=False # bool
):
"""
Enqueue multiple clones of the current task to a queue, allowing the task
to be ran by multiple workers in parallel. Each task running this way is called a node.
@ -1996,6 +2006,9 @@ class Task(_Task):
parameter will be set to the one defined in ``MASTER_ADDR``. If neither environment variables exist,
the value passed to the parameter will be used. If this value is None (default), the private IP of
the machine the master node is running on will be used.
:param devices: The devices to use. This can be a positive number indicating the number of devices to use,
a sequence of indices or the value ``-1`` to indicate all available devices should be used.
:param hide_children: If True, the children tasks will be hidden. Otherwise, they will be visible in the UI
:return: A dictionary containing relevant information regarding the multi node run. This dictionary has the following entries:
@ -2006,9 +2019,12 @@ class Task(_Task):
- `node_rank` - the rank of the current node (master has rank 0)
- `wait` - if True, the master node will wait for the other nodes to start
"""
def set_launch_multi_node_runtime_props(task, conf):
# noinspection PyProtectedMember
task._set_runtime_properties({"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()})
task._set_runtime_properties(
{"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()}
)
if total_num_nodes < 1:
raise UsageError("total_num_nodes needs to be at least 1")
@ -2024,6 +2040,7 @@ class Task(_Task):
),
"node_rank": 0,
"wait": wait,
"devices": devices
}
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
@ -2033,23 +2050,27 @@ class Task(_Task):
runtime_properties = self._get_runtime_properties()
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))
current_conf = master_conf
if remote_node_rank:
# self is a child node, build the conf from the runtime proprerties
current_conf = {
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
for entry in master_conf.keys()
}
else:
elif os.environ.get("CLEARML_MULTI_NODE_MASTER") is None:
nodes_to_wait = []
# self is the master node, enqueue the other nodes
set_launch_multi_node_runtime_props(self, master_conf)
current_conf = master_conf
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
node = self.clone(source_task=self, parent=self.id)
node_conf = copy.deepcopy(master_conf)
node_conf["node_rank"] = node_rank
set_launch_multi_node_runtime_props(node, node_conf)
node.set_system_tags(node.get_system_tags() + [self._launch_multi_node_instance_tag])
node.set_system_tags(
node.get_system_tags()
+ [self._launch_multi_node_instance_tag]
+ ([self.__hidden_tag] if hide_children else [])
)
if master_conf.get("queue"):
Task.enqueue(node, queue_name=master_conf["queue"])
else:
@ -2064,16 +2085,42 @@ class Task(_Task):
Task.TaskStatusEnum.stopped,
Task.TaskStatusEnum.closed,
Task.TaskStatusEnum.failed,
Task.TaskStatusEnum.in_progress
Task.TaskStatusEnum.in_progress,
),
check_interval_sec=10
check_interval_sec=10,
)
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
os.environ["CLEARML_MULTI_NODE_MASTER"] = "1"
num_devices = 1
if devices is not None:
try:
num_devices = int(devices)
except TypeError:
try:
num_devices = len(devices)
except Exception as ex:
raise ValueError("Failed parsing number of devices: {}".format(ex))
except ValueError as ex:
raise ValueError("Failed parsing number of devices: {}".format(ex))
if num_devices < 0:
try:
import torch
num_devices = torch.cuda.device_count()
except ImportError:
raise ImportError(
"Could not import `torch` while finding the number of devices. "
"Please install it or set `devices` to a value different than -1"
)
os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", ""))
os.environ["RANK"] = str(current_conf.get("node_rank", ""))
os.environ["RANK"] = str(
current_conf.get("node_rank", 0) * num_devices + int(os.environ.get("LOCAL_RANK", "0"))
)
os.environ["NODE_RANK"] = str(current_conf.get("node_rank", ""))
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", total_num_nodes) * num_devices)
return current_conf