Add Task export_task / import_task / update_task (Issue #128)

This commit is contained in:
allegroai 2020-07-06 21:02:34 +03:00
parent 04ab5ca99c
commit 8af97dbab1
2 changed files with 74 additions and 1 deletions

View File

@ -1399,6 +1399,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
task = get_single_result(entity='task', query=task_name, results=res.response.tasks)
return cls(task_id=task.id)
@classmethod
def _get_project_name(cls, project_id):
res = cls._send(cls._get_default_session(), projects.GetByIdRequest(project=project_id), raise_on_errors=False)
if not res or not res.response or not res.response.project:
return None
return res.response.project.name
def _get_all_events(self, max_events=100):
# type: (int) -> Any
"""

View File

@ -45,7 +45,7 @@ from .model import Model, InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .utilities.dicts import ReadOnlyDict
from .utilities.dicts import ReadOnlyDict, merge_dicts
from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \
nested_from_flat_dictionary, naive_nested_from_flat_dictionary
from .utilities.resource_monitor import ResourceMonitor
@ -1501,6 +1501,72 @@ class Task(_Task):
if raise_on_status and self.status in raise_on_status:
raise RuntimeError("Task {} has status: {}.".format(self.task_id, self.status))
def export_task(self):
# type: () -> dict
"""
Export Task's configuration into a dictionary (for serialization purposes).
A Task can be copied/modified by calling Task.import_task()
Notice: Export task does not include the tasks outputs, such as results
(scalar/plots etc.) or Task artifacts/models
:return: dictionary of the Task's configuration.
"""
self.reload()
export_data = self.data.to_dict()
export_data.pop('last_metrics', None)
export_data.pop('last_iteration', None)
export_data.pop('status_changed', None)
export_data.pop('status_reason', None)
export_data.pop('status_message', None)
export_data.get('execution', {}).pop('artifacts', None)
export_data.get('execution', {}).pop('model', None)
return export_data
def update_task(self, task_data):
# type: (dict) -> bool
"""
Update current task with configuration found on the task_data dictionary.
See also export_task() for retrieving Task configuration.
:param task_data: dictionary with full Task configuration
:return: return True if Task update was successful
"""
return self.import_task(task_data=task_data, target_task=self, update=True)
@classmethod
def import_task(cls, task_data, target_task=None, update=False):
# type: (dict, Optional[Union[str, Task]], bool) -> bool
"""
Import (create) Task from previously exported Task configuration (see Task.export_task)
Can also be used to edit/update an existing Task (by passing `target_task` and `update=True`).
:param task_data: dictionary of a Task's configuration
:param target_task: Import task_data into an existing Task. Can be either task_id (str) or Task object.
:param update: If True, merge task_data with current Task configuration.
:return: return True if Task was imported/updated
"""
if not target_task:
project_name = Task._get_project_name(task_data.get('project', ''))
target_task = Task.create(project_name=project_name, task_name=task_data.get('name', None))
elif isinstance(target_task, six.string_types):
target_task = Task.get_task(task_id=target_task)
elif not isinstance(target_task, Task):
raise ValueError(
"`target_task` must be either Task id (str) or Task object, "
"received `target_task` type {}".format(type(target_task)))
target_task.reload()
cur_data = target_task.data.to_dict()
cur_data = merge_dicts(cur_data, task_data) if update else task_data
cur_data.pop('id', None)
# noinspection PyProtectedMember
valid_fields = list(tasks.EditRequest._get_data_props().keys())
cur_data = dict((k, cur_data[k]) for k in valid_fields if k in cur_data)
res = target_task._edit(**cur_data)
if res and res.ok():
target_task.reload()
return True
return False
@classmethod
def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None):
# type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> ()