Add offline support using Task.set_offline() and Task.import_offline_session()

This commit is contained in:
allegroai 2020-07-30 15:03:22 +03:00
parent 2ec5726812
commit a8d6380696
15 changed files with 427 additions and 78 deletions

View File

@ -31,7 +31,7 @@ class ApiServiceProxy(object):
]]
# get the most advanced service version that supports our api
version = [str(v) for v in ApiServiceProxy._available_versions if Version(Session.api_version) >= v][-1]
version = [str(v) for v in ApiServiceProxy._available_versions if Session.check_min_api_version(v)][-1]
self.__dict__["__wrapped_version__"] = Session.api_version
name = ".v{}.{}".format(
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")

View File

@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True)
ENV_OFFLINE_MODE = EnvEntry("TRAINS_OFFLINE_MODE", "ALG_OFFLINE_MODE", type=bool)

View File

@ -11,7 +11,8 @@ from requests.auth import HTTPBasicAuth
from six.moves.urllib.parse import urlparse, urlunparse
from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, \
ENV_FILES_HOST, ENV_OFFLINE_MODE
from .request import Request, BatchRequest # noqa: F401
from .token_manager import TokenManager
from ..config import load
@ -50,6 +51,8 @@ class Session(TokenManager):
_write_session_timeout = (300.0, 300.)
_sessions_created = 0
_ssl_error_count_verbosity = 2
_offline_mode = ENV_OFFLINE_MODE.get()
_offline_default_version = '2.5'
_client = [(__package__.partition(".")[0], __version__)]
@ -153,6 +156,9 @@ class Session(TokenManager):
self.client = ", ".join("{}-{}".format(*x) for x in self._client)
if self._offline_mode:
return
self.refresh_token()
# update api version from server response
@ -197,6 +203,9 @@ class Session(TokenManager):
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
generate a token with the updated permissions.
"""
if self._offline_mode:
return None
host = self.host
headers = headers.copy() if headers else {}
headers[self._WORKER_HEADER] = self.worker
@ -406,6 +415,9 @@ class Session(TokenManager):
"""
self.validate_request(req_obj)
if self._offline_mode:
return None
if isinstance(req_obj, BatchRequest):
# TODO: support async for batch requests as well
if async_enable:
@ -526,11 +538,14 @@ class Session(TokenManager):
# If no session was created, create a default one, in order to get the backend api version.
if cls._sessions_created <= 0:
# noinspection PyBroadException
try:
cls()
except Exception:
pass
if cls._offline_mode:
cls.api_version = cls._offline_default_version
else:
# noinspection PyBroadException
try:
cls()
except Exception:
pass
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))

View File

@ -7,7 +7,7 @@ from ..backend_api import Session, CallResult
from ..backend_api.session.session import MaxRequestSizeError
from ..backend_api.session.response import ResponseMeta
from ..backend_api.session import BatchRequest
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_OFFLINE_MODE
from ..config import config_obj
from ..config.defs import LOG_LEVEL_ENV_VAR
@ -19,6 +19,7 @@ class InterfaceBase(SessionInterface):
""" Base class for a backend manager class """
_default_session = None
_num_retry_warning_display = 1
_offline_mode = ENV_OFFLINE_MODE.get()
@property
def session(self):
@ -44,6 +45,9 @@ class InterfaceBase(SessionInterface):
@classmethod
def _send(cls, session, req, ignore_errors=False, raise_on_errors=True, log=None, async_enable=False):
""" Convenience send() method providing a standardized error reporting """
if cls._offline_mode:
return None
num_retries = 0
while True:
error_msg = ''
@ -151,7 +155,7 @@ class IdObjectBase(InterfaceBase):
pass
def reload(self):
if not self.id:
if not self.id and not self._offline_mode:
raise ValueError('Failed reloading %s: missing id' % type(self).__name__)
# noinspection PyBroadException
try:

View File

@ -18,7 +18,7 @@ class StdStreamPatch(object):
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
# noinspection PyBroadException
try:
if StdStreamPatch._stdout_original_write is None:
@ -70,7 +70,7 @@ class StdStreamPatch(object):
pass
elif DevWorker.report_stdout and not running_remotely():
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
if StdStreamPatch._stdout_proxy:
StdStreamPatch._stdout_proxy.connect(logger)
if StdStreamPatch._stderr_proxy:

View File

