diff --git a/trains/task.py b/trains/task.py index 7fa063d0..ddaf2531 100644 --- a/trains/task.py +++ b/trains/task.py @@ -1619,6 +1619,7 @@ class Task(_Task): export_data.get('execution', {}).pop('artifacts', None) export_data.get('execution', {}).pop('model', None) export_data['project_name'] = self.get_project_name() + export_data['session_api_version'] = self.session.api_version return export_data def update_task(self, task_data): @@ -1644,6 +1645,14 @@ class Task(_Task): :param update: If True, merge task_data with current Task configuration. :return: return True if Task was imported/updated """ + + # restore original API version (otherwise, we might not be able to restore the data correctly) + force_api_version = task_data.get('session_api_version') or None + original_api_version = Session.api_version + original_force_max_api_version = Session.force_max_api_version + if force_api_version: + Session.force_max_api_version = str(force_api_version) + if not target_task: project_name = task_data.get('project_name') or Task._get_project_name(task_data.get('project', '')) target_task = Task.create(project_name=project_name, task_name=task_data.get('name', None)) @@ -1664,8 +1673,17 @@ class Task(_Task): res = target_task._edit(**cur_data) if res and res.ok(): target_task.reload() - return target_task - return None + else: + target_task = None + + # restore current api version, and return a new instance if Task with the current version + if force_api_version: + Session.force_max_api_version = original_force_max_api_version + Session.api_version = original_api_version + if target_task: + target_task = Task.get_task(task_id=target_task.id) + + return target_task @classmethod def import_offline_session(cls, session_folder_zip):