From 3bc1ec23625d09e7184119c94ef8342e4a608d07 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 19 Aug 2019 21:17:53 +0300 Subject: [PATCH] Improve stability and resilience on intermittent network connection --- trains/backend_api/config/default/api.conf | 2 +- trains/backend_api/utils.py | 1 - trains/backend_interface/base.py | 19 +++++++---- trains/backend_interface/model.py | 19 ++++++++++- trains/backend_interface/task/access.py | 4 +-- .../task/development/stop_signal.py | 12 +++---- .../task/development/worker.py | 33 ++++++++++--------- trains/backend_interface/task/log.py | 20 ++++++++--- trains/backend_interface/task/task.py | 26 ++++++++------- trains/debugging/log.py | 7 +++- 10 files changed, 92 insertions(+), 51 deletions(-) diff --git a/trains/backend_api/config/default/api.conf b/trains/backend_api/config/default/api.conf index de96becc..f457a41a 100644 --- a/trains/backend_api/config/default/api.conf +++ b/trains/backend_api/config/default/api.conf @@ -34,7 +34,7 @@ # backoff parameters # timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1)) backoff_factor: 1.0 - backoff_max: 300.0 + backoff_max: 120.0 } wait_on_maintenance_forever: true diff --git a/trains/backend_api/utils.py b/trains/backend_api/utils.py index 7bf47c07..d764d059 100644 --- a/trains/backend_api/utils.py +++ b/trains/backend_api/utils.py @@ -4,7 +4,6 @@ import sys import requests from requests.adapters import HTTPAdapter -## from requests.packages.urllib3.util.retry import Retry from urllib3.util import Retry from urllib3 import PoolManager import six diff --git a/trains/backend_interface/base.py b/trains/backend_interface/base.py index 3c4a40fb..b75b1033 100644 --- a/trains/backend_interface/base.py +++ b/trains/backend_interface/base.py @@ -61,13 +61,22 @@ class InterfaceBase(SessionInterface): except requests.exceptions.BaseHTTPError as e: res = None - log.error('Failed sending %s: %s' % (str(req), str(e))) + if log: + log.warning('Failed sending %s: %s' % (str(type(req)), str(e))) except MaxRequestSizeError as e: res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e))) error_msg = 'Failed sending: %s' % str(e) + except requests.exceptions.ConnectionError: + # We couldn't send the request for more than the retries times configure in the api configuration file, + # so we will end the loop and raise the exception to the upper level. + # Notice: this is a connectivity error and not a backend error. + if raise_on_errors: + raise + res = None except Exception as e: res = None - log.error('Failed sending %s: %s' % (str(req), str(e))) + if log: + log.warning('Failed sending %s: %s' % (str(type(req)), str(e))) if res and res.meta.result_code <= 500: # Proper backend error/bad status code - raise or return @@ -75,10 +84,6 @@ class InterfaceBase(SessionInterface): raise SendError(res, error_msg) return res - # # Infrastructure error - # if log: - # log.info('retrying request %s' % str(req)) - def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False): return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors, log=self.log, async_enable=async_enable) @@ -128,7 +133,7 @@ class IdObjectBase(InterfaceBase): @id.setter def id(self, value): - should_reload = value is not None and value != self._id + should_reload = value is not None and self._id is not None and value != self._id self._id = value if should_reload: self.reload() diff --git a/trains/backend_interface/model.py b/trains/backend_interface/model.py index 19aabe79..fa2c9120 100644 --- a/trains/backend_interface/model.py +++ b/trains/backend_interface/model.py @@ -1,5 +1,7 @@ +import os from collections import namedtuple from functools import partial +from tempfile import mkstemp import six from pathlib2 import Path @@ -45,6 +47,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): _EMPTY_MODEL_ID = 'empty' + _local_model_to_id_uri = {} + @property def model_id(self): return self.id @@ -172,6 +176,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): self.upload_storage_uri = upload_storage_uri self._create_empty_model(self.upload_storage_uri) + if model_file and uri: + Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) + # upload model file if needed and get uri uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri) # update fields @@ -213,6 +220,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): if uploaded_uri is False: uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri) + Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri) + self.update( uri=uploaded_uri, task_id=task_id, @@ -234,6 +243,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): return uri else: uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename) + Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) self.update( uri=uri, task_id=task_id, @@ -339,7 +349,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): """ Download the model weights into a local file in our cache """ uri = self.data.uri helper = StorageHelper.get(uri, logger=self._log, verbose=True) - return helper.download_to_file(uri, force_cache=True) + filename = uri.split('/')[-1] + ext = '.'.join(filename.split('.')[1:]) + fd, local_filename = mkstemp(suffix='.'+ext) + os.close(fd) + local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True) + # save local model, so we can later query what was the original one + Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) + return local_download @property def cache_dir(self): diff --git a/trains/backend_interface/task/access.py b/trains/backend_interface/task/access.py index 9dc4040e..3d19097c 100644 --- a/trains/backend_interface/task/access.py +++ b/trains/backend_interface/task/access.py @@ -18,14 +18,14 @@ class AccessMixin(object): obj = self.data props = prop_path.split('.') for i in range(len(props)): - obj = getattr(obj, props[i], None) - if obj is None: + if not hasattr(obj, props[i]): msg = 'Task has no %s section defined' % '.'.join(props[:i + 1]) if log_on_error: self.log.info(msg) if raise_on_error: raise ValueError(msg) return default + obj = getattr(obj, props[i], None) return obj def _set_task_property(self, prop_path, value, raise_on_error=True, log_on_error=True): diff --git a/trains/backend_interface/task/development/stop_signal.py b/trains/backend_interface/task/development/stop_signal.py index 24f146a2..98af69a6 100644 --- a/trains/backend_interface/task/development/stop_signal.py +++ b/trains/backend_interface/task/development/stop_signal.py @@ -30,22 +30,22 @@ class TaskStopSignal(object): def test(self): # noinspection PyBroadException try: - status = self.task.status + status = str(self.task.status) message = self.task.data.status_message - if status == tasks.TaskStatusEnum.in_progress and "stopping" in message: + if status == str(tasks.TaskStatusEnum.in_progress) and "stopping" in message: return TaskStopReason.stopped _expected_statuses = ( - tasks.TaskStatusEnum.created, - tasks.TaskStatusEnum.queued, - tasks.TaskStatusEnum.in_progress, + str(tasks.TaskStatusEnum.created), + str(tasks.TaskStatusEnum.queued), + str(tasks.TaskStatusEnum.in_progress), ) if status not in _expected_statuses and "worker" not in message: return TaskStopReason.status_changed - if status == tasks.TaskStatusEnum.created: + if status == str(tasks.TaskStatusEnum.created): self._task_reset_state_counter += 1 if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests: diff --git a/trains/backend_interface/task/development/worker.py b/trains/backend_interface/task/development/worker.py index a6caf77f..2189af86 100644 --- a/trains/backend_interface/task/development/worker.py +++ b/trains/backend_interface/task/development/worker.py @@ -1,5 +1,3 @@ -from socket import gethostname - import attr from threading import Thread, Event @@ -13,9 +11,9 @@ from ....backend_api.services import tasks class DevWorker(object): prefix = attr.ib(type=str, default="MANUAL:") - report_period = float(config.get('development.worker.report_period_sec', 30.)) + report_period = float(max(config.get('development.worker.report_period_sec', 30.), 1.)) report_stdout = bool(config.get('development.worker.log_stdout', True)) - ping_period = 30. + ping_period = float(max(config.get('development.worker.ping_period_sec', 30.), 1.)) def __init__(self): self._dev_stop_signal = None @@ -51,20 +49,23 @@ class DevWorker(object): def _daemon(self): last_ping = time() while self._task is not None: - if self._exit_event.wait(min(self.ping_period, self.report_period)): - return - # send ping request - if self._support_ping and (time() - last_ping) >= self.ping_period: - self.ping() - last_ping = time() - if self._dev_stop_signal: - stop_reason = self._dev_stop_signal.test() - if stop_reason and self._task: - self._task._dev_mode_stop_task(stop_reason) + try: + if self._exit_event.wait(min(self.ping_period, self.report_period)): + return + # send ping request + if self._support_ping and (time() - last_ping) >= self.ping_period: + self.ping() + last_ping = time() + if self._dev_stop_signal: + stop_reason = self._dev_stop_signal.test() + if stop_reason and self._task: + self._task._dev_mode_stop_task(stop_reason) + except Exception: + pass def unregister(self): - self._exit_event.set() self._dev_stop_signal = None - self._thread = None self._task = None + self._thread = None + self._exit_event.set() return True diff --git a/trains/backend_interface/task/log.py b/trains/backend_interface/task/log.py index 4a91d8b2..f855263c 100644 --- a/trains/backend_interface/task/log.py +++ b/trains/backend_interface/task/log.py @@ -1,6 +1,7 @@ import time from logging import LogRecord, getLogger, basicConfig from logging.handlers import BufferingHandler +from multiprocessing.pool import ThreadPool from ...backend_api.services import events from ...config import config @@ -27,6 +28,7 @@ class TaskHandler(BufferingHandler): self.last_timestamp = 0 self.counter = 1 self._last_event = None + self._thread_pool = ThreadPool(processes=1) def shouldFlush(self, record): """ @@ -92,6 +94,7 @@ class TaskHandler(BufferingHandler): def flush(self): if not self.buffer: return + self.acquire() buffer = self.buffer try: @@ -100,11 +103,20 @@ class TaskHandler(BufferingHandler): self.buffer = [] record_events = [self._record_to_event(record) for record in buffer] self._last_event = None - requests = [events.AddRequest(e) for e in record_events if e] - res = self.session.send(events.AddBatchRequest(requests=requests)) - if not res.ok(): - print("Failed logging task to backend ({:d} lines, {})".format(len(buffer), str(res.meta))) + batch_requests = events.AddBatchRequest(requests=[events.AddRequest(e) for e in record_events if e]) except Exception: + batch_requests = None print("Failed logging task to backend ({:d} lines)".format(len(buffer))) finally: self.release() + + if batch_requests: + self._thread_pool.apply_async(self._send_events, args=(batch_requests, )) + + def _send_events(self, a_request): + try: + res = self.session.send(a_request) + if not res.ok(): + print("Failed logging task to backend ({:d} lines, {})".format(len(a_request.requests), str(res.meta))) + except Exception: + print("Failed logging task to backend ({:d} lines)".format(len(a_request.requests))) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index ab0c415f..cec414ad 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -4,8 +4,6 @@ import itertools import logging from enum import Enum from threading import RLock, Thread -from copy import copy -from six.moves.urllib.parse import urlparse, urlunparse import six @@ -300,7 +298,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): @property def _status(self): """ Return the task's cached status (don't reload if we don't have to) """ - return self.data.status + return str(self.data.status) @property def input_model(self): @@ -349,11 +347,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): Returns a simple metrics reporter instance """ if self._reporter is None: - try: - storage_uri = self.get_output_destination(log_on_error=False) - except ValueError: - storage_uri = None - self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri)) + self._setup_reporter() return self._reporter def _get_metrics_manager(self, storage_uri): @@ -366,6 +360,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): ) return self._metrics_manager + def _setup_reporter(self): + try: + storage_uri = self.get_output_destination(log_on_error=False) + except ValueError: + storage_uri = None + self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri)) + return self._reporter + def _get_output_destination_suffix(self, extra_path=None): return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x) @@ -403,7 +405,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def publish(self, ignore_errors=True): """ Signal that this task will be published """ - if self.status != tasks.TaskStatusEnum.stopped: + if str(self.status) != str(tasks.TaskStatusEnum.stopped): raise ValueError("Can't publish, Task is not stopped") resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors) assert isinstance(resp.response, tasks.PublishResponse) @@ -471,7 +473,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): return uri def _conditionally_start_task(self): - if self.status == tasks.TaskStatusEnum.created: + if str(self.status) == str(tasks.TaskStatusEnum.created): self.started() @property @@ -700,8 +702,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _edit(self, **kwargs): with self._edit_lock: # Since we ae using forced update, make sure he task status is valid - if not self._data or (self.data.status not in (tasks.TaskStatusEnum.created, - tasks.TaskStatusEnum.in_progress)): + if not self._data or (str(self.data.status) not in (str(tasks.TaskStatusEnum.created), + str(tasks.TaskStatusEnum.in_progress))): raise ValueError('Task object can only be updated if created or in_progress') res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) diff --git a/trains/debugging/log.py b/trains/debugging/log.py index 6d83c698..619c769f 100644 --- a/trains/debugging/log.py +++ b/trains/debugging/log.py @@ -6,7 +6,6 @@ import os import sys from platform import system -from ..config import config, get_log_redirect_level from pathlib2 import Path from six import BytesIO from tqdm import tqdm @@ -55,6 +54,9 @@ class LoggerRoot(object): def get_base_logger(cls, level=None, stream=sys.stdout, colored=False): if LoggerRoot.__base_logger: return LoggerRoot.__base_logger + # avoid nested imports + from ..config import get_log_redirect_level + LoggerRoot.__base_logger = logging.getLogger('trains') level = level if level is not None else default_level LoggerRoot.__base_logger.setLevel(level) @@ -152,6 +154,9 @@ def get_null_logger(name=None): """ Get a logger with a null handler """ log = logging.getLogger(name if name else 'null') if not log.handlers: + # avoid nested imports + from ..config import config + log.addHandler(logging.NullHandler()) log.propagate = config.get("log.null_log_propagate", False) return log