From 54ae340ccb85850607a3e25736440290501b99c2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 2 Jan 2020 12:01:03 +0200 Subject: [PATCH] Use source task id to determine cloned task parent --- trains/backend_interface/task/task.py | 4 ++++ trains/task.py | 28 +++++++++++++++------------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 39c5c4af..ac3785c8 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -296,6 +296,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def project(self): return self.data.project + @property + def parent(self): + return self.data.parent + @property def input_model_id(self): return self.data.execution.model diff --git a/trains/task.py b/trains/task.py index ea52886e..7dec1d8d 100644 --- a/trains/task.py +++ b/trains/task.py @@ -424,15 +424,12 @@ class Task(_Task): :param source_task: Source Task object (or ID) to be cloned :type source_task: Task/str - :param name: Optional, New for the new task - :type name: str - :param comment: Optional, comment for the new task - :type comment: str - :param parent: Optional parent Task ID of the new task. - :type parent: str - :param project: Optional project ID of the new task. + :param str name: Optional, New for the new task + :param str comment: Optional, comment for the new task + :param str parent: Optional parent Task ID of the new task. + If None, parent will be set to source_task.parent, or if not available to source_task itself. + :param str project: Optional project ID of the new task. If None, the new task will inherit the cloned task's project. - :type project: str :return: a new cloned Task object """ assert isinstance(source_task, (six.string_types, Task)) @@ -440,6 +437,13 @@ class Task(_Task): raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above") task_id = source_task if isinstance(source_task, six.string_types) else source_task.id + if not parent: + if isinstance(source_task, six.string_types): + source_task = cls.get_task(task_id=source_task) + parent = source_task.id if not source_task.parent else source_task.parent + elif isinstance(parent, Task): + parent = parent.id + cloned_task_id = cls._clone_task(cloned_task_id=task_id, name=name, comment=comment, parent=parent, project=project) cloned_task = cls.get_task(task_id=cloned_task_id) @@ -573,7 +577,7 @@ class Task(_Task): def _update_config_dict(task, config_dict): task.set_model_config(config_dict=config_dict) - if not running_remotely(): + if not running_remotely() or not self.is_main_task(): self.set_model_config(config_dict=configuration) configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration) else: @@ -583,7 +587,7 @@ class Task(_Task): return configuration # it is a path to a local file - if not running_remotely(): + if not running_remotely() or not self.is_main_task(): # check if not absolute path configuration_path = Path(configuration) if not configuration_path.is_file(): @@ -620,7 +624,7 @@ class Task(_Task): raise ValueError("connect_label_enumeration supports only `dict` type, " "{} is not supported".format(type(enumeration))) - if not running_remotely(): + if not running_remotely() or not self.is_main_task(): self.set_model_label_enumeration(enumeration) else: # pop everything @@ -1154,7 +1158,7 @@ class Task(_Task): self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) - if not running_remotely(): + if not running_remotely() or not self.is_main_task(): self._arguments.copy_from_dict(flatten_dictionary(dictionary)) dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) else: