Add offline support using Task.set_offline() and Task.import_offline_session()

This commit is contained in:
allegroai
2020-07-30 15:03:22 +03:00
parent 2ec5726812
commit a8d6380696
15 changed files with 427 additions and 78 deletions

View File

@@ -306,10 +306,11 @@ class _Arguments(object):
# TODO: add dict prefix
prefix = prefix or '' # self._prefix_dict
if prefix:
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
cur_params.update(prefix_dictionary)
self._task.set_parameters(cur_params)
with self._task._edit_lock:
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
cur_params.update(prefix_dictionary)
self._task.set_parameters(cur_params)
else:
self._task.update_parameters(dictionary)
if not isinstance(dictionary, self._ProxyDictWrite):

View File

@@ -1,6 +1,8 @@
import json
import sys
import time
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord
from pathlib2 import Path
from logging import LogRecord, getLogger, basicConfig, getLevelName, INFO, WARNING, Formatter, makeLogRecord, warning
from logging.handlers import BufferingHandler
from threading import Thread, Event
from six.moves.queue import Queue
@@ -17,6 +19,7 @@ class TaskHandler(BufferingHandler):
__wait_for_flush_timeout = 10.
__max_event_size = 1024 * 1024
__once = False
__offline_filename = 'log.jsonl'
@property
def task_id(self):
@@ -26,10 +29,10 @@ class TaskHandler(BufferingHandler):
def task_id(self, value):
self._task_id = value
def __init__(self, session, task_id, capacity=buffer_capacity):
def __init__(self, task, capacity=buffer_capacity):
super(TaskHandler, self).__init__(capacity)
self.task_id = task_id
self.session = session
self.task_id = task.id
self.session = task.session
self.last_timestamp = 0
self.counter = 1
self._last_event = None
@@ -37,6 +40,11 @@ class TaskHandler(BufferingHandler):
self._queue = None
self._thread = None
self._pending = 0
self._offline_log_filename = None
if task.is_offline():
offline_folder = Path(task.get_offline_mode_folder())
offline_folder.mkdir(parents=True, exist_ok=True)
self._offline_log_filename = offline_folder / self.__offline_filename
def shouldFlush(self, record):
"""
@@ -124,6 +132,7 @@ class TaskHandler(BufferingHandler):
if not self.buffer:
return
buffer = None
self.acquire()
if self.buffer:
buffer = self.buffer
@@ -133,6 +142,7 @@ class TaskHandler(BufferingHandler):
if not buffer:
return
# noinspection PyBroadException
try:
record_events = [r for record in buffer for r in self._record_to_event(record)] + [self._last_event]
self._last_event = None
@@ -194,11 +204,17 @@ class TaskHandler(BufferingHandler):
def _send_events(self, a_request):
try:
self._pending -= 1
if self._offline_log_filename:
with open(self._offline_log_filename.as_posix(), 'at') as f:
f.write(json.dumps([b.to_dict() for b in a_request.requests]) + '\n')
return
# if self._thread is None:
# self.__log_stderr('Task.close() flushing remaining logs ({})'.format(self._pending))
self._pending -= 1
res = self.session.send(a_request)
if not res.ok():
if res and not res.ok():
self.__log_stderr("failed logging task to backend ({:d} lines, {})".format(
len(a_request.requests), str(res.meta)), level=WARNING)
except MaxRequestSizeError:
@@ -237,3 +253,31 @@ class TaskHandler(BufferingHandler):
write('{asctime} - {name} - {levelname} - {message}\n'.format(
asctime=Formatter().formatTime(makeLogRecord({})),
name='trains.log', levelname=getLevelName(level), message=msg))
@classmethod
def report_offline_session(cls, task, folder):
filename = Path(folder) / cls.__offline_filename
if not filename.is_file():
return False
with open(filename, 'rt') as f:
i = 0
while True:
try:
line = f.readline()
if not line:
break
list_requests = json.loads(line)
for r in list_requests:
r.pop('task', None)
i += 1
except StopIteration:
break
except Exception as ex:
warning('Failed reporting log, line {} [{}]'.format(i, ex))
batch_requests = events.AddBatchRequest(
requests=[events.TaskLogEvent(task=task.id, **r) for r in list_requests])
res = task.session.send(batch_requests)
if res and not res.ok():
warning("failed logging task to backend ({:d} lines, {})".format(
len(batch_requests.requests), str(res.meta)))
return True

View File

@@ -1,5 +1,6 @@
""" Backend task management support """
import itertools
import json
import logging
import os
import sys
@@ -7,8 +8,10 @@ import re
from enum import Enum
from tempfile import gettempdir
from multiprocessing import RLock
from pathlib2 import Path
from threading import Thread
from typing import Optional, Any, Sequence, Callable, Mapping, Union, List
from uuid import uuid4
try:
# noinspection PyCompatibility
@@ -25,17 +28,18 @@ from ...binding.artifacts import Artifacts
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
from pathlib2 import Path
from ...backend_api.session.defs import ENV_OFFLINE_MODE
from ...utilities.pyhocon import ConfigTree, ConfigFactory
from ..base import IdObjectBase
from ..base import IdObjectBase, InterfaceBase
from ..metrics import Metrics, Reporter
from ..model import Model
from ..setupuploadmixin import SetupUploadMixin
from ..util import make_message, get_or_create_project, get_single_result, \
exact_match_regex
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR
from ...config import (
get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend,
running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir)
from ...debugging import get_logger
from ...debugging.log import LoggerRoot
from ...storage.helper import StorageHelper, StorageError
@@ -56,6 +60,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_force_requirements = {}
_store_diff = config.get('development.store_uncommitted_code_diff', False)
_offline_filename = 'task.json'
class TaskTypes(Enum):
def __str__(self):
@@ -143,6 +148,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not task_id:
# generate a new task
self.id = self._auto_generate(project_name=project_name, task_name=task_name, task_type=task_type)
if self._offline_mode:
self.data.id = self.id
self.name = task_name
else:
# this is an existing task, let's try to verify stuff
self._validate()
@@ -195,7 +203,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
# handler instance handled them)
backend_handler = TaskHandler(self.session, self.task_id)
backend_handler = TaskHandler(task=self)
# Add backend handler to both loggers:
# 1. to root logger root logger
@@ -280,11 +288,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# to the module call including all argv's
result.script = ScriptInfo.detect_running_module(result.script)
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (let someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
with self._edit_lock:
self.reload()
self.data.script = result.script
self._edit(script=result.script)
# if jupyter is present, requirements will be created in the background, when saving a snapshot
if result.script and script_requirements:
entry_point_filename = None if config.get('development.force_analyze_entire_repo', False) else \
@@ -304,7 +314,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
result.script['requirements']['conda'] = conda_requirements
self._update_requirements(result.script.get('requirements') or '')
self.reload()
# we do not want to wait for the check version thread,
# because someone might wait for us to finish the repo detection update
@@ -339,7 +348,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
)
res = self.send(req)
return res.response.id
return res.response.id if res else 'offline-{}'.format(str(uuid4()).replace("-", ""))
def _set_storage_uri(self, value):
value = value.rstrip('/') if value else None
@@ -498,7 +507,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if self._metrics_manager is None:
self._metrics_manager = Metrics(
session=self.session,
task_id=self.id,
task=self,
storage_uri=storage_uri,
storage_uri_suffix=self._get_output_destination_suffix('metrics'),
iteration_offset=self.get_initial_iteration()
@@ -523,6 +532,27 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# type: () -> Any
""" Reload the task object from the backend """
with self._edit_lock:
if self._offline_mode:
# noinspection PyBroadException
try:
with open(self.get_offline_mode_folder() / self._offline_filename, 'rt') as f:
stored_dict = json.load(f)
stored_data = tasks.Task(**stored_dict)
# add missing entries
for k, v in stored_dict.items():
if not hasattr(stored_data, k):
setattr(stored_data, k, v)
if stored_dict.get('project_name'):
self._project_name = (None, stored_dict.get('project_name'))
except Exception:
stored_data = self._data
return stored_data or tasks.Task(
execution=tasks.Execution(
parameters={}, artifacts=[], dataviews=[], model='',
model_desc={}, model_labels={}, docker_cmd=''),
output=tasks.Output())
if self._reload_skip_flag and self._data:
return self._data
res = self.send(tasks.GetByIdRequest(task=self.id))
@@ -774,7 +804,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
execution = self.data.execution
if execution is None:
execution = tasks.Execution(parameters=parameters)
execution = tasks.Execution(
parameters=parameters, artifacts=[], dataviews=[], model='',
model_desc={}, model_labels={}, docker_cmd='')
else:
execution.parameters = parameters
self._edit(execution=execution)
@@ -940,7 +972,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def get_project_name(self):
# type: () -> Optional[str]
if self.project is None:
return None
return self._project_name[1] if self._project_name and len(self._project_name) > 1 else None
if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project:
return self._project_name[1]
@@ -1232,12 +1264,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _get_default_report_storage_uri(self):
# type: () -> str
if self._offline_mode:
return str(self.get_offline_mode_folder() / 'data')
if not self._files_server:
self._files_server = Session.get_files_server_host()
return self._files_server
def _get_status(self):
# type: () -> (Optional[str], Optional[str])
if self._offline_mode:
return tasks.TaskStatusEnum.created, 'offline'
# noinspection PyBroadException
try:
all_tasks = self.send(
@@ -1296,6 +1334,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _edit(self, **kwargs):
# type: (**Any) -> Any
with self._edit_lock:
if self._offline_mode:
for k, v in kwargs.items():
setattr(self.data, k, v)
Path(self.get_offline_mode_folder()).mkdir(parents=True, exist_ok=True)
with open((self.get_offline_mode_folder() / self._offline_filename).as_posix(), 'wt') as f:
export_data = self.data.to_dict()
export_data['project_name'] = self.get_project_name()
export_data['offline_folder'] = self.get_offline_mode_folder().as_posix()
json.dump(export_data, f, ensure_ascii=True, sort_keys=True)
return None
# Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
if status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
@@ -1315,15 +1364,32 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# protection, Old API might not support it
# noinspection PyBroadException
try:
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
with self._edit_lock:
self.reload()
self.data.script.requirements = requirements
if self._offline_mode:
self._edit(script=self.data.script)
else:
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
except Exception:
pass
def _update_script(self, script):
# type: (dict) -> ()
self.data.script = script
self._edit(script=script)
with self._edit_lock:
self.reload()
self.data.script = script
self._edit(script=script)
def get_offline_mode_folder(self):
# type: () -> (Optional[Path])
"""
Return the folder where all the task outputs and logs are stored in the offline session.
:return: Path object, local folder, later to be used with `report_offline_session()`
"""
if not self._offline_mode:
return None
return get_offline_dir(task_id=self.task_id)
@classmethod
def _clone_task(
@@ -1475,13 +1541,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not PROC_MASTER_ID_ENV_VAR.get() or len(PROC_MASTER_ID_ENV_VAR.get().split(':')) < 2:
self.__edit_lock = RLock()
elif PROC_MASTER_ID_ENV_VAR.get().split(':')[1] == str(self.id):
# remove previous file lock instance, just in case.
filename = os.path.join(gettempdir(), 'trains_{}.lock'.format(self.id))
# noinspection PyBroadException
try:
os.unlink(filename)
except Exception:
pass
# no need to remove previous file lock if we have a dead process, it will automatically release the lock.
# # noinspection PyBroadException
# try:
# os.unlink(filename)
# except Exception:
# pass
# create a new file based lock
self.__edit_lock = FileRLock(filename=filename)
else:
@@ -1523,3 +1589,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())
return is_subprocess
@classmethod
def set_offline(cls, offline_mode=False):
# type: (bool) -> ()
"""
Set offline mode, where all data and logs are stored into local folder, for later transmission
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
:return:
"""
ENV_OFFLINE_MODE.set(offline_mode)
InterfaceBase._offline_mode = bool(offline_mode)
Session._offline_mode = bool(offline_mode)
@classmethod
def is_offline(cls):
# type: () -> bool
"""
Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
:return: boolean offline-mode state
"""
return cls._offline_mode