diff --git a/trains/backend_api/api_proxy.py b/trains/backend_api/api_proxy.py index a68a4578..2fc02d5b 100644 --- a/trains/backend_api/api_proxy.py +++ b/trains/backend_api/api_proxy.py @@ -31,7 +31,7 @@ class ApiServiceProxy(object): ]] # get the most advanced service version that supports our api - version = [str(v) for v in ApiServiceProxy._available_versions if Version(Session.api_version) >= v][-1] + version = [str(v) for v in ApiServiceProxy._available_versions if Session.check_min_api_version(v)][-1] self.__dict__["__wrapped_version__"] = Session.api_version name = ".v{}.{}".format( version.replace(".", "_"), self.__dict__.get("__wrapped_name__") diff --git a/trains/backend_api/session/defs.py b/trains/backend_api/session/defs.py index 5ea6a97d..c89ee867 100644 --- a/trains/backend_api/session/defs.py +++ b/trains/backend_api/session/defs.py @@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY") ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY") ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False) ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True) +ENV_OFFLINE_MODE = EnvEntry("TRAINS_OFFLINE_MODE", "ALG_OFFLINE_MODE", type=bool) diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index 8e55e4ac..6ca3c54d 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -11,7 +11,8 @@ from requests.auth import HTTPBasicAuth from six.moves.urllib.parse import urlparse, urlunparse from .callresult import CallResult -from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST +from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, \ + ENV_FILES_HOST, ENV_OFFLINE_MODE from .request import Request, BatchRequest # noqa: F401 from .token_manager import TokenManager from ..config import load @@ -50,6 +51,8 @@ class Session(TokenManager): _write_session_timeout = (300.0, 300.) _sessions_created = 0 _ssl_error_count_verbosity = 2 + _offline_mode = ENV_OFFLINE_MODE.get() + _offline_default_version = '2.5' _client = [(__package__.partition(".")[0], __version__)] @@ -153,6 +156,9 @@ class Session(TokenManager): self.client = ", ".join("{}-{}".format(*x) for x in self._client) + if self._offline_mode: + return + self.refresh_token() # update api version from server response @@ -197,6 +203,9 @@ class Session(TokenManager): server-side permissions have changed but are not reflected in the current token. Refreshing the token will generate a token with the updated permissions. """ + if self._offline_mode: + return None + host = self.host headers = headers.copy() if headers else {} headers[self._WORKER_HEADER] = self.worker @@ -406,6 +415,9 @@ class Session(TokenManager): """ self.validate_request(req_obj) + if self._offline_mode: + return None + if isinstance(req_obj, BatchRequest): # TODO: support async for batch requests as well if async_enable: @@ -526,11 +538,14 @@ class Session(TokenManager): # If no session was created, create a default one, in order to get the backend api version. if cls._sessions_created <= 0: - # noinspection PyBroadException - try: - cls() - except Exception: - pass + if cls._offline_mode: + cls.api_version = cls._offline_default_version + else: + # noinspection PyBroadException + try: + cls() + except Exception: + pass return version_tuple(cls.api_version) >= version_tuple(str(min_api_version)) diff --git a/trains/backend_interface/base.py b/trains/backend_interface/base.py index df228f71..54858944 100644 --- a/trains/backend_interface/base.py +++ b/trains/backend_interface/base.py @@ -7,7 +7,7 @@ from ..backend_api import Session, CallResult from ..backend_api.session.session import MaxRequestSizeError from ..backend_api.session.response import ResponseMeta from ..backend_api.session import BatchRequest -from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY +from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_OFFLINE_MODE from ..config import config_obj from ..config.defs import LOG_LEVEL_ENV_VAR @@ -19,6 +19,7 @@ class InterfaceBase(SessionInterface): """ Base class for a backend manager class """ _default_session = None _num_retry_warning_display = 1 + _offline_mode = ENV_OFFLINE_MODE.get() @property def session(self): @@ -44,6 +45,9 @@ class InterfaceBase(SessionInterface): @classmethod def _send(cls, session, req, ignore_errors=False, raise_on_errors=True, log=None, async_enable=False): """ Convenience send() method providing a standardized error reporting """ + if cls._offline_mode: + return None + num_retries = 0 while True: error_msg = '' @@ -151,7 +155,7 @@ class IdObjectBase(InterfaceBase): pass def reload(self): - if not self.id: + if not self.id and not self._offline_mode: raise ValueError('Failed reloading %s: missing id' % type(self).__name__) # noinspection PyBroadException try: diff --git a/trains/backend_interface/logger.py b/trains/backend_interface/logger.py index e3c6a2a2..be6c3268 100644 --- a/trains/backend_interface/logger.py +++ b/trains/backend_interface/logger.py @@ -18,7 +18,7 @@ class StdStreamPatch(object): if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely(): StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO) StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR) - logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100) + logger._task_handler = TaskHandler(task=logger._task, capacity=100) # noinspection PyBroadException try: if StdStreamPatch._stdout_original_write is None: @@ -70,7 +70,7 @@ class StdStreamPatch(object): pass elif DevWorker.report_stdout and not running_remotely(): - logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100) + logger._task_handler = TaskHandler(task=logger._task, capacity=100) if StdStreamPatch._stdout_proxy: StdStreamPatch._stdout_proxy.connect(logger) if StdStreamPatch._stderr_proxy: diff --git a/trains/backend_interface/metrics/interface.py b/trains/backend_interface/metrics/interface.py index 6f1ec090..cc9aed89 100644 --- a/trains/backend_interface/metrics/interface.py +++ b/trains/backend_interface/metrics/interface.py @@ -1,4 +1,7 @@ +import json +import os from functools import partial +from logging import warning from multiprocessing.pool import ThreadPool from multiprocessing import Lock from time import time @@ -25,6 +28,7 @@ class Metrics(InterfaceBase): _file_upload_retries = 3 _upload_pool = None _file_upload_pool = None + __offline_filename = 'metrics.jsonl' @property def storage_key_prefix(self): @@ -43,14 +47,19 @@ class Metrics(InterfaceBase): finally: self._storage_lock.release() - def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None): + def __init__(self, session, task, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None): super(Metrics, self).__init__(session, log=log) - self._task_id = task_id + self._task_id = task.id self._task_iteration_offset = iteration_offset self._storage_uri = storage_uri.rstrip('/') if storage_uri else None self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None self._file_related_event_time = None self._file_upload_time = None + self._offline_log_filename = None + if self._offline_mode: + offline_folder = Path(task.get_offline_mode_folder()) + offline_folder.mkdir(parents=True, exist_ok=True) + self._offline_log_filename = offline_folder / self.__offline_filename def write_events(self, events, async_enable=True, callback=None, **kwargs): """ @@ -167,6 +176,7 @@ class Metrics(InterfaceBase): e.set_exception(exp) e.stream.close() if e.delete_local_file: + # noinspection PyBroadException try: Path(e.delete_local_file).unlink() except Exception: @@ -199,6 +209,11 @@ class Metrics(InterfaceBase): _events = [ev.get_api_event() for ev in good_events] batched_requests = [api_events.AddRequest(event=ev) for ev in _events if ev] if batched_requests: + if self._offline_mode: + with open(self._offline_log_filename.as_posix(), 'at') as f: + f.write(json.dumps([b.to_dict() for b in batched_requests])+'\n') + return + req = api_events.AddBatchRequest(requests=batched_requests) return self.send(req, raise_on_errors=False) @@ -234,3 +249,69 @@ class Metrics(InterfaceBase): pool.join() except Exception: pass + + @classmethod + def report_offline_session(cls, task, folder): + from ... import StorageManager + filename = Path(folder) / cls.__offline_filename + if not filename.is_file(): + return False + # noinspection PyProtectedMember + remote_url = task._get_default_report_storage_uri() + if remote_url and remote_url.endswith('/'): + remote_url = remote_url[:-1] + uploaded_files = set() + task_id = task.id + with open(filename, 'rt') as f: + i = 0 + while True: + try: + line = f.readline() + if not line: + break + list_requests = json.loads(line) + for r in list_requests: + org_task_id = r['task'] + r['task'] = task_id + if r.get('key') and r.get('url'): + debug_sample = (Path(folder) / 'data').joinpath(*(r['key'].split('/'))) + r['key'] = r['key'].replace( + '.{}{}'.format(org_task_id, os.sep), '.{}{}'.format(task_id, os.sep), 1) + r['url'] = '{}/{}'.format(remote_url, r['key']) + if debug_sample not in uploaded_files and debug_sample.is_file(): + uploaded_files.add(debug_sample) + StorageManager.upload_file(local_file=debug_sample.as_posix(), remote_url=r['url']) + elif r.get('plot_str'): + # hack plotly embedded images links + # noinspection PyBroadException + try: + task_id_sep = '.{}{}'.format(org_task_id, os.sep) + plot = json.loads(r['plot_str']) + if plot.get('layout', {}).get('images'): + for image in plot['layout']['images']: + if task_id_sep not in image['source']: + continue + pre, post = image['source'].split(task_id_sep, 1) + pre = os.sep.join(pre.split(os.sep)[-2:]) + debug_sample = (Path(folder) / 'data').joinpath( + pre+'.{}'.format(org_task_id), post) + image['source'] = '/'.join( + [remote_url, pre + '.{}'.format(task_id), post]) + if debug_sample not in uploaded_files and debug_sample.is_file(): + uploaded_files.add(debug_sample) + StorageManager.upload_file( + local_file=debug_sample.as_posix(), remote_url=image['source']) + r['plot_str'] = json.dumps(plot) + except Exception: + pass + i += 1 + except StopIteration: + break + except Exception as ex: + warning('Failed reporting metric, line {} [{}]'.format(i, ex)) + batch_requests = api_events.AddBatchRequest(requests=list_requests) + res = task.session.send(batch_requests) + if res and not res.ok(): + warning("failed logging metric task to backend ({:d} lines, {})".format( + len(batch_requests.requests), str(res.meta))) + return True diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 2ab976b3..09774c4c 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -659,7 +659,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan # Hack: if the url doesn't start with http/s then the plotly will not be able to show it, # then we put the link under images not plots - if not url.startswith('http'): + if not url.startswith('http') and not self._offline_mode: return self.report_image_and_upload(title=title, series=series, iter=iter, path=path, image=matrix, upload_uri=upload_uri, max_image_history=max_image_history) diff --git a/trains/backend_interface/model.py b/trains/backend_interface/model.py index 52d578da..9ffd6ec1 100644 --- a/trains/backend_interface/model.py +++ b/trains/backend_interface/model.py @@ -71,6 +71,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): def _reload(self): """ Reload the model object """ + if self._offline_mode: + return models.Model() + if self.id == self._EMPTY_MODEL_ID: return res = self.send(models.GetByIdRequest(model=self.id)) @@ -186,35 +189,40 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): task = task_id or self.data.task project = project_id or self.data.project parent = parent_id or self.data.parent - if tags: - extra = {'system_tags': tags or self.data.system_tags} \ - if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags} - else: - extra = {} - self.send(models.EditRequest( + self._edit( model=self.id, uri=uri, name=name, comment=comment, labels=labels, design=design, + framework=framework or self.data.framework, + iteration=iteration, task=task, project=project, parent=parent, - framework=framework or self.data.framework, - iteration=iteration, - **extra - )) - self.reload() + ) def edit(self, design=None, labels=None, name=None, comment=None, tags=None, uri=None, framework=None, iteration=None): + return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags, + uri=uri, framework=framework, iteration=iteration) + + def _edit(self, design=None, labels=None, name=None, comment=None, tags=None, + uri=None, framework=None, iteration=None, **extra): + def offline_store(**kwargs): + for k, v in kwargs.items(): + setattr(self.data, k, v or getattr(self.data, k, None)) + return + if self._offline_mode: + return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags, + uri=uri, framework=framework, iteration=iteration, **extra) + if tags: - extra = {'system_tags': tags or self.data.system_tags} \ - if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags} - else: - extra = {} + extra.update({'system_tags': tags or self.data.system_tags} + if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}) + self.send(models.EditRequest( model=self.id, uri=uri, @@ -298,7 +306,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): override_model_id=override_model_id, **extra)) if self.id is None: # update the model id. in case it was just created, this will trigger a reload of the model object - self.id = res.response.id + self.id = res.response.id if res else None else: self.reload() try: diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index 5712ec13..97b58bc3 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -306,10 +306,11 @@ class _Arguments(object): # TODO: add dict prefix prefix = prefix or '' # self._prefix_dict if prefix: - prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()]) - cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)]) - cur_params.update(prefix_dictionary) - self._task.set_parameters(cur_params) + with self._task._edit_lock: + prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()]) + cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)]) + cur_params.update(prefix_dictionary) + self._task.set_parameters(cur_params) else: self._task.update_parameters(dictionary) if not isinstance(dictionary, self._ProxyDictWrite): diff --git a/trains/backend_interface/task/log.py b/trains/backend_interface/task/log.py index 99ede809..411db7fa 100644 --- a/trains/backend_interface/task/log.py +++ b/trains/backend_interface/task/log.py @@ -1,6 +1,8 @@ +import json import sys import time -from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord +from pathlib2 import Path +from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord, warning from logging.handlers import BufferingHandler from threading import Thread, Event from six.moves.queue import Queue @@ -17,6 +19,7 @@ class TaskHandler(BufferingHandler): __wait_for_flush_timeout = 10. __max_event_size = 1024 * 1024 __once = False + __offline_filename = 'log.jsonl' @property def task_id(self): @@ -26,10 +29,10 @@ class TaskHandler(BufferingHandler): def task_id(self, value): self._task_id = value - def __init__(self, session, task_id, capacity=buffer_capacity): + def __init__(self, task, capacity=buffer_capacity): super(TaskHandler, self).__init__(capacity) - self.task_id = task_id - self.session = session + self.task_id = task.id + self.session = task.session self.last_timestamp = 0 self.counter = 1 self._last_event = None @@ -37,6 +40,11 @@ class TaskHandler(BufferingHandler): self._queue = None self._thread = None self._pending = 0 + self._offline_log_filename = None + if task.is_offline(): + offline_folder = Path(task.get_offline_mode_folder()) + offline_folder.mkdir(parents=True, exist_ok=True) + self._offline_log_filename = offline_folder / self.__offline_filename def shouldFlush(self, record): """ @@ -124,6 +132,7 @@ class TaskHandler(BufferingHandler): if not self.buffer: return + buffer = None self.acquire() if self.buffer: buffer = self.buffer @@ -133,6 +142,7 @@ class TaskHandler(BufferingHandler): if not buffer: return + # noinspection PyBroadException try: record_events = [r for record in buffer for r in self._record_to_event(record)] + [self._last_event] self._last_event = None @@ -194,11 +204,17 @@ class TaskHandler(BufferingHandler): def _send_events(self, a_request): try: + self._pending -= 1 + + if self._offline_log_filename: + with open(self._offline_log_filename.as_posix(), 'at') as f: + f.write(json.dumps([b.to_dict() for b in a_request.requests]) + '\n') + return + # if self._thread is None: # self.__log_stderr('Task.close() flushing remaining logs ({})'.format(self._pending)) - self._pending -= 1 res = self.session.send(a_request) - if not res.ok(): + if res and not res.ok(): self.__log_stderr("failed logging task to backend ({:d} lines, {})".format( len(a_request.requests), str(res.meta)), level=WARNING) except MaxRequestSizeError: @@ -237,3 +253,31 @@ class TaskHandler(BufferingHandler): write('{asctime} - {name} - {levelname} - {message}\n'.format( asctime=Formatter().formatTime(makeLogRecord({})), name='trains.log', levelname=getLevelName(level), message=msg)) + + @classmethod + def report_offline_session(cls, task, folder): + filename = Path(folder) / cls.__offline_filename + if not filename.is_file(): + return False + with open(filename, 'rt') as f: + i = 0 + while True: + try: + line = f.readline() + if not line: + break + list_requests = json.loads(line) + for r in list_requests: + r.pop('task', None) + i += 1 + except StopIteration: + break + except Exception as ex: + warning('Failed reporting log, line {} [{}]'.format(i, ex)) + batch_requests = events.AddBatchRequest( + requests=[events.TaskLogEvent(task=task.id, **r) for r in list_requests]) + res = task.session.send(batch_requests) + if res and not res.ok(): + warning("failed logging task to backend ({:d} lines, {})".format( + len(batch_requests.requests), str(res.meta))) + return True diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 659b6af7..0308fa91 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -1,5 +1,6 @@ """ Backend task management support """ import itertools +import json import logging import os import sys @@ -7,8 +8,10 @@ import re from enum import Enum from tempfile import gettempdir from multiprocessing import RLock +from pathlib2 import Path from threading import Thread from typing import Optional, Any, Sequence, Callable, Mapping, Union, List +from uuid import uuid4 try: # noinspection PyCompatibility @@ -25,17 +28,18 @@ from ...binding.artifacts import Artifacts from ...backend_interface.task.development.worker import DevWorker from ...backend_api import Session from ...backend_api.services import tasks, models, events, projects -from pathlib2 import Path +from ...backend_api.session.defs import ENV_OFFLINE_MODE from ...utilities.pyhocon import ConfigTree, ConfigFactory -from ..base import IdObjectBase +from ..base import IdObjectBase, InterfaceBase from ..metrics import Metrics, Reporter from ..model import Model from ..setupuploadmixin import SetupUploadMixin from ..util import make_message, get_or_create_project, get_single_result, \ exact_match_regex -from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \ - running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR +from ...config import ( + get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, + running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir) from ...debugging import get_logger from ...debugging.log import LoggerRoot from ...storage.helper import StorageHelper, StorageError @@ -56,6 +60,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): _force_requirements = {} _store_diff = config.get('development.store_uncommitted_code_diff', False) + _offline_filename = 'task.json' class TaskTypes(Enum): def __str__(self): @@ -143,6 +148,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if not task_id: # generate a new task self.id = self._auto_generate(project_name=project_name, task_name=task_name, task_type=task_type) + if self._offline_mode: + self.data.id = self.id + self.name = task_name else: # this is an existing task, let's try to verify stuff self._validate() @@ -195,7 +203,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more # than one instance to report to the same task will result in out-of-order log reports (grouped by whichever # handler instance handled them) - backend_handler = TaskHandler(self.session, self.task_id) + backend_handler = TaskHandler(task=self) # Add backend handler to both loggers: # 1. to root logger root logger @@ -280,11 +288,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # to the module call including all argv's result.script = ScriptInfo.detect_running_module(result.script) - self.data.script = result.script # Since we might run asynchronously, don't use self.data (let someone else # overwrite it before we have a chance to call edit) - self._edit(script=result.script) - self.reload() + with self._edit_lock: + self.reload() + self.data.script = result.script + self._edit(script=result.script) + # if jupyter is present, requirements will be created in the background, when saving a snapshot if result.script and script_requirements: entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \ @@ -304,7 +314,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): result.script['requirements']['conda'] = conda_requirements self._update_requirements(result.script.get('requirements') or '') - self.reload() # we do not want to wait for the check version thread, # because someone might wait for us to finish the repo detection update @@ -339,7 +348,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): ) res = self.send(req) - return res.response.id + return res.response.id if res else 'offline-{}'.format(str(uuid4()).replace("-", "")) def _set_storage_uri(self, value): value = value.rstrip('/') if value else None @@ -498,7 +507,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if self._metrics_manager is None: self._metrics_manager = Metrics( session=self.session, - task_id=self.id, + task=self, storage_uri=storage_uri, storage_uri_suffix=self._get_output_destination_suffix('metrics'), iteration_offset=self.get_initial_iteration() @@ -523,6 +532,27 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # type: () -> Any """ Reload the task object from the backend """ with self._edit_lock: + if self._offline_mode: + # noinspection PyBroadException + try: + with open(self.get_offline_mode_folder() / self._offline_filename, 'rt') as f: + stored_dict = json.load(f) + stored_data = tasks.Task(**stored_dict) + # add missing entries + for k, v in stored_dict.items(): + if not hasattr(stored_data, k): + setattr(stored_data, k, v) + if stored_dict.get('project_name'): + self._project_name = (None, stored_dict.get('project_name')) + except Exception: + stored_data = self._data + + return stored_data or tasks.Task( + execution=tasks.Execution( + parameters={}, artifacts=[], dataviews=[], model='', + model_desc={}, model_labels={}, docker_cmd=''), + output=tasks.Output()) + if self._reload_skip_flag and self._data: return self._data res = self.send(tasks.GetByIdRequest(task=self.id)) @@ -774,7 +804,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): execution = self.data.execution if execution is None: - execution = tasks.Execution(parameters=parameters) + execution = tasks.Execution( + parameters=parameters, artifacts=[], dataviews=[], model='', + model_desc={}, model_labels={}, docker_cmd='') else: execution.parameters = parameters self._edit(execution=execution) @@ -940,7 +972,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def get_project_name(self): # type: () -> Optional[str] if self.project is None: - return None + return self._project_name[1] if self._project_name and len(self._project_name) > 1 else None if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project: return self._project_name[1] @@ -1232,12 +1264,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _get_default_report_storage_uri(self): # type: () -> str + if self._offline_mode: + return str(self.get_offline_mode_folder() / 'data') + if not self._files_server: self._files_server = Session.get_files_server_host() return self._files_server def _get_status(self): # type: () -> (Optional[str], Optional[str]) + if self._offline_mode: + return tasks.TaskStatusEnum.created, 'offline' + # noinspection PyBroadException try: all_tasks = self.send( @@ -1296,6 +1334,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _edit(self, **kwargs): # type: (**Any) -> Any with self._edit_lock: + if self._offline_mode: + for k, v in kwargs.items(): + setattr(self.data, k, v) + Path(self.get_offline_mode_folder()).mkdir(parents=True, exist_ok=True) + with open((self.get_offline_mode_folder() / self._offline_filename).as_posix(), 'wt') as f: + export_data = self.data.to_dict() + export_data['project_name'] = self.get_project_name() + export_data['offline_folder'] = self.get_offline_mode_folder().as_posix() + json.dump(export_data, f, ensure_ascii=True, sort_keys=True) + return None + # Since we ae using forced update, make sure he task status is valid status = self._data.status if self._data and self._reload_skip_flag else self.data.status if status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress): @@ -1315,15 +1364,32 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # protection, Old API might not support it # noinspection PyBroadException try: - self.data.script.requirements = requirements - self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) + with self._edit_lock: + self.reload() + self.data.script.requirements = requirements + if self._offline_mode: + self._edit(script=self.data.script) + else: + self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) except Exception: pass def _update_script(self, script): # type: (dict) -> () - self.data.script = script - self._edit(script=script) + with self._edit_lock: + self.reload() + self.data.script = script + self._edit(script=script) + + def get_offline_mode_folder(self): + # type: () -> (Optional[Path]) + """ + Return the folder where all the task outputs and logs are stored in the offline session. + :return: Path object, local folder, later to be used with `report_offline_session()` + """ + if not self._offline_mode: + return None + return get_offline_dir(task_id=self.task_id) @classmethod def _clone_task( @@ -1475,13 +1541,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if not PROC_MASTER_ID_ENV_VAR.get() or len(PROC_MASTER_ID_ENV_VAR.get().split(':')) < 2: self.__edit_lock = RLock() elif PROC_MASTER_ID_ENV_VAR.get().split(':')[1] == str(self.id): - # remove previous file lock instance, just in case. filename = os.path.join(gettempdir(), 'trains_{}.lock'.format(self.id)) - # noinspection PyBroadException - try: - os.unlink(filename) - except Exception: - pass + # no need to remove previous file lock if we have a dead process, it will automatically release the lock. + # # noinspection PyBroadException + # try: + # os.unlink(filename) + # except Exception: + # pass # create a new file based lock self.__edit_lock = FileRLock(filename=filename) else: @@ -1523,3 +1589,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \ PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid()) return is_subprocess + + @classmethod + def set_offline(cls, offline_mode=False): + # type: (bool) -> () + """ + Set offline mode, where all data and logs are stored into local folder, for later transmission + + :param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled. + :return: + """ + ENV_OFFLINE_MODE.set(offline_mode) + InterfaceBase._offline_mode = bool(offline_mode) + Session._offline_mode = bool(offline_mode) + + @classmethod + def is_offline(cls): + # type: () -> bool + """ + Return offline-mode state, If in offline-mode, no communication to the backend is enabled. + + :return: boolean offline-mode state + """ + return cls._offline_mode diff --git a/trains/backend_interface/util.py b/trains/backend_interface/util.py index abc61859..e4ae7d26 100644 --- a/trains/backend_interface/util.py +++ b/trains/backend_interface/util.py @@ -52,6 +52,8 @@ def make_message(s, **kwargs): def get_or_create_project(session, project_name, description=None): res = session.send(projects.GetAllRequest(name=exact_match_regex(project_name))) + if not res: + return None if res.response.projects: return res.response.projects[0].id res = session.send(projects.CreateRequest(name=project_name, description=description)) diff --git a/trains/config/__init__.py b/trains/config/__init__.py index b7b11373..43030de3 100644 --- a/trains/config/__init__.py +++ b/trains/config/__init__.py @@ -19,13 +19,21 @@ def get_cache_dir(): cache_base_dir = Path( # noqa: F405 expandvars( expanduser( - TRAINS_CACHE_DIR.get() or config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR # noqa: F405 + TRAINS_CACHE_DIR.get() or + config.get("storage.cache.default_base_dir") or + DEFAULT_CACHE_DIR # noqa: F405 ) ) ) return cache_base_dir +def get_offline_dir(task_id=None): + if not task_id: + return get_cache_dir() / 'offline' + return get_cache_dir() / 'offline' / task_id + + def get_config_for_bucket(base_url, extra_configurations=None): config_list = S3BucketConfigurations.from_config(config.get("aws.s3")) diff --git a/trains/logger.py b/trains/logger.py index 296a737a..7194e681 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -982,6 +982,7 @@ class Logger(object): For example: ``s3://bucket/directory/``, or ``file:///tmp/debug/``. """ + # noinspection PyProtectedMember return self._default_upload_destination or self._task._get_default_report_storage_uri() def flush(self): diff --git a/trains/task.py b/trains/task.py index 0bba0743..376f3b1d 100644 --- a/trains/task.py +++ b/trains/task.py @@ -1,12 +1,14 @@ import atexit +import json import os +import shutil import signal import sys import threading import time from argparse import ArgumentParser -from tempfile import mkstemp - +from tempfile import mkstemp, mkdtemp +from zipfile import ZipFile, ZIP_DEFLATED try: # noinspection PyCompatibility @@ -25,6 +27,7 @@ from .backend_api.session.session import Session, ENV_ACCESS_KEY, ENV_SECRET_KEY from .backend_interface.metrics import Metrics from .backend_interface.model import Model as BackendModel from .backend_interface.task import Task as _Task +from .backend_interface.task.log import TaskHandler from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.repo import ScriptInfo from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive @@ -446,7 +449,9 @@ class Task(_Task): not auto_connect_frameworks.get('detect_repository', True)) else True ) # set defaults - if output_uri: + if cls._offline_mode: + task.output_uri = None + elif output_uri: task.output_uri = output_uri elif cls.__default_output_uri: task.output_uri = cls.__default_output_uri @@ -530,9 +535,11 @@ class Task(_Task): logger = task.get_logger() # show the debug metrics page in the log, it is very convenient if not is_sub_process_task_id: - logger.report_text( - 'TRAINS results page: {}'.format(task.get_output_log_web_page()), - ) + if cls._offline_mode: + logger.report_text('TRAINS running in offline mode, session stored in {}'.format( + task.get_offline_mode_folder())) + else: + logger.report_text('TRAINS results page: {}'.format(task.get_output_log_web_page())) # Make sure we start the dev worker if required, otherwise it will only be started when we write # something to the log. task._dev_mode_task_start() @@ -1362,7 +1369,7 @@ class Task(_Task): :return: The last reported iteration number. """ self._reload_last_iteration() - return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0) + return max(self.data.last_iteration or 0, self._reporter.max_iteration if self._reporter else 0) def set_initial_iteration(self, offset=0): # type: (int) -> int @@ -1570,11 +1577,11 @@ class Task(_Task): :param task_data: dictionary with full Task configuration :return: return True if Task update was successful """ - return self.import_task(task_data=task_data, target_task=self, update=True) + return bool(self.import_task(task_data=task_data, target_task=self, update=True)) @classmethod def import_task(cls, task_data, target_task=None, update=False): - # type: (dict, Optional[Union[str, Task]], bool) -> bool + # type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task] """ Import (create) Task from previously exported Task configuration (see Task.export_task) Can also be used to edit/update an existing Task (by passing `target_task` and `update=True`). @@ -1595,7 +1602,7 @@ class Task(_Task): "received `target_task` type {}".format(type(target_task))) target_task.reload() cur_data = target_task.data.to_dict() - cur_data = merge_dicts(cur_data, task_data) if update else task_data + cur_data = merge_dicts(cur_data, task_data) if update else dict(**task_data) cur_data.pop('id', None) cur_data.pop('project', None) # noinspection PyProtectedMember @@ -1604,8 +1611,79 @@ class Task(_Task): res = target_task._edit(**cur_data) if res and res.ok(): target_task.reload() - return True - return False + return target_task + return None + + @classmethod + def import_offline_session(cls, session_folder_zip): + # type: (str) -> (Optional[str]) + """ + Upload an off line session (execution) of a Task. + Full Task execution includes repository details, installed packages, artifacts, logs, metric and debug samples. + + :param session_folder_zip: Path to a folder containing the session, or zip-file of the session folder. + :return: Newly created task ID (str) + """ + print('TRAINS: Importing offline session from {}'.format(session_folder_zip)) + + temp_folder = None + if Path(session_folder_zip).is_file(): + # unzip the file: + temp_folder = mkdtemp(prefix='trains-offline-') + ZipFile(session_folder_zip).extractall(path=temp_folder) + session_folder_zip = temp_folder + + session_folder = Path(session_folder_zip) + if not session_folder.is_dir(): + raise ValueError("Could not find the session folder / zip-file {}".format(session_folder)) + + try: + with open(session_folder / cls._offline_filename, 'rt') as f: + export_data = json.load(f) + except Exception as ex: + raise ValueError( + "Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex)) + task = cls.import_task(export_data) + task.mark_started(force=True) + # fix artifacts + if task.data.execution.artifacts: + from . import StorageManager + # noinspection PyProtectedMember + offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/') + + remote_url = task._get_default_report_storage_uri() + if remote_url and remote_url.endswith('/'): + remote_url = remote_url[:-1] + + for artifact in task.data.execution.artifacts: + local_path = artifact.uri.replace(offline_folder, '', 1) + local_file = session_folder / 'data' / local_path + if local_file.is_file(): + remote_path = local_path.replace( + '.{}{}'.format(export_data['id'], os.sep), '.{}{}'.format(task.id, os.sep), 1) + artifact.uri = '{}/{}'.format(remote_url, remote_path) + StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri) + # noinspection PyProtectedMember + task._edit(execution=task.data.execution) + # logs + TaskHandler.report_offline_session(task, session_folder) + # metrics + Metrics.report_offline_session(task, session_folder) + # print imported results page + print('TRAINS results page: {}'.format(task.get_output_log_web_page())) + task.completed() + # close task + task.close() + + # cleanup + if temp_folder: + # noinspection PyBroadException + try: + shutil.rmtree(temp_folder) + except Exception: + pass + + return task.id @classmethod def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None): @@ -2099,7 +2177,7 @@ class Task(_Task): parent.terminate() def _dev_mode_setup_worker(self, model_updated=False): - if running_remotely() or not self.is_main_task() or self._at_exit_called: + if running_remotely() or not self.is_main_task() or self._at_exit_called or self._offline_mode: return if self._dev_worker: @@ -2283,6 +2361,23 @@ class Task(_Task): except Exception: # make sure we do not interrupt the exit process pass + + # make sure we store last task state + if self._offline_mode and not is_sub_process: + # noinspection PyBroadException + try: + # create zip file + offline_folder = self.get_offline_mode_folder() + zip_file = offline_folder.as_posix() + '.zip' + with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf: + for filename in offline_folder.rglob('*'): + if filename.is_file(): + relative_file_name = filename.relative_to(offline_folder).as_posix() + zf.write(filename.as_posix(), arcname=relative_file_name) + print('TRAINS Task: Offline session stored in {}'.format(zip_file)) + except Exception as ex: + pass + # delete locking object (lock file) if self._edit_lock: # noinspection PyBroadException @@ -2597,7 +2692,7 @@ class Task(_Task): @classmethod def __get_task_api_obj(cls, task_id, only_fields=None): - if not task_id: + if not task_id or cls._offline_mode: return None all_tasks = cls._send(