Improve stability and resilience on intermittent network connection

This commit is contained in:
allegroai 2019-08-19 21:17:53 +03:00
parent 0a8cf706bd
commit 3bc1ec2362
10 changed files with 92 additions and 51 deletions

View File

@ -34,7 +34,7 @@
# backoff parameters # backoff parameters
# timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1)) # timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1))
backoff_factor: 1.0 backoff_factor: 1.0
backoff_max: 300.0 backoff_max: 120.0
} }
wait_on_maintenance_forever: true wait_on_maintenance_forever: true

View File

@ -4,7 +4,6 @@ import sys
import requests import requests
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
## from requests.packages.urllib3.util.retry import Retry
from urllib3.util import Retry from urllib3.util import Retry
from urllib3 import PoolManager from urllib3 import PoolManager
import six import six

View File

@ -61,13 +61,22 @@ class InterfaceBase(SessionInterface):
except requests.exceptions.BaseHTTPError as e: except requests.exceptions.BaseHTTPError as e:
res = None res = None
log.error('Failed sending %s: %s' % (str(req), str(e))) if log:
log.warning('Failed sending %s: %s' % (str(type(req)), str(e)))
except MaxRequestSizeError as e: except MaxRequestSizeError as e:
res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e))) res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e)))
error_msg = 'Failed sending: %s' % str(e) error_msg = 'Failed sending: %s' % str(e)
except requests.exceptions.ConnectionError:
# We couldn't send the request for more than the retries times configure in the api configuration file,
# so we will end the loop and raise the exception to the upper level.
# Notice: this is a connectivity error and not a backend error.
if raise_on_errors:
raise
res = None
except Exception as e: except Exception as e:
res = None res = None
log.error('Failed sending %s: %s' % (str(req), str(e))) if log:
log.warning('Failed sending %s: %s' % (str(type(req)), str(e)))
if res and res.meta.result_code <= 500: if res and res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return # Proper backend error/bad status code - raise or return
@ -75,10 +84,6 @@ class InterfaceBase(SessionInterface):
raise SendError(res, error_msg) raise SendError(res, error_msg)
return res return res
# # Infrastructure error
# if log:
# log.info('retrying request %s' % str(req))
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False): def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors, return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
log=self.log, async_enable=async_enable) log=self.log, async_enable=async_enable)
@ -128,7 +133,7 @@ class IdObjectBase(InterfaceBase):
@id.setter @id.setter
def id(self, value): def id(self, value):
should_reload = value is not None and value != self._id should_reload = value is not None and self._id is not None and value != self._id
self._id = value self._id = value
if should_reload: if should_reload:
self.reload() self.reload()

View File

