Improve stability and resilience on intermittent network connection

This commit is contained in:
allegroai 2019-08-19 21:17:53 +03:00
parent 0a8cf706bd
commit 3bc1ec2362
10 changed files with 92 additions and 51 deletions

View File

@ -34,7 +34,7 @@
# backoff parameters
# timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1))
backoff_factor: 1.0
backoff_max: 300.0
backoff_max: 120.0
}
wait_on_maintenance_forever: true

View File

@ -4,7 +4,6 @@ import sys
import requests
from requests.adapters import HTTPAdapter
## from requests.packages.urllib3.util.retry import Retry
from urllib3.util import Retry
from urllib3 import PoolManager
import six

View File

@ -61,13 +61,22 @@ class InterfaceBase(SessionInterface):
except requests.exceptions.BaseHTTPError as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
if log:
log.warning('Failed sending %s: %s' % (str(type(req)), str(e)))
except MaxRequestSizeError as e:
res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e)))
error_msg = 'Failed sending: %s' % str(e)
except requests.exceptions.ConnectionError:
# We couldn't send the request for more than the retries times configure in the api configuration file,
# so we will end the loop and raise the exception to the upper level.
# Notice: this is a connectivity error and not a backend error.
if raise_on_errors:
raise
res = None
except Exception as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
if log:
log.warning('Failed sending %s: %s' % (str(type(req)), str(e)))
if res and res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
@ -75,10 +84,6 @@ class InterfaceBase(SessionInterface):
raise SendError(res, error_msg)
return res
# # Infrastructure error
# if log:
# log.info('retrying request %s' % str(req))
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
log=self.log, async_enable=async_enable)
@ -128,7 +133,7 @@ class IdObjectBase(InterfaceBase):
@id.setter
def id(self, value):
should_reload = value is not None and value != self._id
should_reload = value is not None and self._id is not None and value != self._id
self._id = value
if should_reload:
self.reload()

View File