@ -1,4 +1,7 @@
import json
import os
from functools import partial
from logging import warning
from multiprocessing.pool import ThreadPool
from multiprocessing import Lock
from time import time
@ -25,6 +28,7 @@ class Metrics(InterfaceBase):
_file_upload_retries = 3
_upload_pool = None
_file_upload_pool = None
__offline_filename = 'metrics.jsonl'
@property
def storage_key_prefix(self):
@ -43,14 +47,19 @@ class Metrics(InterfaceBase):
finally:
self._storage_lock.release()
def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None):
def __init__(self, session, task, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None):
super(Metrics, self).__init__(session, log=log)
self._task_id = task_id
self._task_id = task.id
self._task_iteration_offset = iteration_offset
self._storage_uri = storage_uri.rstrip('/') if storage_uri else None
self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None
self._file_related_event_time = None
self._file_upload_time = None
self._offline_log_filename = None
if self._offline_mode:
offline_folder = Path(task.get_offline_mode_folder())
offline_folder.mkdir(parents=True, exist_ok=True)
self._offline_log_filename = offline_folder / self.__offline_filename
def write_events(self, events, async_enable=True, callback=None, **kwargs):
"""
@ -167,6 +176,7 @@ class Metrics(InterfaceBase):
e.set_exception(exp)
e.stream.close()
if e.delete_local_file:
# noinspection PyBroadException
try:
Path(e.delete_local_file).unlink()
except Exception:
@ -199,6 +209,11 @@ class Metrics(InterfaceBase):
_events = [ev.get_api_event() for ev in good_events]
batched_requests = [api_events.AddRequest(event=ev) for ev in _events if ev]
if batched_requests:
if self._offline_mode:
with open(self._offline_log_filename.as_posix(), 'at') as f:
f.write(json.dumps([b.to_dict() for b in batched_requests])+'\n')
return
req = api_events.AddBatchRequest(requests=batched_requests)
return self.send(req, raise_on_errors=False)
@ -234,3 +249,69 @@ class Metrics(InterfaceBase):
pool.join()
except Exception:
pass
@classmethod
def report_offline_session(cls, task, folder):
from ... import StorageManager
filename = Path(folder) / cls.__offline_filename
if not filename.is_file():
return False
# noinspection PyProtectedMember
remote_url = task._get_default_report_storage_uri()
if remote_url and remote_url.endswith('/'):
remote_url = remote_url[:-1]
uploaded_files = set()
task_id = task.id
with open(filename, 'rt') as f:
i = 0
while True:
try:
line = f.readline()
if not line:
break
list_requests = json.loads(line)
for r in list_requests:
org_task_id = r['task']
r['task'] = task_id
if r.get('key') and r.get('url'):
debug_sample = (Path(folder) / 'data').joinpath(*(r['key'].split('/')))
r['key'] = r['key'].replace(
'.{}{}'.format(org_task_id, os.sep), '.{}{}'.format(task_id, os.sep), 1)
r['url'] = '{}/{}'.format(remote_url, r['key'])
if debug_sample not in uploaded_files and debug_sample.is_file():
uploaded_files.add(debug_sample)
StorageManager.upload_file(local_file=debug_sample.as_posix(), remote_url=r['url'])
elif r.get('plot_str'):
# hack plotly embedded images links
# noinspection PyBroadException
try:
task_id_sep = '.{}{}'.format(org_task_id, os.sep)
plot = json.loads(r['plot_str'])
if plot.get('layout', {}).get('images'):
for image in plot['layout']['images']:
if task_id_sep not in image['source']:
continue
pre, post = image['source'].split(task_id_sep, 1)
pre = os.sep.join(pre.split(os.sep)[-2:])
debug_sample = (Path(folder) / 'data').joinpath(
pre+'.{}'.format(org_task_id), post)
image['source'] = '/'.join(
[remote_url, pre + '.{}'.format(task_id), post])
if debug_sample not in uploaded_files and debug_sample.is_file():
uploaded_files.add(debug_sample)
StorageManager.upload_file(
local_file=debug_sample.as_posix(), remote_url=image['source'])
r['plot_str'] = json.dumps(plot)
except Exception:
pass
i += 1
except StopIteration:
break
except Exception as ex:
warning('Failed reporting metric, line {} [{}]'.format(i, ex))
batch_requests = api_events.AddBatchRequest(requests=list_requests)
res = task.session.send(batch_requests)
if res and not res.ok():
warning("failed logging metric task to backend ({:d} lines, {})".format(
len(batch_requests.requests), str(res.meta)))
return True

View File

@ -659,7 +659,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
# Hack: if the url doesn't start with http/s then the plotly will not be able to show it,
# then we put the link under images not plots
if not url.startswith('http'):
if not url.startswith('http') and not self._offline_mode:
return self.report_image_and_upload(title=title, series=series, iter=iter, path=path, image=matrix,
upload_uri=upload_uri, max_image_history=max_image_history)

View File

@ -71,6 +71,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def _reload(self):
""" Reload the model object """
if self._offline_mode:
return models.Model()
if self.id == self._EMPTY_MODEL_ID:
return
res = self.send(models.GetByIdRequest(model=self.id))
@ -186,35 +189,40 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
task = task_id or self.data.task
project = project_id or self.data.project
parent = parent_id or self.data.parent
if tags:
extra = {'system_tags': tags or self.data.system_tags} \
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
else:
extra = {}
self.send(models.EditRequest(
self._edit(
model=self.id,
uri=uri,
name=name,
comment=comment,
labels=labels,
design=design,
framework=framework or self.data.framework,
iteration=iteration,
task=task,
project=project,
parent=parent,
framework=framework or self.data.framework,
iteration=iteration,
**extra
))
self.reload()
)
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
uri=None, framework=None, iteration=None):
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags,
uri=uri, framework=framework, iteration=iteration)
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None,
uri=None, framework=None, iteration=None, **extra):
def offline_store(**kwargs):
for k, v in kwargs.items():
setattr(self.data, k, v or getattr(self.data, k, None))
return
if self._offline_mode:
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
uri=uri, framework=framework, iteration=iteration, **extra)
if tags:
extra = {'system_tags': tags or self.data.system_tags} \
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
else:
extra = {}
extra.update({'system_tags': tags or self.data.system_tags}
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags})
self.send(models.EditRequest(
model=self.id,
uri=uri,
@ -298,7 +306,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
override_model_id=override_model_id, **extra))
if self.id is None:
# update the model id. in case it was just created, this will trigger a reload of the model object
self.id = res.response.id
self.id = res.response.id if res else None
else:
self.reload()
try:

View File

@ -306,10 +306,11 @@ class _Arguments(object):
# TODO: add dict prefix
prefix = prefix or '' # self._prefix_dict
if prefix:
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
cur_params.update(prefix_dictionary)
self._task.set_parameters(cur_params)
with self._task._edit_lock:
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
cur_params.update(prefix_dictionary)
self._task.set_parameters(cur_params)
else:
self._task.update_parameters(dictionary)
if not isinstance(dictionary, self._ProxyDictWrite):

View File

@ -1,6 +1,8 @@
import json
import sys
import time
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord
from pathlib2 import Path
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord, warning
from logging.handlers import BufferingHandler
from threading import Thread, Event
from six.moves.queue import Queue
@ -17,6 +19,7 @@ class TaskHandler(BufferingHandler):
__wait_for_flush_timeout = 10.
__max_event_size = 1024 * 1024
__once = False
__offline_filename = 'log.jsonl'
@property
def task_id(self):
@ -26,10 +29,10 @@ class TaskHandler(BufferingHandler):
def task_id(self, value):
self._task_id = value
def __init__(self, session, task_id, capacity=buffer_capacity):
def __init__(self, task, capacity=buffer_capacity):
super(TaskHandler, self).__init__(capacity)
self.task_id = task_id
self.session = session
self.task_id = task.id
self.session = task.session
self.last_timestamp = 0
self.counter = 1
self._last_event = None
@ -37,6 +40,11 @@ class TaskHandler(BufferingHandler):
self._queue = None
self._thread = None
self._pending = 0
self._offline_log_filename = None
if task.is_offline():
offline_folder = Path(task.get_offline_mode_folder())
offline_folder.mkdir(parents=True, exist_ok=True)
self._offline_log_filename = offline_folder / self.__offline_filename
def shouldFlush(self, record):
"""
@ -124,6 +132,7 @@ class TaskHandler(BufferingHandler):
if not self.buffer:
return
buffer = None
self.acquire()
if self.buffer:
buffer = self.buffer
@ -133,6 +142,7 @@ class TaskHandler(BufferingHandler):
if not buffer:
return
# noinspection PyBroadException
try:
record_events = [r for record in buffer for r in self._record_to_event(record)] + [self._last_event]
self._last_event = None
@ -194,11 +204,17 @@ class TaskHandler(BufferingHandler):
def _send_events(self, a_request):
try:
self._pending -= 1
if self._offline_log_filename:
with open(self._offline_log_filename.as_posix(), 'at') as f:
f.write(json.dumps([b.to_dict() for b in a_request.requests]) + '\n')
return
# if self._thread is None:
# self.__log_stderr('Task.close() flushing remaining logs ({})'.format(self._pending))
self._pending -= 1
res = self.session.send(a_request)
if not res.ok():
if res and not res.ok():
self.__log_stderr("failed logging task to backend ({:d} lines, {})".format(
len(a_request.requests), str(res.meta)), level=WARNING)
except MaxRequestSizeError:
@ -237,3 +253,31 @@ class TaskHandler(BufferingHandler):
write('{asctime} - {name} - {levelname} - {message}\n'.format(
asctime=Formatter().formatTime(makeLogRecord({})),
name='trains.log', levelname=getLevelName(level), message=msg))
@classmethod
def report_offline_session(cls, task, folder):
filename = Path(folder) / cls.__offline_filename
if not filename.is_file():
return False
with open(filename, 'rt') as f:
i = 0
while True:
try:
line = f.readline()
if not line:
break
list_requests = json.loads(line)
for r in list_requests:
r.pop('task', None)
i += 1
except StopIteration:
break
except Exception as ex:
warning('Failed reporting log, line {} [{}]'.format(i, ex))
batch_requests = events.AddBatchRequest(
requests=[events.TaskLogEvent(task=task.id, **r) for r in list_requests])
res = task.session.send(batch_requests)
if res and not res.ok():
warning("failed logging task to backend ({:d} lines, {})".format(
len(batch_requests.requests), str(res.meta)))
return True