@ -1,5 +1,7 @@
import os
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
from tempfile import mkstemp
import six import six
from pathlib2 import Path from pathlib2 import Path
@ -45,6 +47,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
_EMPTY_MODEL_ID = 'empty' _EMPTY_MODEL_ID = 'empty'
_local_model_to_id_uri = {}
@property @property
def model_id(self): def model_id(self):
return self.id return self.id
@ -172,6 +176,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.upload_storage_uri = upload_storage_uri self.upload_storage_uri = upload_storage_uri
self._create_empty_model(self.upload_storage_uri) self._create_empty_model(self.upload_storage_uri)
if model_file and uri:
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
# upload model file if needed and get uri # upload model file if needed and get uri
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri) uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
# update fields # update fields
@ -213,6 +220,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
if uploaded_uri is False: if uploaded_uri is False:
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri) uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri)
self.update( self.update(
uri=uploaded_uri, uri=uploaded_uri,
task_id=task_id, task_id=task_id,
@ -234,6 +243,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return uri return uri
else: else:
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename) uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
self.update( self.update(
uri=uri, uri=uri,
task_id=task_id, task_id=task_id,
@ -339,7 +349,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
""" Download the model weights into a local file in our cache """ """ Download the model weights into a local file in our cache """
uri = self.data.uri uri = self.data.uri
helper = StorageHelper.get(uri, logger=self._log, verbose=True) helper = StorageHelper.get(uri, logger=self._log, verbose=True)
return helper.download_to_file(uri, force_cache=True) filename = uri.split('/')[-1]
ext = '.'.join(filename.split('.')[1:])
fd, local_filename = mkstemp(suffix='.'+ext)
os.close(fd)
local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True)
# save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
return local_download
@property @property
def cache_dir(self): def cache_dir(self):

View File

@ -18,14 +18,14 @@ class AccessMixin(object):
obj = self.data obj = self.data
props = prop_path.split('.') props = prop_path.split('.')
for i in range(len(props)): for i in range(len(props)):
obj = getattr(obj, props[i], None) if not hasattr(obj, props[i]):
if obj is None:
msg = 'Task has no %s section defined' % '.'.join(props[:i + 1]) msg = 'Task has no %s section defined' % '.'.join(props[:i + 1])
if log_on_error: if log_on_error:
self.log.info(msg) self.log.info(msg)
if raise_on_error: if raise_on_error:
raise ValueError(msg) raise ValueError(msg)
return default return default
obj = getattr(obj, props[i], None)
return obj return obj
def _set_task_property(self, prop_path, value, raise_on_error=True, log_on_error=True): def _set_task_property(self, prop_path, value, raise_on_error=True, log_on_error=True):

View File

@ -30,22 +30,22 @@ class TaskStopSignal(object):
def test(self): def test(self):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
status = self.task.status status = str(self.task.status)
message = self.task.data.status_message message = self.task.data.status_message
if status == tasks.TaskStatusEnum.in_progress and "stopping" in message: if status == str(tasks.TaskStatusEnum.in_progress) and "stopping" in message:
return TaskStopReason.stopped return TaskStopReason.stopped
_expected_statuses = ( _expected_statuses = (
tasks.TaskStatusEnum.created, str(tasks.TaskStatusEnum.created),
tasks.TaskStatusEnum.queued, str(tasks.TaskStatusEnum.queued),
tasks.TaskStatusEnum.in_progress, str(tasks.TaskStatusEnum.in_progress),
) )
if status not in _expected_statuses and "worker" not in message: if status not in _expected_statuses and "worker" not in message:
return TaskStopReason.status_changed return TaskStopReason.status_changed
if status == tasks.TaskStatusEnum.created: if status == str(tasks.TaskStatusEnum.created):
self._task_reset_state_counter += 1 self._task_reset_state_counter += 1
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests: if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:

View File

@ -1,5 +1,3 @@
from socket import gethostname
import attr import attr
from threading import Thread, Event from threading import Thread, Event
@ -13,9 +11,9 @@ from ....backend_api.services import tasks
class DevWorker(object): class DevWorker(object):
prefix = attr.ib(type=str, default="MANUAL:") prefix = attr.ib(type=str, default="MANUAL:")
report_period = float(config.get('development.worker.report_period_sec', 30.)) report_period = float(max(config.get('development.worker.report_period_sec', 30.), 1.))
report_stdout = bool(config.get('development.worker.log_stdout', True)) report_stdout = bool(config.get('development.worker.log_stdout', True))
ping_period = 30. ping_period = float(max(config.get('development.worker.ping_period_sec', 30.), 1.))
def __init__(self): def __init__(self):
self._dev_stop_signal = None self._dev_stop_signal = None
@ -51,20 +49,23 @@ class DevWorker(object):
def _daemon(self): def _daemon(self):
last_ping = time() last_ping = time()
while self._task is not None: while self._task is not None:
if self._exit_event.wait(min(self.ping_period, self.report_period)): try:
return if self._exit_event.wait(min(self.ping_period, self.report_period)):
# send ping request return
if self._support_ping and (time() - last_ping) >= self.ping_period: # send ping request
self.ping() if self._support_ping and (time() - last_ping) >= self.ping_period:
last_ping = time() self.ping()
if self._dev_stop_signal: last_ping = time()
stop_reason = self._dev_stop_signal.test() if self._dev_stop_signal:
if stop_reason and self._task: stop_reason = self._dev_stop_signal.test()
self._task._dev_mode_stop_task(stop_reason) if stop_reason and self._task:
self._task._dev_mode_stop_task(stop_reason)
except Exception:
pass
def unregister(self): def unregister(self):
self._exit_event.set()
self._dev_stop_signal = None self._dev_stop_signal = None
self._thread = None
self._task = None self._task = None
self._thread = None
self._exit_event.set()
return True return True

View File

@ -1,6 +1,7 @@
import time import time
from logging import LogRecord, getLogger, basicConfig from logging import LogRecord, getLogger, basicConfig
from logging.handlers import BufferingHandler from logging.handlers import BufferingHandler
from multiprocessing.pool import ThreadPool
from ...backend_api.services import events from ...backend_api.services import events
from ...config import config from ...config import config
@ -27,6 +28,7 @@ class TaskHandler(BufferingHandler):
self.last_timestamp = 0 self.last_timestamp = 0
self.counter = 1 self.counter = 1
self._last_event = None self._last_event = None
self._thread_pool = ThreadPool(processes=1)
def shouldFlush(self, record): def shouldFlush(self, record):
""" """
@ -92,6 +94,7 @@ class TaskHandler(BufferingHandler):
def flush(self): def flush(self):
if not self.buffer: if not self.buffer:
return return
self.acquire() self.acquire()
buffer = self.buffer buffer = self.buffer
try: try:
@ -100,11 +103,20 @@ class TaskHandler(BufferingHandler):
self.buffer = [] self.buffer = []
record_events = [self._record_to_event(record) for record in buffer] record_events = [self._record_to_event(record) for record in buffer]
self._last_event = None self._last_event = None
requests = [events.AddRequest(e) for e in record_events if e] batch_requests = events.AddBatchRequest(requests=[events.AddRequest(e) for e in record_events if e])
res = self.session.send(events.AddBatchRequest(requests=requests))
if not res.ok():
print("Failed logging task to backend ({:d} lines, {})".format(len(buffer), str(res.meta)))
except Exception: except Exception:
batch_requests = None
print("Failed logging task to backend ({:d} lines)".format(len(buffer))) print("Failed logging task to backend ({:d} lines)".format(len(buffer)))
finally: finally:
self.release() self.release()
if batch_requests:
self._thread_pool.apply_async(self._send_events, args=(batch_requests, ))
def _send_events(self, a_request):
try:
res = self.session.send(a_request)
if not res.ok():
print("Failed logging task to backend ({:d} lines, {})".format(len(a_request.requests), str(res.meta)))
except Exception:
print("Failed logging task to backend ({:d} lines)".format(len(a_request.requests)))

View File

@ -4,8 +4,6 @@ import itertools
import logging import logging
from enum import Enum from enum import Enum
from threading import RLock, Thread from threading import RLock, Thread
from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse
import six import six
@ -300,7 +298,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def _status(self): def _status(self):
""" Return the task's cached status (don't reload if we don't have to) """ """ Return the task's cached status (don't reload if we don't have to) """
return self.data.status return str(self.data.status)
@property @property
def input_model(self): def input_model(self):
@ -349,11 +347,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
Returns a simple metrics reporter instance Returns a simple metrics reporter instance
""" """
if self._reporter is None: if self._reporter is None:
try: self._setup_reporter()
storage_uri = self.get_output_destination(log_on_error=False)
except ValueError:
storage_uri = None
self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri))
return self._reporter return self._reporter
def _get_metrics_manager(self, storage_uri): def _get_metrics_manager(self, storage_uri):
@ -366,6 +360,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
) )
return self._metrics_manager return self._metrics_manager
def _setup_reporter(self):
try:
storage_uri = self.get_output_destination(log_on_error=False)
except ValueError:
storage_uri = None
self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri))
return self._reporter
def _get_output_destination_suffix(self, extra_path=None): def _get_output_destination_suffix(self, extra_path=None):
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x) return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
@ -403,7 +405,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def publish(self, ignore_errors=True): def publish(self, ignore_errors=True):
""" Signal that this task will be published """ """ Signal that this task will be published """
if self.status != tasks.TaskStatusEnum.stopped: if str(self.status) != str(tasks.TaskStatusEnum.stopped):
raise ValueError("Can't publish, Task is not stopped") raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors) resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse) assert isinstance(resp.response, tasks.PublishResponse)
@ -471,7 +473,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return uri return uri
def _conditionally_start_task(self): def _conditionally_start_task(self):
if self.status == tasks.TaskStatusEnum.created: if str(self.status) == str(tasks.TaskStatusEnum.created):
self.started() self.started()
@property @property
@ -700,8 +702,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _edit(self, **kwargs): def _edit(self, **kwargs):
with self._edit_lock: with self._edit_lock:
# Since we ae using forced update, make sure he task status is valid # Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (tasks.TaskStatusEnum.created, if not self._data or (str(self.data.status) not in (str(tasks.TaskStatusEnum.created),
tasks.TaskStatusEnum.in_progress)): str(tasks.TaskStatusEnum.in_progress))):
raise ValueError('Task object can only be updated if created or in_progress') raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)