@ -1,5 +1,7 @@
import os
from collections import namedtuple
from functools import partial
from tempfile import mkstemp
import six
from pathlib2 import Path
@ -45,6 +47,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
_EMPTY_MODEL_ID = 'empty'
_local_model_to_id_uri = {}
@property
def model_id(self):
return self.id
@ -172,6 +176,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.upload_storage_uri = upload_storage_uri
self._create_empty_model(self.upload_storage_uri)
if model_file and uri:
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
# upload model file if needed and get uri
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
# update fields
@ -213,6 +220,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
if uploaded_uri is False:
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri)
self.update(
uri=uploaded_uri,
task_id=task_id,
@ -234,6 +243,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return uri
else:
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
self.update(
uri=uri,
task_id=task_id,
@ -339,7 +349,14 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
""" Download the model weights into a local file in our cache """
uri = self.data.uri
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
return helper.download_to_file(uri, force_cache=True)
filename = uri.split('/')[-1]
ext = '.'.join(filename.split('.')[1:])
fd, local_filename = mkstemp(suffix='.'+ext)
os.close(fd)
local_download = helper.download_to_file(uri, local_path=local_filename, overwrite_existing=True)
# save local model, so we can later query what was the original one
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
return local_download
@property
def cache_dir(self):

View File

@ -18,14 +18,14 @@ class AccessMixin(object):
obj = self.data
props = prop_path.split('.')
for i in range(len(props)):
obj = getattr(obj, props[i], None)
if obj is None:
if not hasattr(obj, props[i]):
msg = 'Task has no %s section defined' % '.'.join(props[:i + 1])
if log_on_error:
self.log.info(msg)
if raise_on_error:
raise ValueError(msg)
return default
obj = getattr(obj, props[i], None)
return obj
def _set_task_property(self, prop_path, value, raise_on_error=True, log_on_error=True):

View File

@ -30,22 +30,22 @@ class TaskStopSignal(object):
def test(self):
# noinspection PyBroadException
try:
status = self.task.status
status = str(self.task.status)
message = self.task.data.status_message
if status == tasks.TaskStatusEnum.in_progress and "stopping" in message:
if status == str(tasks.TaskStatusEnum.in_progress) and "stopping" in message:
return TaskStopReason.stopped
_expected_statuses = (
tasks.TaskStatusEnum.created,
tasks.TaskStatusEnum.queued,
tasks.TaskStatusEnum.in_progress,
str(tasks.TaskStatusEnum.created),
str(tasks.TaskStatusEnum.queued),
str(tasks.TaskStatusEnum.in_progress),
)
if status not in _expected_statuses and "worker" not in message:
return TaskStopReason.status_changed
if status == tasks.TaskStatusEnum.created:
if status == str(tasks.TaskStatusEnum.created):
self._task_reset_state_counter += 1
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:

View File

@ -1,5 +1,3 @@
from socket import gethostname
import attr
from threading import Thread, Event
@ -13,9 +11,9 @@ from ....backend_api.services import tasks
class DevWorker(object):
prefix = attr.ib(type=str, default="MANUAL:")
report_period = float(config.get('development.worker.report_period_sec', 30.))
report_period = float(max(config.get('development.worker.report_period_sec', 30.), 1.))
report_stdout = bool(config.get('development.worker.log_stdout', True))
ping_period = 30.
ping_period = float(max(config.get('development.worker.ping_period_sec', 30.), 1.))
def __init__(self):
self._dev_stop_signal = None
@ -51,20 +49,23 @@ class DevWorker(object):
def _daemon(self):
last_ping = time()
while self._task is not None:
if self._exit_event.wait(min(self.ping_period, self.report_period)):
return
# send ping request
if self._support_ping and (time() - last_ping) >= self.ping_period:
self.ping()
last_ping = time()
if self._dev_stop_signal:
stop_reason = self._dev_stop_signal.test()
if stop_reason and self._task:
self._task._dev_mode_stop_task(stop_reason)
try:
if self._exit_event.wait(min(self.ping_period, self.report_period)):
return
# send ping request
if self._support_ping and (time() - last_ping) >= self.ping_period:
self.ping()
last_ping = time()
if self._dev_stop_signal:
stop_reason = self._dev_stop_signal.test()
if stop_reason and self._task:
self._task._dev_mode_stop_task(stop_reason)
except Exception:
pass
def unregister(self):
self._exit_event.set()
self._dev_stop_signal = None
self._thread = None
self._task = None
self._thread = None
self._exit_event.set()
return True

View File

@ -1,6 +1,7 @@
import time
from logging import LogRecord, getLogger, basicConfig
from logging.handlers import BufferingHandler
from multiprocessing.pool import ThreadPool
from ...backend_api.services import events
from ...config import config
@ -27,6 +28,7 @@ class TaskHandler(BufferingHandler):
self.last_timestamp = 0
self.counter = 1
self._last_event = None
self._thread_pool = ThreadPool(processes=1)
def shouldFlush(self, record):
"""
@ -92,6 +94,7 @@ class TaskHandler(BufferingHandler):
def flush(self):
if not self.buffer:
return
self.acquire()
buffer = self.buffer
try:
@ -100,11 +103,20 @@ class TaskHandler(BufferingHandler):
self.buffer = []
record_events = [self._record_to_event(record) for record in buffer]
self._last_event = None
requests = [events.AddRequest(e) for e in record_events if e]
res = self.session.send(events.AddBatchRequest(requests=requests))
if not res.ok():
print("Failed logging task to backend ({:d} lines, {})".format(len(buffer), str(res.meta)))
batch_requests = events.AddBatchRequest(requests=[events.AddRequest(e) for e in record_events if e])
except Exception:
batch_requests = None
print("Failed logging task to backend ({:d} lines)".format(len(buffer)))
finally:
self.release()
if batch_requests:
self._thread_pool.apply_async(self._send_events, args=(batch_requests, ))
def _send_events(self, a_request):
try:
res = self.session.send(a_request)
if not res.ok():
print("Failed logging task to backend ({:d} lines, {})".format(len(a_request.requests), str(res.meta)))
except Exception:
print("Failed logging task to backend ({:d} lines)".format(len(a_request.requests)))

View File

@ -4,8 +4,6 @@ import itertools
import logging
from enum import Enum
from threading import RLock, Thread
from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse
import six
@ -300,7 +298,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property
def _status(self):
""" Return the task's cached status (don't reload if we don't have to) """
return self.data.status
return str(self.data.status)
@property
def input_model(self):
@ -349,11 +347,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
Returns a simple metrics reporter instance
"""
if self._reporter is None:
try:
storage_uri = self.get_output_destination(log_on_error=False)
except ValueError:
storage_uri = None
self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri))
self._setup_reporter()
return self._reporter
def _get_metrics_manager(self, storage_uri):
@ -366,6 +360,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
)
return self._metrics_manager
def _setup_reporter(self):
try:
storage_uri = self.get_output_destination(log_on_error=False)
except ValueError:
storage_uri = None
self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri))
return self._reporter
def _get_output_destination_suffix(self, extra_path=None):
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
@ -403,7 +405,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def publish(self, ignore_errors=True):
""" Signal that this task will be published """
if self.status != tasks.TaskStatusEnum.stopped:
if str(self.status) != str(tasks.TaskStatusEnum.stopped):
raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse)
@ -471,7 +473,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return uri
def _conditionally_start_task(self):
if self.status == tasks.TaskStatusEnum.created:
if str(self.status) == str(tasks.TaskStatusEnum.created):
self.started()
@property
@ -700,8 +702,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _edit(self, **kwargs):
with self._edit_lock:
# Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (tasks.TaskStatusEnum.created,
tasks.TaskStatusEnum.in_progress)):
if not self._data or (str(self.data.status) not in (str(tasks.TaskStatusEnum.created),
str(tasks.TaskStatusEnum.in_progress))):
raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)

View File

@ -6,7 +6,6 @@ import os
import sys
from platform import system
from ..config import config, get_log_redirect_level
from pathlib2 import Path
from six import BytesIO
from tqdm import tqdm
@ -55,6 +54,9 @@ class LoggerRoot(object):
def get_base_logger(cls, level=None, stream=sys.stdout, colored=False):
if LoggerRoot.__base_logger:
return LoggerRoot.__base_logger
# avoid nested imports
from ..config import get_log_redirect_level
LoggerRoot.__base_logger = logging.getLogger('trains')
level = level if level is not None else default_level
LoggerRoot.__base_logger.setLevel(level)
@ -152,6 +154,9 @@ def get_null_logger(name=None):
""" Get a logger with a null handler """
log = logging.getLogger(name if name else 'null')
if not log.handlers:
# avoid nested imports
from ..config import config
log.addHandler(logging.NullHandler())
log.propagate = config.get("log.null_log_propagate", False)
return log