Add session API version to exported tasks

This commit is contained in:
allegroai 2020-08-08 12:43:12 +03:00
parent fa4803cb82
commit 48ef50d41d

View File

@ -1619,6 +1619,7 @@ class Task(_Task):
export_data.get('execution', {}).pop('artifacts', None) export_data.get('execution', {}).pop('artifacts', None)
export_data.get('execution', {}).pop('model', None) export_data.get('execution', {}).pop('model', None)
export_data['project_name'] = self.get_project_name() export_data['project_name'] = self.get_project_name()
export_data['session_api_version'] = self.session.api_version
return export_data return export_data
def update_task(self, task_data): def update_task(self, task_data):
@ -1644,6 +1645,14 @@ class Task(_Task):
:param update: If True, merge task_data with current Task configuration. :param update: If True, merge task_data with current Task configuration.
:return: return True if Task was imported/updated :return: return True if Task was imported/updated
""" """
# restore original API version (otherwise, we might not be able to restore the data correctly)
force_api_version = task_data.get('session_api_version') or None
original_api_version = Session.api_version
original_force_max_api_version = Session.force_max_api_version
if force_api_version:
Session.force_max_api_version = str(force_api_version)
if not target_task: if not target_task:
project_name = task_data.get('project_name') or Task._get_project_name(task_data.get('project', '')) project_name = task_data.get('project_name') or Task._get_project_name(task_data.get('project', ''))
target_task = Task.create(project_name=project_name, task_name=task_data.get('name', None)) target_task = Task.create(project_name=project_name, task_name=task_data.get('name', None))
@ -1664,8 +1673,17 @@ class Task(_Task):
res = target_task._edit(**cur_data) res = target_task._edit(**cur_data)
if res and res.ok(): if res and res.ok():
target_task.reload() target_task.reload()
return target_task else:
return None target_task = None
# restore current api version, and return a new instance if Task with the current version
if force_api_version:
Session.force_max_api_version = original_force_max_api_version
Session.api_version = original_api_version
if target_task:
target_task = Task.get_task(task_id=target_task.id)
return target_task
@classmethod @classmethod
def import_offline_session(cls, session_folder_zip): def import_offline_session(cls, session_folder_zip):