Add check for server updates

This commit is contained in:
allegroai 2019-10-28 21:49:16 +02:00
parent 5d17059cbe
commit 6101dc4f11
7 changed files with 169 additions and 10 deletions

View File

@ -88,4 +88,17 @@
task_update_timeout: 600
}
check_for_updates {
enabled: true
# Check for updates every 24 hours
check_interval_sec: 86400
url: "https://updates.trains.allegro.ai/updates"
component_name: "trains-server"
# GET request timeout
request_timeout_sec: 3.0
}
}

View File

@ -19,7 +19,6 @@ from database.utils import (
get_fields_choices,
field_does_not_exist,
field_exists,
get_fields,
)
log = config.logger("dbmodel")
@ -477,11 +476,9 @@ class GetMixin(PropsMixin):
):
params = {}
mongo_field = order_field.replace(".", "__")
if mongo_field in get_fields(cls, of_type=ListField, subfields=True):
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True
elif mongo_field in get_fields(
cls, of_type=StringField, subfields=True
):
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)

View File

@ -1,11 +1,12 @@
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from itertools import chain
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document
from mongoengine.base import get_document, BaseField
from database.fields import (
LengthRangeEmbeddedDocumentListField,
@ -20,6 +21,7 @@ class PropsMixin(object):
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_field_names_per_type = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@ -30,6 +32,39 @@ class PropsMixin(object):
cls.__cached_fields = get_fields(cls)
return cls.__cached_fields
@classmethod
def get_field_names_for_type(cls, of_type=BaseField):
"""
Return field names per type including subfields
The fields of derived types are also returned
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
)
cls.__cached_field_names_per_type[of_type] = names
return cls.__cached_field_names_per_type[of_type]
@classmethod
def get_fields_with_instance(cls, doc_cls):
if cls.__cached_fields_with_instance is None:

View File

@ -14,6 +14,8 @@ from service_repo.errors import PathParsingError
from timing_context import TimingContext
from utilities import json
from init_data import init_es_data, init_mongo_data
from updates import check_updates_thread
app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors"))
@ -35,6 +37,9 @@ ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
check_updates_thread.start()
@app.before_first_request
def before_app_first_request():
pass

View File

@ -198,8 +198,6 @@ def get_queue_metrics(
interval=req_model.interval,
queue_ids=req_model.queue_ids,
)
if not ret:
return GetMetricsResponse(queues=[])
queue_dicts = {
queue: extract_properties_to_lists(
@ -214,7 +212,7 @@ def get_queue_metrics(
dates=data["date"],
avg_waiting_times=data["avg_waiting_time"],
queue_lengths=data["queue_length"],
)
) if data else QueueMetrics(queue=queue)
for queue, data in queue_dicts.items()
]
)

110
server/updates.py Normal file
View File

@ -0,0 +1,110 @@
import os
from threading import Thread
from time import sleep
from typing import Optional
import attr
import requests
from semantic_version import Version
from config import config
from version import __version__ as current_version
log = config.logger(__name__)
class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
@attr.s(auto_attribs=True)
class _VersionResponse:
version: str
patch_upgrade: bool
description: str = None
def __init__(self):
super(CheckUpdatesThread, self).__init__(
target=self._check_updates, daemon=True
)
def start(self) -> None:
if not self._enabled:
log.info("Checking for updates is disabled")
return
super(CheckUpdatesThread, self).start()
@property
def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "trains-server")
def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get(
"apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates"
)
response = requests.get(
url,
json={"versions": {self.component_name: str(current_version)}},
timeout=float(
config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)
),
)
if not response.ok:
return
response = response.json().get(self.component_name)
if not response:
return
latest_version = response.get("version")
if not latest_version:
return
cur_version = Version(current_version)
latest_version = Version(latest_version)
if cur_version >= latest_version:
return
return self._VersionResponse(
version=str(latest_version),
patch_upgrade=(
latest_version.major == cur_version.major
and latest_version.minor == cur_version.minor
),
description=response.get("description").split("\r\n"),
)
def _check_updates(self):
while True:
# noinspection PyBroadException
try:
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception:
log.exception("Failed obtaining updates")
sleep(
max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
)
check_updates_thread = CheckUpdatesThread()

1
server/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.12.0"