Use source task id to determine cloned task parent

This commit is contained in:
allegroai 2020-01-02 12:01:03 +02:00
parent 62d5535351
commit 54ae340ccb
2 changed files with 20 additions and 12 deletions

View File

@ -296,6 +296,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def project(self):
return self.data.project
@property
def parent(self):
return self.data.parent
@property
def input_model_id(self):
return self.data.execution.model

View File

@ -424,15 +424,12 @@ class Task(_Task):
:param source_task: Source Task object (or ID) to be cloned
:type source_task: Task/str
:param name: Optional, New for the new task
:type name: str
:param comment: Optional, comment for the new task
:type comment: str
:param parent: Optional parent Task ID of the new task.
:type parent: str
:param project: Optional project ID of the new task.
:param str name: Optional, New for the new task
:param str comment: Optional, comment for the new task
:param str parent: Optional parent Task ID of the new task.
If None, parent will be set to source_task.parent, or if not available to source_task itself.
:param str project: Optional project ID of the new task.
If None, the new task will inherit the cloned task's project.
:type project: str
:return: a new cloned Task object
"""
assert isinstance(source_task, (six.string_types, Task))
@ -440,6 +437,13 @@ class Task(_Task):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = source_task if isinstance(source_task, six.string_types) else source_task.id
if not parent:
if isinstance(source_task, six.string_types):
source_task = cls.get_task(task_id=source_task)
parent = source_task.id if not source_task.parent else source_task.parent
elif isinstance(parent, Task):
parent = parent.id
cloned_task_id = cls._clone_task(cloned_task_id=task_id, name=name, comment=comment,
parent=parent, project=project)
cloned_task = cls.get_task(task_id=cloned_task_id)
@ -573,7 +577,7 @@ class Task(_Task):
def _update_config_dict(task, config_dict):
task.set_model_config(config_dict=config_dict)
if not running_remotely():
if not running_remotely() or not self.is_main_task():
self.set_model_config(config_dict=configuration)
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
else:
@ -583,7 +587,7 @@ class Task(_Task):
return configuration
# it is a path to a local file
if not running_remotely():
if not running_remotely() or not self.is_main_task():
# check if not absolute path
configuration_path = Path(configuration)
if not configuration_path.is_file():
@ -620,7 +624,7 @@ class Task(_Task):
raise ValueError("connect_label_enumeration supports only `dict` type, "
"{} is not supported".format(type(enumeration)))
if not running_remotely():
if not running_remotely() or not self.is_main_task():
self.set_model_label_enumeration(enumeration)
else:
# pop everything
@ -1154,7 +1158,7 @@ class Task(_Task):
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if not running_remotely():
if not running_remotely() or not self.is_main_task():
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
else: