Add offline support using Task.set_offline() and Task.import_offline_session()

This commit is contained in:
allegroai
2020-07-30 15:03:22 +03:00
parent 2ec5726812
commit a8d6380696
15 changed files with 427 additions and 78 deletions

View File

@@ -1,12 +1,14 @@
import atexit
import json
import os
import shutil
import signal
import sys
import threading
import time
from argparse import ArgumentParser
from tempfile import mkstemp
from tempfile import mkstemp, mkdtemp
from zipfile import ZipFile, ZIP_DEFLATED
try:
# noinspection PyCompatibility
@@ -25,6 +27,7 @@ from .backend_api.session.session import Session, ENV_ACCESS_KEY, ENV_SECRET_KEY
from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
from .backend_interface.task.log import TaskHandler
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
@@ -446,7 +449,9 @@ class Task(_Task):
not auto_connect_frameworks.get('detect_repository', True)) else True
)
# set defaults
if output_uri:
if cls._offline_mode:
task.output_uri = None
elif output_uri:
task.output_uri = output_uri
elif cls.__default_output_uri:
task.output_uri = cls.__default_output_uri
@@ -530,9 +535,11 @@ class Task(_Task):
logger = task.get_logger()
# show the debug metrics page in the log, it is very convenient
if not is_sub_process_task_id:
logger.report_text(
'TRAINS results page: {}'.format(task.get_output_log_web_page()),
)
if cls._offline_mode:
logger.report_text('TRAINS running in offline mode, session stored in {}'.format(
task.get_offline_mode_folder()))
else:
logger.report_text('TRAINS results page: {}'.format(task.get_output_log_web_page()))
# Make sure we start the dev worker if required, otherwise it will only be started when we write
# something to the log.
task._dev_mode_task_start()
@@ -1362,7 +1369,7 @@ class Task(_Task):
:return: The last reported iteration number.
"""
self._reload_last_iteration()
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
return max(self.data.last_iteration or 0, self._reporter.max_iteration if self._reporter else 0)
def set_initial_iteration(self, offset=0):
# type: (int) -> int
@@ -1570,11 +1577,11 @@ class Task(_Task):
:param task_data: dictionary with full Task configuration
:return: return True if Task update was successful
"""
return self.import_task(task_data=task_data, target_task=self, update=True)
return bool(self.import_task(task_data=task_data, target_task=self, update=True))
@classmethod
def import_task(cls, task_data, target_task=None, update=False):
# type: (dict, Optional[Union[str, Task]], bool) -> bool
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]
"""
Import (create) Task from previously exported Task configuration (see Task.export_task)
Can also be used to edit/update an existing Task (by passing `target_task` and `update=True`).
@@ -1595,7 +1602,7 @@ class Task(_Task):
"received `target_task` type {}".format(type(target_task)))
target_task.reload()
cur_data = target_task.data.to_dict()
cur_data = merge_dicts(cur_data, task_data) if update else task_data
cur_data = merge_dicts(cur_data, task_data) if update else dict(**task_data)
cur_data.pop('id', None)
cur_data.pop('project', None)
# noinspection PyProtectedMember
@@ -1604,8 +1611,79 @@ class Task(_Task):
res = target_task._edit(**cur_data)
if res and res.ok():
target_task.reload()
return True
return False
return target_task
return None
@classmethod
def import_offline_session(cls, session_folder_zip):
# type: (str) -> (Optional[str])
"""
Upload an off line session (execution) of a Task.
Full Task execution includes repository details, installed packages, artifacts, logs, metric and debug samples.
:param session_folder_zip: Path to a folder containing the session, or zip-file of the session folder.
:return: Newly created task ID (str)
"""
print('TRAINS: Importing offline session from {}'.format(session_folder_zip))
temp_folder = None
if Path(session_folder_zip).is_file():
# unzip the file:
temp_folder = mkdtemp(prefix='trains-offline-')
ZipFile(session_folder_zip).extractall(path=temp_folder)
session_folder_zip = temp_folder
session_folder = Path(session_folder_zip)
if not session_folder.is_dir():
raise ValueError("Could not find the session folder / zip-file {}".format(session_folder))
try:
with open(session_folder / cls._offline_filename, 'rt') as f:
export_data = json.load(f)
except Exception as ex:
raise ValueError(
"Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex))
task = cls.import_task(export_data)
task.mark_started(force=True)
# fix artifacts
if task.data.execution.artifacts:
from . import StorageManager
# noinspection PyProtectedMember
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
remote_url = task._get_default_report_storage_uri()
if remote_url and remote_url.endswith('/'):
remote_url = remote_url[:-1]
for artifact in task.data.execution.artifacts:
local_path = artifact.uri.replace(offline_folder, '', 1)
local_file = session_folder / 'data' / local_path
if local_file.is_file():
remote_path = local_path.replace(
'.{}{}'.format(export_data['id'], os.sep), '.{}{}'.format(task.id, os.sep), 1)
artifact.uri = '{}/{}'.format(remote_url, remote_path)
StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri)
# noinspection PyProtectedMember
task._edit(execution=task.data.execution)
# logs
TaskHandler.report_offline_session(task, session_folder)
# metrics
Metrics.report_offline_session(task, session_folder)
# print imported results page
print('TRAINS results page: {}'.format(task.get_output_log_web_page()))
task.completed()
# close task
task.close()
# cleanup
if temp_folder:
# noinspection PyBroadException
try:
shutil.rmtree(temp_folder)
except Exception:
pass
return task.id
@classmethod
def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None):
@@ -2099,7 +2177,7 @@ class Task(_Task):
parent.terminate()
def _dev_mode_setup_worker(self, model_updated=False):
if running_remotely() or not self.is_main_task() or self._at_exit_called:
if running_remotely() or not self.is_main_task() or self._at_exit_called or self._offline_mode:
return
if self._dev_worker:
@@ -2283,6 +2361,23 @@ class Task(_Task):
except Exception:
# make sure we do not interrupt the exit process
pass
# make sure we store last task state
if self._offline_mode and not is_sub_process:
# noinspection PyBroadException
try:
# create zip file
offline_folder = self.get_offline_mode_folder()
zip_file = offline_folder.as_posix() + '.zip'
with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf:
for filename in offline_folder.rglob('*'):
if filename.is_file():
relative_file_name = filename.relative_to(offline_folder).as_posix()
zf.write(filename.as_posix(), arcname=relative_file_name)
print('TRAINS Task: Offline session stored in {}'.format(zip_file))
except Exception as ex:
pass
# delete locking object (lock file)
if self._edit_lock:
# noinspection PyBroadException
@@ -2597,7 +2692,7 @@ class Task(_Task):
@classmethod
def __get_task_api_obj(cls, task_id, only_fields=None):
if not task_id:
if not task_id or cls._offline_mode:
return None
all_tasks = cls._send(