Add check for server updates

This commit is contained in:
allegroai 2019-10-28 21:49:16 +02:00
parent 5d17059cbe
commit 6101dc4f11
7 changed files with 169 additions and 10 deletions

View File

@ -88,4 +88,17 @@
task_update_timeout: 600 task_update_timeout: 600
} }
check_for_updates {
enabled: true
# Check for updates every 24 hours
check_interval_sec: 86400
url: "https://updates.trains.allegro.ai/updates"
component_name: "trains-server"
# GET request timeout
request_timeout_sec: 3.0
}
} }

View File

@ -19,7 +19,6 @@ from database.utils import (
get_fields_choices, get_fields_choices,
field_does_not_exist, field_does_not_exist,
field_exists, field_exists,
get_fields,
) )
log = config.logger("dbmodel") log = config.logger("dbmodel")
@ -477,11 +476,9 @@ class GetMixin(PropsMixin):
): ):
params = {} params = {}
mongo_field = order_field.replace(".", "__") mongo_field = order_field.replace(".", "__")
if mongo_field in get_fields(cls, of_type=ListField, subfields=True): if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True params["is_list"] = True
elif mongo_field in get_fields( elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
cls, of_type=StringField, subfields=True
):
params["empty_value"] = "" params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params) non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params) empty = query & field_does_not_exist(mongo_field, **params)

View File

@ -1,11 +1,12 @@
from collections import OrderedDict from collections import OrderedDict, defaultdict
from itertools import chain
from operator import attrgetter from operator import attrgetter
from threading import Lock from threading import Lock
from typing import Sequence from typing import Sequence
import six import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document from mongoengine.base import get_document, BaseField
from database.fields import ( from database.fields import (
LengthRangeEmbeddedDocumentListField, LengthRangeEmbeddedDocumentListField,
@ -20,6 +21,7 @@ class PropsMixin(object):
__cached_reference_fields = None __cached_reference_fields = None
__cached_exclude_fields = None __cached_exclude_fields = None
__cached_fields_with_instance = None __cached_fields_with_instance = None
__cached_field_names_per_type = None
__cached_dpath_computed_fields_lock = Lock() __cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None __cached_dpath_computed_fields = None
@ -30,6 +32,39 @@ class PropsMixin(object):
cls.__cached_fields = get_fields(cls) cls.__cached_fields = get_fields(cls)
return cls.__cached_fields return cls.__cached_fields
@classmethod
def get_field_names_for_type(cls, of_type=BaseField):
"""
Return field names per type including subfields
The fields of derived types are also returned
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
)
cls.__cached_field_names_per_type[of_type] = names
return cls.__cached_field_names_per_type[of_type]
@classmethod @classmethod
def get_fields_with_instance(cls, doc_cls): def get_fields_with_instance(cls, doc_cls):
if cls.__cached_fields_with_instance is None: if cls.__cached_fields_with_instance is None:

View File

@ -14,6 +14,8 @@ from service_repo.errors import PathParsingError
from timing_context import TimingContext from timing_context import TimingContext
from utilities import json from utilities import json
from init_data import init_es_data, init_mongo_data from init_data import init_es_data, init_mongo_data
from updates import check_updates_thread
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors")) CORS(app, **config.get("apiserver.cors"))
@ -35,6 +37,9 @@ ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}") log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
check_updates_thread.start()
@app.before_first_request @app.before_first_request
def before_app_first_request(): def before_app_first_request():
pass pass

View File

@ -198,8 +198,6 @@ def get_queue_metrics(
interval=req_model.interval, interval=req_model.interval,
queue_ids=req_model.queue_ids, queue_ids=req_model.queue_ids,
) )
if not ret:
return GetMetricsResponse(queues=[])
queue_dicts = { queue_dicts = {
queue: extract_properties_to_lists( queue: extract_properties_to_lists(
@ -214,7 +212,7 @@ def get_queue_metrics(
dates=data["date"], dates=data["date"],
avg_waiting_times=data["avg_waiting_time"], avg_waiting_times=data["avg_waiting_time"],
queue_lengths=data["queue_length"], queue_lengths=data["queue_length"],
) ) if data else QueueMetrics(queue=queue)
for queue, data in queue_dicts.items() for queue, data in queue_dicts.items()
] ]
) )

110
server/updates.py Normal file
View File

@ -0,0 +1,110 @@
import os
from threading import Thread
from time import sleep
from typing import Optional
import attr
import requests
from semantic_version import Version
from config import config
from version import __version__ as current_version
log = config.logger(__name__)
class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
@attr.s(auto_attribs=True)
class _VersionResponse:
version: str
patch_upgrade: bool
description: str = None
def __init__(self):
super(CheckUpdatesThread, self).__init__(
target=self._check_updates, daemon=True
)
def start(self) -> None:
if not self._enabled:
log.info("Checking for updates is disabled")
return
super(CheckUpdatesThread, self).start()
@property
def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "trains-server")
def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get(
"apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates"
)
response = requests.get(
url,
json={"versions": {self.component_name: str(current_version)}},
timeout=float(
config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)
),
)
if not response.ok:
return
response = response.json().get(self.component_name)
if not response:
return
latest_version = response.get("version")
if not latest_version:
return
cur_version = Version(current_version)
latest_version = Version(latest_version)
if cur_version >= latest_version:
return
return self._VersionResponse(
version=str(latest_version),
patch_upgrade=(
latest_version.major == cur_version.major
and latest_version.minor == cur_version.minor
),
description=response.get("description").split("\r\n"),
)
def _check_updates(self):
while True:
# noinspection PyBroadException
try:
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception:
log.exception("Failed obtaining updates")
sleep(
max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
)
check_updates_thread = CheckUpdatesThread()

1
server/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.12.0"