View File

@ -6,7 +6,6 @@ import os
import sys import sys
from platform import system from platform import system
from ..config import config, get_log_redirect_level
from pathlib2 import Path from pathlib2 import Path
from six import BytesIO from six import BytesIO
from tqdm import tqdm from tqdm import tqdm
@ -55,6 +54,9 @@ class LoggerRoot(object):
def get_base_logger(cls, level=None, stream=sys.stdout, colored=False): def get_base_logger(cls, level=None, stream=sys.stdout, colored=False):
if LoggerRoot.__base_logger: if LoggerRoot.__base_logger:
return LoggerRoot.__base_logger return LoggerRoot.__base_logger
# avoid nested imports
from ..config import get_log_redirect_level
LoggerRoot.__base_logger = logging.getLogger('trains') LoggerRoot.__base_logger = logging.getLogger('trains')
level = level if level is not None else default_level level = level if level is not None else default_level
LoggerRoot.__base_logger.setLevel(level) LoggerRoot.__base_logger.setLevel(level)
@ -152,6 +154,9 @@ def get_null_logger(name=None):
""" Get a logger with a null handler """ """ Get a logger with a null handler """
log = logging.getLogger(name if name else 'null') log = logging.getLogger(name if name else 'null')
if not log.handlers: if not log.handlers:
# avoid nested imports
from ..config import config
log.addHandler(logging.NullHandler()) log.addHandler(logging.NullHandler())
log.propagate = config.get("log.null_log_propagate", False) log.propagate = config.get("log.null_log_propagate", False)
return log return log