mirror of
https://github.com/clearml/clearml
synced 2025-05-02 20:11:07 +00:00
Add offline support using Task.set_offline() and Task.import_offline_session()
This commit is contained in:
parent
2ec5726812
commit
a8d6380696
@ -31,7 +31,7 @@ class ApiServiceProxy(object):
|
||||
]]
|
||||
|
||||
# get the most advanced service version that supports our api
|
||||
version = [str(v) for v in ApiServiceProxy._available_versions if Version(Session.api_version) >= v][-1]
|
||||
version = [str(v) for v in ApiServiceProxy._available_versions if Session.check_min_api_version(v)][-1]
|
||||
self.__dict__["__wrapped_version__"] = Session.api_version
|
||||
name = ".v{}.{}".format(
|
||||
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")
|
||||
|
@ -8,3 +8,4 @@ ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
|
||||
ENV_HOST_VERIFY_CERT = EnvEntry("TRAINS_API_HOST_VERIFY_CERT", "ALG_API_HOST_VERIFY_CERT", type=bool, default=True)
|
||||
ENV_OFFLINE_MODE = EnvEntry("TRAINS_OFFLINE_MODE", "ALG_OFFLINE_MODE", type=bool)
|
||||
|
@ -11,7 +11,8 @@ from requests.auth import HTTPBasicAuth
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
from .callresult import CallResult
|
||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, ENV_FILES_HOST
|
||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_WEB_HOST, \
|
||||
ENV_FILES_HOST, ENV_OFFLINE_MODE
|
||||
from .request import Request, BatchRequest # noqa: F401
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
@ -50,6 +51,8 @@ class Session(TokenManager):
|
||||
_write_session_timeout = (300.0, 300.)
|
||||
_sessions_created = 0
|
||||
_ssl_error_count_verbosity = 2
|
||||
_offline_mode = ENV_OFFLINE_MODE.get()
|
||||
_offline_default_version = '2.5'
|
||||
|
||||
_client = [(__package__.partition(".")[0], __version__)]
|
||||
|
||||
@ -153,6 +156,9 @@ class Session(TokenManager):
|
||||
|
||||
self.client = ", ".join("{}-{}".format(*x) for x in self._client)
|
||||
|
||||
if self._offline_mode:
|
||||
return
|
||||
|
||||
self.refresh_token()
|
||||
|
||||
# update api version from server response
|
||||
@ -197,6 +203,9 @@ class Session(TokenManager):
|
||||
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
|
||||
generate a token with the updated permissions.
|
||||
"""
|
||||
if self._offline_mode:
|
||||
return None
|
||||
|
||||
host = self.host
|
||||
headers = headers.copy() if headers else {}
|
||||
headers[self._WORKER_HEADER] = self.worker
|
||||
@ -406,6 +415,9 @@ class Session(TokenManager):
|
||||
"""
|
||||
self.validate_request(req_obj)
|
||||
|
||||
if self._offline_mode:
|
||||
return None
|
||||
|
||||
if isinstance(req_obj, BatchRequest):
|
||||
# TODO: support async for batch requests as well
|
||||
if async_enable:
|
||||
@ -526,11 +538,14 @@ class Session(TokenManager):
|
||||
|
||||
# If no session was created, create a default one, in order to get the backend api version.
|
||||
if cls._sessions_created <= 0:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cls()
|
||||
except Exception:
|
||||
pass
|
||||
if cls._offline_mode:
|
||||
cls.api_version = cls._offline_default_version
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cls()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return version_tuple(cls.api_version) >= version_tuple(str(min_api_version))
|
||||
|
||||
|
@ -7,7 +7,7 @@ from ..backend_api import Session, CallResult
|
||||
from ..backend_api.session.session import MaxRequestSizeError
|
||||
from ..backend_api.session.response import ResponseMeta
|
||||
from ..backend_api.session import BatchRequest
|
||||
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_OFFLINE_MODE
|
||||
|
||||
from ..config import config_obj
|
||||
from ..config.defs import LOG_LEVEL_ENV_VAR
|
||||
@ -19,6 +19,7 @@ class InterfaceBase(SessionInterface):
|
||||
""" Base class for a backend manager class """
|
||||
_default_session = None
|
||||
_num_retry_warning_display = 1
|
||||
_offline_mode = ENV_OFFLINE_MODE.get()
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
@ -44,6 +45,9 @@ class InterfaceBase(SessionInterface):
|
||||
@classmethod
|
||||
def _send(cls, session, req, ignore_errors=False, raise_on_errors=True, log=None, async_enable=False):
|
||||
""" Convenience send() method providing a standardized error reporting """
|
||||
if cls._offline_mode:
|
||||
return None
|
||||
|
||||
num_retries = 0
|
||||
while True:
|
||||
error_msg = ''
|
||||
@ -151,7 +155,7 @@ class IdObjectBase(InterfaceBase):
|
||||
pass
|
||||
|
||||
def reload(self):
|
||||
if not self.id:
|
||||
if not self.id and not self._offline_mode:
|
||||
raise ValueError('Failed reloading %s: missing id' % type(self).__name__)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
|
@ -18,7 +18,7 @@ class StdStreamPatch(object):
|
||||
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
|
||||
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
|
||||
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
|
||||
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
|
||||
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if StdStreamPatch._stdout_original_write is None:
|
||||
@ -70,7 +70,7 @@ class StdStreamPatch(object):
|
||||
pass
|
||||
|
||||
elif DevWorker.report_stdout and not running_remotely():
|
||||
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
|
||||
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
|
||||
if StdStreamPatch._stdout_proxy:
|
||||
StdStreamPatch._stdout_proxy.connect(logger)
|
||||
if StdStreamPatch._stderr_proxy:
|
||||
|
@ -1,4 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from logging import warning
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from multiprocessing import Lock
|
||||
from time import time
|
||||
@ -25,6 +28,7 @@ class Metrics(InterfaceBase):
|
||||
_file_upload_retries = 3
|
||||
_upload_pool = None
|
||||
_file_upload_pool = None
|
||||
__offline_filename = 'metrics.jsonl'
|
||||
|
||||
@property
|
||||
def storage_key_prefix(self):
|
||||
@ -43,14 +47,19 @@ class Metrics(InterfaceBase):
|
||||
finally:
|
||||
self._storage_lock.release()
|
||||
|
||||
def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None):
|
||||
def __init__(self, session, task, storage_uri, storage_uri_suffix='metrics', iteration_offset=0, log=None):
|
||||
super(Metrics, self).__init__(session, log=log)
|
||||
self._task_id = task_id
|
||||
self._task_id = task.id
|
||||
self._task_iteration_offset = iteration_offset
|
||||
self._storage_uri = storage_uri.rstrip('/') if storage_uri else None
|
||||
self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None
|
||||
self._file_related_event_time = None
|
||||
self._file_upload_time = None
|
||||
self._offline_log_filename = None
|
||||
if self._offline_mode:
|
||||
offline_folder = Path(task.get_offline_mode_folder())
|
||||
offline_folder.mkdir(parents=True, exist_ok=True)
|
||||
self._offline_log_filename = offline_folder / self.__offline_filename
|
||||
|
||||
def write_events(self, events, async_enable=True, callback=None, **kwargs):
|
||||
"""
|
||||
@ -167,6 +176,7 @@ class Metrics(InterfaceBase):
|
||||
e.set_exception(exp)
|
||||
e.stream.close()
|
||||
if e.delete_local_file:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
Path(e.delete_local_file).unlink()
|
||||
except Exception:
|
||||
@ -199,6 +209,11 @@ class Metrics(InterfaceBase):
|
||||
_events = [ev.get_api_event() for ev in good_events]
|
||||
batched_requests = [api_events.AddRequest(event=ev) for ev in _events if ev]
|
||||
if batched_requests:
|
||||
if self._offline_mode:
|
||||
with open(self._offline_log_filename.as_posix(), 'at') as f:
|
||||
f.write(json.dumps([b.to_dict() for b in batched_requests])+'\n')
|
||||
return
|
||||
|
||||
req = api_events.AddBatchRequest(requests=batched_requests)
|
||||
return self.send(req, raise_on_errors=False)
|
||||
|
||||
@ -234,3 +249,69 @@ class Metrics(InterfaceBase):
|
||||
pool.join()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def report_offline_session(cls, task, folder):
|
||||
from ... import StorageManager
|
||||
filename = Path(folder) / cls.__offline_filename
|
||||
if not filename.is_file():
|
||||
return False
|
||||
# noinspection PyProtectedMember
|
||||
remote_url = task._get_default_report_storage_uri()
|
||||
if remote_url and remote_url.endswith('/'):
|
||||
remote_url = remote_url[:-1]
|
||||
uploaded_files = set()
|
||||
task_id = task.id
|
||||
with open(filename, 'rt') as f:
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
list_requests = json.loads(line)
|
||||
for r in list_requests:
|
||||
org_task_id = r['task']
|
||||
r['task'] = task_id
|
||||
if r.get('key') and r.get('url'):
|
||||
debug_sample = (Path(folder) / 'data').joinpath(*(r['key'].split('/')))
|
||||
r['key'] = r['key'].replace(
|
||||
'.{}{}'.format(org_task_id, os.sep), '.{}{}'.format(task_id, os.sep), 1)
|
||||
r['url'] = '{}/{}'.format(remote_url, r['key'])
|
||||
if debug_sample not in uploaded_files and debug_sample.is_file():
|
||||
uploaded_files.add(debug_sample)
|
||||
StorageManager.upload_file(local_file=debug_sample.as_posix(), remote_url=r['url'])
|
||||
elif r.get('plot_str'):
|
||||
# hack plotly embedded images links
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
task_id_sep = '.{}{}'.format(org_task_id, os.sep)
|
||||
plot = json.loads(r['plot_str'])
|
||||
if plot.get('layout', {}).get('images'):
|
||||
for image in plot['layout']['images']:
|
||||
if task_id_sep not in image['source']:
|
||||
continue
|
||||
pre, post = image['source'].split(task_id_sep, 1)
|
||||
pre = os.sep.join(pre.split(os.sep)[-2:])
|
||||
debug_sample = (Path(folder) / 'data').joinpath(
|
||||
pre+'.{}'.format(org_task_id), post)
|
||||
image['source'] = '/'.join(
|
||||
[remote_url, pre + '.{}'.format(task_id), post])
|
||||
if debug_sample not in uploaded_files and debug_sample.is_file():
|
||||
uploaded_files.add(debug_sample)
|
||||
StorageManager.upload_file(
|
||||
local_file=debug_sample.as_posix(), remote_url=image['source'])
|
||||
r['plot_str'] = json.dumps(plot)
|
||||
except Exception:
|
||||
pass
|
||||
i += 1
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as ex:
|
||||
warning('Failed reporting metric, line {} [{}]'.format(i, ex))
|
||||
batch_requests = api_events.AddBatchRequest(requests=list_requests)
|
||||
res = task.session.send(batch_requests)
|
||||
if res and not res.ok():
|
||||
warning("failed logging metric task to backend ({:d} lines, {})".format(
|
||||
len(batch_requests.requests), str(res.meta)))
|
||||
return True
|
||||
|
@ -659,7 +659,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
|
||||
# Hack: if the url doesn't start with http/s then the plotly will not be able to show it,
|
||||
# then we put the link under images not plots
|
||||
if not url.startswith('http'):
|
||||
if not url.startswith('http') and not self._offline_mode:
|
||||
return self.report_image_and_upload(title=title, series=series, iter=iter, path=path, image=matrix,
|
||||
upload_uri=upload_uri, max_image_history=max_image_history)
|
||||
|
||||
|
@ -71,6 +71,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
def _reload(self):
|
||||
""" Reload the model object """
|
||||
if self._offline_mode:
|
||||
return models.Model()
|
||||
|
||||
if self.id == self._EMPTY_MODEL_ID:
|
||||
return
|
||||
res = self.send(models.GetByIdRequest(model=self.id))
|
||||
@ -186,35 +189,40 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
task = task_id or self.data.task
|
||||
project = project_id or self.data.project
|
||||
parent = parent_id or self.data.parent
|
||||
if tags:
|
||||
extra = {'system_tags': tags or self.data.system_tags} \
|
||||
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
|
||||
else:
|
||||
extra = {}
|
||||
|
||||
self.send(models.EditRequest(
|
||||
self._edit(
|
||||
model=self.id,
|
||||
uri=uri,
|
||||
name=name,
|
||||
comment=comment,
|
||||
labels=labels,
|
||||
design=design,
|
||||
framework=framework or self.data.framework,
|
||||
iteration=iteration,
|
||||
task=task,
|
||||
project=project,
|
||||
parent=parent,
|
||||
framework=framework or self.data.framework,
|
||||
iteration=iteration,
|
||||
**extra
|
||||
))
|
||||
self.reload()
|
||||
)
|
||||
|
||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None):
|
||||
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
||||
uri=uri, framework=framework, iteration=iteration)
|
||||
|
||||
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None, **extra):
|
||||
def offline_store(**kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self.data, k, v or getattr(self.data, k, None))
|
||||
return
|
||||
if self._offline_mode:
|
||||
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
||||
uri=uri, framework=framework, iteration=iteration, **extra)
|
||||
|
||||
if tags:
|
||||
extra = {'system_tags': tags or self.data.system_tags} \
|
||||
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
|
||||
else:
|
||||
extra = {}
|
||||
extra.update({'system_tags': tags or self.data.system_tags}
|
||||
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags})
|
||||
|
||||
self.send(models.EditRequest(
|
||||
model=self.id,
|
||||
uri=uri,
|
||||
@ -298,7 +306,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
override_model_id=override_model_id, **extra))
|
||||
if self.id is None:
|
||||
# update the model id. in case it was just created, this will trigger a reload of the model object
|
||||
self.id = res.response.id
|
||||
self.id = res.response.id if res else None
|
||||
else:
|
||||
self.reload()
|
||||
try:
|
||||
|
@ -306,10 +306,11 @@ class _Arguments(object):
|
||||
# TODO: add dict prefix
|
||||
prefix = prefix or '' # self._prefix_dict
|
||||
if prefix:
|
||||
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
|
||||
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
|
||||
cur_params.update(prefix_dictionary)
|
||||
self._task.set_parameters(cur_params)
|
||||
with self._task._edit_lock:
|
||||
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
|
||||
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
|
||||
cur_params.update(prefix_dictionary)
|
||||
self._task.set_parameters(cur_params)
|
||||
else:
|
||||
self._task.update_parameters(dictionary)
|
||||
if not isinstance(dictionary, self._ProxyDictWrite):
|
||||
|
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord
|
||||
from pathlib2 import Path
|
||||
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord, warning
|
||||
from logging.handlers import BufferingHandler
|
||||
from threading import Thread, Event
|
||||
from six.moves.queue import Queue
|
||||
@ -17,6 +19,7 @@ class TaskHandler(BufferingHandler):
|
||||
__wait_for_flush_timeout = 10.
|
||||
__max_event_size = 1024 * 1024
|
||||
__once = False
|
||||
__offline_filename = 'log.jsonl'
|
||||
|
||||
@property
|
||||
def task_id(self):
|
||||
@ -26,10 +29,10 @@ class TaskHandler(BufferingHandler):
|
||||
def task_id(self, value):
|
||||
self._task_id = value
|
||||
|
||||
def __init__(self, session, task_id, capacity=buffer_capacity):
|
||||
def __init__(self, task, capacity=buffer_capacity):
|
||||
super(TaskHandler, self).__init__(capacity)
|
||||
self.task_id = task_id
|
||||
self.session = session
|
||||
self.task_id = task.id
|
||||
self.session = task.session
|
||||
self.last_timestamp = 0
|
||||
self.counter = 1
|
||||
self._last_event = None
|
||||
@ -37,6 +40,11 @@ class TaskHandler(BufferingHandler):
|
||||
self._queue = None
|
||||
self._thread = None
|
||||
self._pending = 0
|
||||
self._offline_log_filename = None
|
||||
if task.is_offline():
|
||||
offline_folder = Path(task.get_offline_mode_folder())
|
||||
offline_folder.mkdir(parents=True, exist_ok=True)
|
||||
self._offline_log_filename = offline_folder / self.__offline_filename
|
||||
|
||||
def shouldFlush(self, record):
|
||||
"""
|
||||
@ -124,6 +132,7 @@ class TaskHandler(BufferingHandler):
|
||||
if not self.buffer:
|
||||
return
|
||||
|
||||
buffer = None
|
||||
self.acquire()
|
||||
if self.buffer:
|
||||
buffer = self.buffer
|
||||
@ -133,6 +142,7 @@ class TaskHandler(BufferingHandler):
|
||||
if not buffer:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
record_events = [r for record in buffer for r in self._record_to_event(record)] + [self._last_event]
|
||||
self._last_event = None
|
||||
@ -194,11 +204,17 @@ class TaskHandler(BufferingHandler):
|
||||
|
||||
def _send_events(self, a_request):
|
||||
try:
|
||||
self._pending -= 1
|
||||
|
||||
if self._offline_log_filename:
|
||||
with open(self._offline_log_filename.as_posix(), 'at') as f:
|
||||
f.write(json.dumps([b.to_dict() for b in a_request.requests]) + '\n')
|
||||
return
|
||||
|
||||
# if self._thread is None:
|
||||
# self.__log_stderr('Task.close() flushing remaining logs ({})'.format(self._pending))
|
||||
self._pending -= 1
|
||||
res = self.session.send(a_request)
|
||||
if not res.ok():
|
||||
if res and not res.ok():
|
||||
self.__log_stderr("failed logging task to backend ({:d} lines, {})".format(
|
||||
len(a_request.requests), str(res.meta)), level=WARNING)
|
||||
except MaxRequestSizeError:
|
||||
@ -237,3 +253,31 @@ class TaskHandler(BufferingHandler):
|
||||
write('{asctime} - {name} - {levelname} - {message}\n'.format(
|
||||
asctime=Formatter().formatTime(makeLogRecord({})),
|
||||
name='trains.log', levelname=getLevelName(level), message=msg))
|
||||
|
||||
@classmethod
|
||||
def report_offline_session(cls, task, folder):
|
||||
filename = Path(folder) / cls.__offline_filename
|
||||
if not filename.is_file():
|
||||
return False
|
||||
with open(filename, 'rt') as f:
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
list_requests = json.loads(line)
|
||||
for r in list_requests:
|
||||
r.pop('task', None)
|
||||
i += 1
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as ex:
|
||||
warning('Failed reporting log, line {} [{}]'.format(i, ex))
|
||||
batch_requests = events.AddBatchRequest(
|
||||
requests=[events.TaskLogEvent(task=task.id, **r) for r in list_requests])
|
||||
res = task.session.send(batch_requests)
|
||||
if res and not res.ok():
|
||||
warning("failed logging task to backend ({:d} lines, {})".format(
|
||||
len(batch_requests.requests), str(res.meta)))
|
||||
return True
|
||||
|
@ -1,5 +1,6 @@
|
||||
""" Backend task management support """
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -7,8 +8,10 @@ import re
|
||||
from enum import Enum
|
||||
from tempfile import gettempdir
|
||||
from multiprocessing import RLock
|
||||
from pathlib2 import Path
|
||||
from threading import Thread
|
||||
from typing import Optional, Any, Sequence, Callable, Mapping, Union, List
|
||||
from uuid import uuid4
|
||||
|
||||
try:
|
||||
# noinspection PyCompatibility
|
||||
@ -25,17 +28,18 @@ from ...binding.artifacts import Artifacts
|
||||
from ...backend_interface.task.development.worker import DevWorker
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import tasks, models, events, projects
|
||||
from pathlib2 import Path
|
||||
from ...backend_api.session.defs import ENV_OFFLINE_MODE
|
||||
from ...utilities.pyhocon import ConfigTree, ConfigFactory
|
||||
|
||||
from ..base import IdObjectBase
|
||||
from ..base import IdObjectBase, InterfaceBase
|
||||
from ..metrics import Metrics, Reporter
|
||||
from ..model import Model
|
||||
from ..setupuploadmixin import SetupUploadMixin
|
||||
from ..util import make_message, get_or_create_project, get_single_result, \
|
||||
exact_match_regex
|
||||
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
|
||||
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR
|
||||
from ...config import (
|
||||
get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend,
|
||||
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir)
|
||||
from ...debugging import get_logger
|
||||
from ...debugging.log import LoggerRoot
|
||||
from ...storage.helper import StorageHelper, StorageError
|
||||
@ -56,6 +60,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
_force_requirements = {}
|
||||
|
||||
_store_diff = config.get('development.store_uncommitted_code_diff', False)
|
||||
_offline_filename = 'task.json'
|
||||
|
||||
class TaskTypes(Enum):
|
||||
def __str__(self):
|
||||
@ -143,6 +148,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if not task_id:
|
||||
# generate a new task
|
||||
self.id = self._auto_generate(project_name=project_name, task_name=task_name, task_type=task_type)
|
||||
if self._offline_mode:
|
||||
self.data.id = self.id
|
||||
self.name = task_name
|
||||
else:
|
||||
# this is an existing task, let's try to verify stuff
|
||||
self._validate()
|
||||
@ -195,7 +203,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
|
||||
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
|
||||
# handler instance handled them)
|
||||
backend_handler = TaskHandler(self.session, self.task_id)
|
||||
backend_handler = TaskHandler(task=self)
|
||||
|
||||
# Add backend handler to both loggers:
|
||||
# 1. to root logger root logger
|
||||
@ -280,11 +288,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# to the module call including all argv's
|
||||
result.script = ScriptInfo.detect_running_module(result.script)
|
||||
|
||||
self.data.script = result.script
|
||||
# Since we might run asynchronously, don't use self.data (let someone else
|
||||
# overwrite it before we have a chance to call edit)
|
||||
self._edit(script=result.script)
|
||||
self.reload()
|
||||
with self._edit_lock:
|
||||
self.reload()
|
||||
self.data.script = result.script
|
||||
self._edit(script=result.script)
|
||||
|
||||
# if jupyter is present, requirements will be created in the background, when saving a snapshot
|
||||
if result.script and script_requirements:
|
||||
entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \
|
||||
@ -304,7 +314,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
result.script['requirements']['conda'] = conda_requirements
|
||||
|
||||
self._update_requirements(result.script.get('requirements') or '')
|
||||
self.reload()
|
||||
|
||||
# we do not want to wait for the check version thread,
|
||||
# because someone might wait for us to finish the repo detection update
|
||||
@ -339,7 +348,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
)
|
||||
res = self.send(req)
|
||||
|
||||
return res.response.id
|
||||
return res.response.id if res else 'offline-{}'.format(str(uuid4()).replace("-", ""))
|
||||
|
||||
def _set_storage_uri(self, value):
|
||||
value = value.rstrip('/') if value else None
|
||||
@ -498,7 +507,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if self._metrics_manager is None:
|
||||
self._metrics_manager = Metrics(
|
||||
session=self.session,
|
||||
task_id=self.id,
|
||||
task=self,
|
||||
storage_uri=storage_uri,
|
||||
storage_uri_suffix=self._get_output_destination_suffix('metrics'),
|
||||
iteration_offset=self.get_initial_iteration()
|
||||
@ -523,6 +532,27 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# type: () -> Any
|
||||
""" Reload the task object from the backend """
|
||||
with self._edit_lock:
|
||||
if self._offline_mode:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with open(self.get_offline_mode_folder() / self._offline_filename, 'rt') as f:
|
||||
stored_dict = json.load(f)
|
||||
stored_data = tasks.Task(**stored_dict)
|
||||
# add missing entries
|
||||
for k, v in stored_dict.items():
|
||||
if not hasattr(stored_data, k):
|
||||
setattr(stored_data, k, v)
|
||||
if stored_dict.get('project_name'):
|
||||
self._project_name = (None, stored_dict.get('project_name'))
|
||||
except Exception:
|
||||
stored_data = self._data
|
||||
|
||||
return stored_data or tasks.Task(
|
||||
execution=tasks.Execution(
|
||||
parameters={}, artifacts=[], dataviews=[], model='',
|
||||
model_desc={}, model_labels={}, docker_cmd=''),
|
||||
output=tasks.Output())
|
||||
|
||||
if self._reload_skip_flag and self._data:
|
||||
return self._data
|
||||
res = self.send(tasks.GetByIdRequest(task=self.id))
|
||||
@ -774,7 +804,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
|
||||
execution = self.data.execution
|
||||
if execution is None:
|
||||
execution = tasks.Execution(parameters=parameters)
|
||||
execution = tasks.Execution(
|
||||
parameters=parameters, artifacts=[], dataviews=[], model='',
|
||||
model_desc={}, model_labels={}, docker_cmd='')
|
||||
else:
|
||||
execution.parameters = parameters
|
||||
self._edit(execution=execution)
|
||||
@ -940,7 +972,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def get_project_name(self):
|
||||
# type: () -> Optional[str]
|
||||
if self.project is None:
|
||||
return None
|
||||
return self._project_name[1] if self._project_name and len(self._project_name) > 1 else None
|
||||
|
||||
if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project:
|
||||
return self._project_name[1]
|
||||
@ -1232,12 +1264,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
|
||||
def _get_default_report_storage_uri(self):
|
||||
# type: () -> str
|
||||
if self._offline_mode:
|
||||
return str(self.get_offline_mode_folder() / 'data')
|
||||
|
||||
if not self._files_server:
|
||||
self._files_server = Session.get_files_server_host()
|
||||
return self._files_server
|
||||
|
||||
def _get_status(self):
|
||||
# type: () -> (Optional[str], Optional[str])
|
||||
if self._offline_mode:
|
||||
return tasks.TaskStatusEnum.created, 'offline'
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
all_tasks = self.send(
|
||||
@ -1296,6 +1334,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def _edit(self, **kwargs):
|
||||
# type: (**Any) -> Any
|
||||
with self._edit_lock:
|
||||
if self._offline_mode:
|
||||
for k, v in kwargs.items():
|
||||
setattr(self.data, k, v)
|
||||
Path(self.get_offline_mode_folder()).mkdir(parents=True, exist_ok=True)
|
||||
with open((self.get_offline_mode_folder() / self._offline_filename).as_posix(), 'wt') as f:
|
||||
export_data = self.data.to_dict()
|
||||
export_data['project_name'] = self.get_project_name()
|
||||
export_data['offline_folder'] = self.get_offline_mode_folder().as_posix()
|
||||
json.dump(export_data, f, ensure_ascii=True, sort_keys=True)
|
||||
return None
|
||||
|
||||
# Since we ae using forced update, make sure he task status is valid
|
||||
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
|
||||
if status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
|
||||
@ -1315,15 +1364,32 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# protection, Old API might not support it
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.data.script.requirements = requirements
|
||||
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
|
||||
with self._edit_lock:
|
||||
self.reload()
|
||||
self.data.script.requirements = requirements
|
||||
if self._offline_mode:
|
||||
self._edit(script=self.data.script)
|
||||
else:
|
||||
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_script(self, script):
|
||||
# type: (dict) -> ()
|
||||
self.data.script = script
|
||||
self._edit(script=script)
|
||||
with self._edit_lock:
|
||||
self.reload()
|
||||
self.data.script = script
|
||||
self._edit(script=script)
|
||||
|
||||
def get_offline_mode_folder(self):
|
||||
# type: () -> (Optional[Path])
|
||||
"""
|
||||
Return the folder where all the task outputs and logs are stored in the offline session.
|
||||
:return: Path object, local folder, later to be used with `report_offline_session()`
|
||||
"""
|
||||
if not self._offline_mode:
|
||||
return None
|
||||
return get_offline_dir(task_id=self.task_id)
|
||||
|
||||
@classmethod
|
||||
def _clone_task(
|
||||
@ -1475,13 +1541,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if not PROC_MASTER_ID_ENV_VAR.get() or len(PROC_MASTER_ID_ENV_VAR.get().split(':')) < 2:
|
||||
self.__edit_lock = RLock()
|
||||
elif PROC_MASTER_ID_ENV_VAR.get().split(':')[1] == str(self.id):
|
||||
# remove previous file lock instance, just in case.
|
||||
filename = os.path.join(gettempdir(), 'trains_{}.lock'.format(self.id))
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.unlink(filename)
|
||||
except Exception:
|
||||
pass
|
||||
# no need to remove previous file lock if we have a dead process, it will automatically release the lock.
|
||||
# # noinspection PyBroadException
|
||||
# try:
|
||||
# os.unlink(filename)
|
||||
# except Exception:
|
||||
# pass
|
||||
# create a new file based lock
|
||||
self.__edit_lock = FileRLock(filename=filename)
|
||||
else:
|
||||
@ -1523,3 +1589,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \
|
||||
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
|
||||
return is_subprocess
|
||||
|
||||
@classmethod
|
||||
def set_offline(cls, offline_mode=False):
|
||||
# type: (bool) -> ()
|
||||
"""
|
||||
Set offline mode, where all data and logs are stored into local folder, for later transmission
|
||||
|
||||
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
|
||||
:return:
|
||||
"""
|
||||
ENV_OFFLINE_MODE.set(offline_mode)
|
||||
InterfaceBase._offline_mode = bool(offline_mode)
|
||||
Session._offline_mode = bool(offline_mode)
|
||||
|
||||
@classmethod
|
||||
def is_offline(cls):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
|
||||
|
||||
:return: boolean offline-mode state
|
||||
"""
|
||||
return cls._offline_mode
|
||||
|
@ -52,6 +52,8 @@ def make_message(s, **kwargs):
|
||||
|
||||
def get_or_create_project(session, project_name, description=None):
|
||||
res = session.send(projects.GetAllRequest(name=exact_match_regex(project_name)))
|
||||
if not res:
|
||||
return None
|
||||
if res.response.projects:
|
||||
return res.response.projects[0].id
|
||||
res = session.send(projects.CreateRequest(name=project_name, description=description))
|
||||
|
@ -19,13 +19,21 @@ def get_cache_dir():
|
||||
cache_base_dir = Path( # noqa: F405
|
||||
expandvars(
|
||||
expanduser(
|
||||
TRAINS_CACHE_DIR.get() or config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR # noqa: F405
|
||||
TRAINS_CACHE_DIR.get() or
|
||||
config.get("storage.cache.default_base_dir") or
|
||||
DEFAULT_CACHE_DIR # noqa: F405
|
||||
)
|
||||
)
|
||||
)
|
||||
return cache_base_dir
|
||||
|
||||
|
||||
def get_offline_dir(task_id=None):
|
||||
if not task_id:
|
||||
return get_cache_dir() / 'offline'
|
||||
return get_cache_dir() / 'offline' / task_id
|
||||
|
||||
|
||||
def get_config_for_bucket(base_url, extra_configurations=None):
|
||||
config_list = S3BucketConfigurations.from_config(config.get("aws.s3"))
|
||||
|
||||
|
@ -982,6 +982,7 @@ class Logger(object):
|
||||
|
||||
For example: ``s3://bucket/directory/``, or ``file:///tmp/debug/``.
|
||||
"""
|
||||
# noinspection PyProtectedMember
|
||||
return self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
|
||||
def flush(self):
|
||||
|
123
trains/task.py
123
trains/task.py
@ -1,12 +1,14 @@
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from tempfile import mkstemp
|
||||
|
||||
from tempfile import mkstemp, mkdtemp
|
||||
from zipfile import ZipFile, ZIP_DEFLATED
|
||||
|
||||
try:
|
||||
# noinspection PyCompatibility
|
||||
@ -25,6 +27,7 @@ from .backend_api.session.session import Session, ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||
from .backend_interface.metrics import Metrics
|
||||
from .backend_interface.model import Model as BackendModel
|
||||
from .backend_interface.task import Task as _Task
|
||||
from .backend_interface.task.log import TaskHandler
|
||||
from .backend_interface.task.development.worker import DevWorker
|
||||
from .backend_interface.task.repo import ScriptInfo
|
||||
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
|
||||
@ -446,7 +449,9 @@ class Task(_Task):
|
||||
not auto_connect_frameworks.get('detect_repository', True)) else True
|
||||
)
|
||||
# set defaults
|
||||
if output_uri:
|
||||
if cls._offline_mode:
|
||||
task.output_uri = None
|
||||
elif output_uri:
|
||||
task.output_uri = output_uri
|
||||
elif cls.__default_output_uri:
|
||||
task.output_uri = cls.__default_output_uri
|
||||
@ -530,9 +535,11 @@ class Task(_Task):
|
||||
logger = task.get_logger()
|
||||
# show the debug metrics page in the log, it is very convenient
|
||||
if not is_sub_process_task_id:
|
||||
logger.report_text(
|
||||
'TRAINS results page: {}'.format(task.get_output_log_web_page()),
|
||||
)
|
||||
if cls._offline_mode:
|
||||
logger.report_text('TRAINS running in offline mode, session stored in {}'.format(
|
||||
task.get_offline_mode_folder()))
|
||||
else:
|
||||
logger.report_text('TRAINS results page: {}'.format(task.get_output_log_web_page()))
|
||||
# Make sure we start the dev worker if required, otherwise it will only be started when we write
|
||||
# something to the log.
|
||||
task._dev_mode_task_start()
|
||||
@ -1362,7 +1369,7 @@ class Task(_Task):
|
||||
:return: The last reported iteration number.
|
||||
"""
|
||||
self._reload_last_iteration()
|
||||
return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0)
|
||||
return max(self.data.last_iteration or 0, self._reporter.max_iteration if self._reporter else 0)
|
||||
|
||||
def set_initial_iteration(self, offset=0):
|
||||
# type: (int) -> int
|
||||
@ -1570,11 +1577,11 @@ class Task(_Task):
|
||||
:param task_data: dictionary with full Task configuration
|
||||
:return: return True if Task update was successful
|
||||
"""
|
||||
return self.import_task(task_data=task_data, target_task=self, update=True)
|
||||
return bool(self.import_task(task_data=task_data, target_task=self, update=True))
|
||||
|
||||
@classmethod
|
||||
def import_task(cls, task_data, target_task=None, update=False):
|
||||
# type: (dict, Optional[Union[str, Task]], bool) -> bool
|
||||
# type: (dict, Optional[Union[str, Task]], bool) -> Optional[Task]
|
||||
"""
|
||||
Import (create) Task from previously exported Task configuration (see Task.export_task)
|
||||
Can also be used to edit/update an existing Task (by passing `target_task` and `update=True`).
|
||||
@ -1595,7 +1602,7 @@ class Task(_Task):
|
||||
"received `target_task` type {}".format(type(target_task)))
|
||||
target_task.reload()
|
||||
cur_data = target_task.data.to_dict()
|
||||
cur_data = merge_dicts(cur_data, task_data) if update else task_data
|
||||
cur_data = merge_dicts(cur_data, task_data) if update else dict(**task_data)
|
||||
cur_data.pop('id', None)
|
||||
cur_data.pop('project', None)
|
||||
# noinspection PyProtectedMember
|
||||
@ -1604,8 +1611,79 @@ class Task(_Task):
|
||||
res = target_task._edit(**cur_data)
|
||||
if res and res.ok():
|
||||
target_task.reload()
|
||||
return True
|
||||
return False
|
||||
return target_task
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def import_offline_session(cls, session_folder_zip):
|
||||
# type: (str) -> (Optional[str])
|
||||
"""
|
||||
Upload an off line session (execution) of a Task.
|
||||
Full Task execution includes repository details, installed packages, artifacts, logs, metric and debug samples.
|
||||
|
||||
:param session_folder_zip: Path to a folder containing the session, or zip-file of the session folder.
|
||||
:return: Newly created task ID (str)
|
||||
"""
|
||||
print('TRAINS: Importing offline session from {}'.format(session_folder_zip))
|
||||
|
||||
temp_folder = None
|
||||
if Path(session_folder_zip).is_file():
|
||||
# unzip the file:
|
||||
temp_folder = mkdtemp(prefix='trains-offline-')
|
||||
ZipFile(session_folder_zip).extractall(path=temp_folder)
|
||||
session_folder_zip = temp_folder
|
||||
|
||||
session_folder = Path(session_folder_zip)
|
||||
if not session_folder.is_dir():
|
||||
raise ValueError("Could not find the session folder / zip-file {}".format(session_folder))
|
||||
|
||||
try:
|
||||
with open(session_folder / cls._offline_filename, 'rt') as f:
|
||||
export_data = json.load(f)
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
"Could not read Task object {}: Exception {}".format(session_folder / cls._offline_filename, ex))
|
||||
task = cls.import_task(export_data)
|
||||
task.mark_started(force=True)
|
||||
# fix artifacts
|
||||
if task.data.execution.artifacts:
|
||||
from . import StorageManager
|
||||
# noinspection PyProtectedMember
|
||||
offline_folder = os.path.join(export_data.get('offline_folder', ''), 'data/')
|
||||
|
||||
remote_url = task._get_default_report_storage_uri()
|
||||
if remote_url and remote_url.endswith('/'):
|
||||
remote_url = remote_url[:-1]
|
||||
|
||||
for artifact in task.data.execution.artifacts:
|
||||
local_path = artifact.uri.replace(offline_folder, '', 1)
|
||||
local_file = session_folder / 'data' / local_path
|
||||
if local_file.is_file():
|
||||
remote_path = local_path.replace(
|
||||
'.{}{}'.format(export_data['id'], os.sep), '.{}{}'.format(task.id, os.sep), 1)
|
||||
artifact.uri = '{}/{}'.format(remote_url, remote_path)
|
||||
StorageManager.upload_file(local_file=local_file.as_posix(), remote_url=artifact.uri)
|
||||
# noinspection PyProtectedMember
|
||||
task._edit(execution=task.data.execution)
|
||||
# logs
|
||||
TaskHandler.report_offline_session(task, session_folder)
|
||||
# metrics
|
||||
Metrics.report_offline_session(task, session_folder)
|
||||
# print imported results page
|
||||
print('TRAINS results page: {}'.format(task.get_output_log_web_page()))
|
||||
task.completed()
|
||||
# close task
|
||||
task.close()
|
||||
|
||||
# cleanup
|
||||
if temp_folder:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
shutil.rmtree(temp_folder)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return task.id
|
||||
|
||||
@classmethod
|
||||
def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None):
|
||||
@ -2099,7 +2177,7 @@ class Task(_Task):
|
||||
parent.terminate()
|
||||
|
||||
def _dev_mode_setup_worker(self, model_updated=False):
|
||||
if running_remotely() or not self.is_main_task() or self._at_exit_called:
|
||||
if running_remotely() or not self.is_main_task() or self._at_exit_called or self._offline_mode:
|
||||
return
|
||||
|
||||
if self._dev_worker:
|
||||
@ -2283,6 +2361,23 @@ class Task(_Task):
|
||||
except Exception:
|
||||
# make sure we do not interrupt the exit process
|
||||
pass
|
||||
|
||||
# make sure we store last task state
|
||||
if self._offline_mode and not is_sub_process:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# create zip file
|
||||
offline_folder = self.get_offline_mode_folder()
|
||||
zip_file = offline_folder.as_posix() + '.zip'
|
||||
with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf:
|
||||
for filename in offline_folder.rglob('*'):
|
||||
if filename.is_file():
|
||||
relative_file_name = filename.relative_to(offline_folder).as_posix()
|
||||
zf.write(filename.as_posix(), arcname=relative_file_name)
|
||||
print('TRAINS Task: Offline session stored in {}'.format(zip_file))
|
||||
except Exception as ex:
|
||||
pass
|
||||
|
||||
# delete locking object (lock file)
|
||||
if self._edit_lock:
|
||||
# noinspection PyBroadException
|
||||
@ -2597,7 +2692,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def __get_task_api_obj(cls, task_id, only_fields=None):
|
||||
if not task_id:
|
||||
if not task_id or cls._offline_mode:
|
||||
return None
|
||||
|
||||
all_tasks = cls._send(
|
||||
|
Loading…
Reference in New Issue
Block a user