View File

@ -1,5 +1,6 @@
""" Backend task management support """
import itertools
import json
import logging
import os
import sys
@ -7,8 +8,10 @@ import re
from enum import Enum
from tempfile import gettempdir
from multiprocessing import RLock
from pathlib2 import Path
from threading import Thread
from typing import Optional, Any, Sequence, Callable, Mapping, Union, List
from uuid import uuid4
try:
# noinspection PyCompatibility
@ -25,17 +28,18 @@ from ...binding.artifacts import Artifacts
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
from pathlib2 import Path
from ...backend_api.session.defs import ENV_OFFLINE_MODE
from ...utilities.pyhocon import ConfigTree, ConfigFactory
from ..base import IdObjectBase
from ..base import IdObjectBase, InterfaceBase
from ..metrics import Metrics, Reporter
from ..model import Model
from ..setupuploadmixin import SetupUploadMixin
from ..util import make_message, get_or_create_project, get_single_result, \
exact_match_regex
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR
from ...config import (
get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend,
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir)
from ...debugging import get_logger
from ...debugging.log import LoggerRoot
from ...storage.helper import StorageHelper, StorageError
@ -56,6 +60,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_force_requirements = {}
_store_diff = config.get('development.store_uncommitted_code_diff', False)
_offline_filename = 'task.json'
class TaskTypes(Enum):
def __str__(self):
@ -143,6 +148,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not task_id:
# generate a new task
self.id = self._auto_generate(project_name=project_name, task_name=task_name, task_type=task_type)
if self._offline_mode:
self.data.id = self.id
self.name = task_name
else:
# this is an existing task, let's try to verify stuff
self._validate()
@ -195,7 +203,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
# handler instance handled them)
backend_handler = TaskHandler(self.session, self.task_id)
backend_handler = TaskHandler(task=self)
# Add backend handler to both loggers:
# 1. to root logger root logger
@ -280,11 +288,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# to the module call including all argv's
result.script = ScriptInfo.detect_running_module(result.script)
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (let someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
with self._edit_lock:
self.reload()
self.data.script = result.script
self._edit(script=result.script)
# if jupyter is present, requirements will be created in the background, when saving a snapshot
if result.script and script_requirements:
entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \
@ -304,7 +314,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
result.script['requirements']['conda'] = conda_requirements
self._update_requirements(result.script.get('requirements') or '')
self.reload()
# we do not want to wait for the check version thread,
# because someone might wait for us to finish the repo detection update
@ -339,7 +348,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
)
res = self.send(req)
return res.response.id
return res.response.id if res else 'offline-{}'.format(str(uuid4()).replace("-", ""))
def _set_storage_uri(self, value):
value = value.rstrip('/') if value else None
@ -498,7 +507,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if self._metrics_manager is None:
self._metrics_manager = Metrics(
session=self.session,
task_id=self.id,
task=self,
storage_uri=storage_uri,
storage_uri_suffix=self._get_output_destination_suffix('metrics'),
iteration_offset=self.get_initial_iteration()
@ -523,6 +532,27 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# type: () -> Any
""" Reload the task object from the backend """
with self._edit_lock:
if self._offline_mode:
# noinspection PyBroadException
try:
with open(self.get_offline_mode_folder() / self._offline_filename, 'rt') as f:
stored_dict = json.load(f)
stored_data = tasks.Task(**stored_dict)
# add missing entries
for k, v in stored_dict.items():
if not hasattr(stored_data, k):
setattr(stored_data, k, v)
if stored_dict.get('project_name'):
self._project_name = (None, stored_dict.get('project_name'))
except Exception:
stored_data = self._data
return stored_data or tasks.Task(
execution=tasks.Execution(
parameters={}, artifacts=[], dataviews=[], model='',
model_desc={}, model_labels={}, docker_cmd=''),
output=tasks.Output())
if self._reload_skip_flag and self._data:
return self._data
res = self.send(tasks.GetByIdRequest(task=self.id))
@ -774,7 +804,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
execution = self.data.execution
if execution is None:
execution = tasks.Execution(parameters=parameters)
execution = tasks.Execution(
parameters=parameters, artifacts=[], dataviews=[], model='',
model_desc={}, model_labels={}, docker_cmd='')
else:
execution.parameters = parameters
self._edit(execution=execution)
@ -940,7 +972,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def get_project_name(self):
# type: () -> Optional[str]
if self.project is None:
return None
return self._project_name[1] if self._project_name and len(self._project_name) > 1 else None
if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project:
return self._project_name[1]
@ -1232,12 +1264,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _get_default_report_storage_uri(self):
# type: () -> str
if self._offline_mode:
return str(self.get_offline_mode_folder() / 'data')
if not self._files_server:
self._files_server = Session.get_files_server_host()
return self._files_server
def _get_status(self):
# type: () -> (Optional[str], Optional[str])
if self._offline_mode:
return tasks.TaskStatusEnum.created, 'offline'
# noinspection PyBroadException
try:
all_tasks = self.send(
@ -1296,6 +1334,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _edit(self, **kwargs):
# type: (**Any) -> Any
with self._edit_lock:
if self._offline_mode:
for k, v in kwargs.items():
setattr(self.data, k, v)
Path(self.get_offline_mode_folder()).mkdir(parents=True, exist_ok=True)
with open((self.get_offline_mode_folder() / self._offline_filename).as_posix(), 'wt') as f:
export_data = self.data.to_dict()
export_data['project_name'] = self.get_project_name()
export_data['offline_folder'] = self.get_offline_mode_folder().as_posix()
json.dump(export_data, f, ensure_ascii=True, sort_keys=True)
return None
# Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
if status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
@ -1315,15 +1364,32 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# protection, Old API might not support it
# noinspection PyBroadException
try:
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
with self._edit_lock:
self.reload()
self.data.script.requirements = requirements
if self._offline_mode:
self._edit(script=self.data.script)
else:
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
except Exception:
pass
def _update_script(self, script):
# type: (dict) -> ()
self.data.script = script
self._edit(script=script)
with self._edit_lock:
self.reload()
self.data.script = script
self._edit(script=script)
def get_offline_mode_folder(self):
# type: () -> (Optional[Path])
"""
Return the folder where all the task outputs and logs are stored in the offline session.
:return: Path object, local folder, later to be used with `report_offline_session()`
"""
if not self._offline_mode:
return None
return get_offline_dir(task_id=self.task_id)
@classmethod
def _clone_task(
@ -1475,13 +1541,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not PROC_MASTER_ID_ENV_VAR.get() or len(PROC_MASTER_ID_ENV_VAR.get().split(':')) < 2:
self.__edit_lock = RLock()
elif PROC_MASTER_ID_ENV_VAR.get().split(':')[1] == str(self.id):
# remove previous file lock instance, just in case.
filename = os.path.join(gettempdir(), 'trains_{}.lock'.format(self.id))
# noinspection PyBroadException
try:
os.unlink(filename)
except Exception:
pass
# no need to remove previous file lock if we have a dead process, it will automatically release the lock.
# # noinspection PyBroadException
# try:
# os.unlink(filename)
# except Exception:
# pass
# create a new file based lock
self.__edit_lock = FileRLock(filename=filename)
else:
@ -1523,3 +1589,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
return is_subprocess
@classmethod
def set_offline(cls, offline_mode=False):
# type: (bool) -> ()
"""
Set offline mode, where all data and logs are stored into local folder, for later transmission
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
:return:
"""
ENV_OFFLINE_MODE.set(offline_mode)
InterfaceBase._offline_mode = bool(offline_mode)
Session._offline_mode = bool(offline_mode)
@classmethod
def is_offline(cls):
# type: () -> bool
"""
Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
:return: boolean offline-mode state
"""
return cls._offline_mode

View File

@ -52,6 +52,8 @@ def make_message(s, **kwargs):
def get_or_create_project(session, project_name, description=None):
res = session.send(projects.GetAllRequest(name=exact_match_regex(project_name)))
if not res:
return None
if res.response.projects:
return res.response.projects[0].id
res = session.send(projects.CreateRequest(name=project_name, description=description))

View File

@ -19,13 +19,21 @@ def get_cache_dir():
cache_base_dir = Path( # noqa: F405
expandvars(
expanduser(
TRAINS_CACHE_DIR.get() or config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR # noqa: F405
TRAINS_CACHE_DIR.get() or
config.get("storage.cache.default_base_dir") or
DEFAULT_CACHE_DIR # noqa: F405
)
)
)
return cache_base_dir
def get_offline_dir(task_id=None):
if not task_id:
return get_cache_dir() / 'offline'
return get_cache_dir() / 'offline' / task_id
def get_config_for_bucket(base_url, extra_configurations=None):
config_list = S3BucketConfigurations.from_config(config.get("aws.s3"))

View File

@ -982,6 +982,7 @@ class Logger(object):
For example: ``s3://bucket/directory/``, or ``file:///tmp/debug/``.
"""
# noinspection PyProtectedMember
return self._default_upload_destination or self._task._get_default_report_storage_uri()
def flush(self):

View File

@ -1,12 +1,14 @@
import atexit
import json
import os
import shutil
import signal
import sys
import threading
import time
from argparse import ArgumentParser
from tempfile import mkstemp
from tempfile import mkstemp, mkdtemp
from zipfile import ZipFile, ZIP_DEFLATED
try:
# noinspection PyCompatibility
@ -25,6 +27,7 @@ from .backend_api.session.session import Session, ENV_ACCESS_KEY, ENV_SECRET_KEY
from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
from .backend_interface.task.log import TaskHandler
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
@ -446,7 +449,9 @@ class Task(_Task):
not auto_connect_frameworks.get('detect_repository', True)) else True
)
# set defaults
if output_uri:
if cls._offline_mode:
task.output_uri = None
elif output_uri:
task.output_uri = output_uri
elif cls.__default_output_uri:
task.output_uri = cls.__default_output_uri
@ -530,9 +535,11 @@ class Task(_Task):
logger = task.get_logger()
# show the debug metrics page in the log, it is very convenient
if not is_sub_process_task_id:
logger.report_text(
'TRAINS results page: {}'.format(task.get_output_log_web_page()),
)
if cls._offline_mode:
logger.report_text('TRAINS running in offline mode, session stored in {}'.format(
task.get_offline_mode_folder()))
else:
logger.report_text('TRAINS results page: {}'.format(task.get_output_log_web_page()))
# Make sure we start the dev worker if required, otherwise it will only be started when we write
# something to the log.
task._dev_mode_task_start()
@ -1362,7 +1369,7 @@ class Task(_Task):
:return: The last reported iteration number.
"""
self._reload_last_iteration()
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
return max(self.data.last_iteration or 0, self._reporter.max_iteration if self._reporter else 0)
def set_initial_iteration(self, offset=0):
# type: (int) -> int
@ -1570,11 +1577,11 @@ class Task(_Task):
:param task_data: dictionary with full Task configuration
:return: return True if Task update was successful
"""
return self.import_task(task_data=task_data, target_task=self, update=True)
return bool(self.import_task(task_data=task_data, target_task=self, update=True))
@classmethod
def import_task(cls, task_data, target_task=None, update=False):
# type: (dict, Optional[Union[str, Task]], bool) -> bool
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]
"""
Import (create) Task from previously exported Task configuration (see Task.export_task)
Can also be used to edit/update an existing Task (by passing `target_task` and `update=True`).
@ -1595,7 +1602,7 @@ class Task(_Task):
"received `target_task` type {}".format(type(target_task)))
target_task.reload()
cur_data = target_task.data.to_dict()
cur_data = merge_dicts(cur_data, task_data) if update else task_data
cur_data = merge_dicts(cur_data, task_data) if update else dict(**task_data)
cur_data.pop('id', None)
cur_data.pop('project', None)
# noinspection PyProtectedMember
@ -1604,8 +1611,79 @@ class Task(_Task):
res = target_task._edit(**cur_data)
if res and res.ok():
target_task.reload()
return True
return False
return target_task
return None
@classmethod
def import_offline_session(cls, session_folder_zip):
# type: (str) -> (Optional[str])
"""
Upload an off line session (execution) of a Task.
Full Task execution includes repository details, installed packages, artifacts, logs, metric and debug samples.
:param session_folder_zip: Path to a folder containing the session, or zip-file of the session folder.
:return: Newly created task ID (str)
"""
print('TRAINS: Importing offline session from {}'.format(session_folder_zip))
temp_folder = None
if Path(session_folder_zip).is_file():
# unzip the file:
temp_folder = mkdtemp(prefix='trains-offline-')
ZipFile(session_folder_zip).extractall(path=temp_folder)
session_folder_zip = temp_folder
session_folder = Path(session_folder_zip)
if not session_folder.is_dir():
raise ValueError("Could not find the session folder / zip-file {}".format(session_folder))
try:
with open(session_folder / cls._offline_filename, 'rt') as f:
export_data = json.load(f)
except Exception as ex:
raise ValueError(
"Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex))
task = cls.import_task(export_data)
task.mark_started(force=True)
# fix artifacts
if task.data.execution.artifacts:
from . import StorageManager
# noinspection PyProtectedMember
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
remote_url = task._get_default_report_storage_uri()
if remote_url and remote_url.endswith('/'):
remote_url = remote_url[:-1]
for artifact in task.data.execution.artifacts:
local_path = artifact.uri.replace(offline_folder, '', 1)
local_file = session_folder / 'data' / local_path
if local_file.is_file():
remote_path = local_path.replace(
'.{}{}'.format(export_data['id'], os.sep), '.{}{}'.format(task.id, os.sep), 1)
artifact.uri = '{}/{}'.format(remote_url, remote_path)
StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri)
# noinspection PyProtectedMember
task._edit(execution=task.data.execution)
# logs
TaskHandler.report_offline_session(task, session_folder)
# metrics
Metrics.report_offline_session(task, session_folder)
# print imported results page
print('TRAINS results page: {}'.format(task.get_output_log_web_page()))
task.completed()
# close task
task.close()
# cleanup
if temp_folder:
# noinspection PyBroadException
try:
shutil.rmtree(temp_folder)
except Exception:
pass
return task.id
@classmethod
def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None):
@ -2099,7 +2177,7 @@ class Task(_Task):
parent.terminate()
def _dev_mode_setup_worker(self, model_updated=False):
if running_remotely() or not self.is_main_task() or self._at_exit_called:
if running_remotely() or not self.is_main_task() or self._at_exit_called or self._offline_mode:
return
if self._dev_worker:
@ -2283,6 +2361,23 @@ class Task(_Task):
except Exception:
# make sure we do not interrupt the exit process
pass
# make sure we store last task state
if self._offline_mode and not is_sub_process:
# noinspection PyBroadException
try:
# create zip file
offline_folder = self.get_offline_mode_folder()
zip_file = offline_folder.as_posix() + '.zip'
with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf:
for filename in offline_folder.rglob('*'):
if filename.is_file():
relative_file_name = filename.relative_to(offline_folder).as_posix()
zf.write(filename.as_posix(), arcname=relative_file_name)
print('TRAINS Task: Offline session stored in {}'.format(zip_file))
except Exception as ex:
pass
# delete locking object (lock file)
if self._edit_lock:
# noinspection PyBroadException
@ -2597,7 +2692,7 @@ class Task(_Task):
@classmethod
def __get_task_api_obj(cls, task_id, only_fields=None):
if not task_id:
if not task_id or cls._offline_mode:
return None
all_tasks = cls._send(