diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 3fdc443a..d6165ab6 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -1399,6 +1399,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): task = get_single_result(entity='task', query=task_name, results=res.response.tasks) return cls(task_id=task.id) + @classmethod + def _get_project_name(cls, project_id): + res = cls._send(cls._get_default_session(), projects.GetByIdRequest(project=project_id), raise_on_errors=False) + if not res or not res.response or not res.response.project: + return None + return res.response.project.name + def _get_all_events(self, max_events=100): # type: (int) -> Any """ diff --git a/trains/task.py b/trains/task.py index 7d1abf45..0cf33db6 100644 --- a/trains/task.py +++ b/trains/task.py @@ -45,7 +45,7 @@ from .model import Model, InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask -from .utilities.dicts import ReadOnlyDict +from .utilities.dicts import ReadOnlyDict, merge_dicts from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \ nested_from_flat_dictionary, naive_nested_from_flat_dictionary from .utilities.resource_monitor import ResourceMonitor @@ -1501,6 +1501,72 @@ class Task(_Task): if raise_on_status and self.status in raise_on_status: raise RuntimeError("Task {} has status: {}.".format(self.task_id, self.status)) + def export_task(self): + # type: () -> dict + """ + Export Task's configuration into a dictionary (for serialization purposes). + A Task can be copied/modified by calling Task.import_task() + Notice: Export task does not include the tasks outputs, such as results + (scalar/plots etc.) or Task artifacts/models + + :return: dictionary of the Task's configuration. + """ + self.reload() + export_data = self.data.to_dict() + export_data.pop('last_metrics', None) + export_data.pop('last_iteration', None) + export_data.pop('status_changed', None) + export_data.pop('status_reason', None) + export_data.pop('status_message', None) + export_data.get('execution', {}).pop('artifacts', None) + export_data.get('execution', {}).pop('model', None) + return export_data + + def update_task(self, task_data): + # type: (dict) -> bool + """ + Update current task with configuration found on the task_data dictionary. + See also export_task() for retrieving Task configuration. + + :param task_data: dictionary with full Task configuration + :return: return True if Task update was successful + """ + return self.import_task(task_data=task_data, target_task=self, update=True) + + @classmethod + def import_task(cls, task_data, target_task=None, update=False): + # type: (dict, Optional[Union[str, Task]], bool) -> bool + """ + Import (create) Task from previously exported Task configuration (see Task.export_task) + Can also be used to edit/update an existing Task (by passing `target_task` and `update=True`). + + :param task_data: dictionary of a Task's configuration + :param target_task: Import task_data into an existing Task. Can be either task_id (str) or Task object. + :param update: If True, merge task_data with current Task configuration. + :return: return True if Task was imported/updated + """ + if not target_task: + project_name = Task._get_project_name(task_data.get('project', '')) + target_task = Task.create(project_name=project_name, task_name=task_data.get('name', None)) + elif isinstance(target_task, six.string_types): + target_task = Task.get_task(task_id=target_task) + elif not isinstance(target_task, Task): + raise ValueError( + "`target_task` must be either Task id (str) or Task object, " + "received `target_task` type {}".format(type(target_task))) + target_task.reload() + cur_data = target_task.data.to_dict() + cur_data = merge_dicts(cur_data, task_data) if update else task_data + cur_data.pop('id', None) + # noinspection PyProtectedMember + valid_fields = list(tasks.EditRequest._get_data_props().keys()) + cur_data = dict((k, cur_data[k]) for k in valid_fields if k in cur_data) + res = target_task._edit(**cur_data) + if res and res.ok(): + target_task.reload() + return True + return False + @classmethod def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None): # type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> ()