From 6101dc4f11de2922760106d2849b5230b8a22aac Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 28 Oct 2019 21:49:16 +0200 Subject: [PATCH] Add check for server updates --- server/config/default/apiserver.conf | 13 ++++ server/database/model/base.py | 7 +- server/database/props.py | 39 +++++++++- server/server.py | 5 ++ server/services/queues.py | 4 +- server/updates.py | 110 +++++++++++++++++++++++++++ server/version.py | 1 + 7 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 server/updates.py create mode 100644 server/version.py diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index 74fb6ec..7bf2f75 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -88,4 +88,17 @@ task_update_timeout: 600 } + check_for_updates { + enabled: true + + # Check for updates every 24 hours + check_interval_sec: 86400 + + url: "https://updates.trains.allegro.ai/updates" + + component_name: "trains-server" + + # GET request timeout + request_timeout_sec: 3.0 + } } diff --git a/server/database/model/base.py b/server/database/model/base.py index 65598cb..e916dbf 100644 --- a/server/database/model/base.py +++ b/server/database/model/base.py @@ -19,7 +19,6 @@ from database.utils import ( get_fields_choices, field_does_not_exist, field_exists, - get_fields, ) log = config.logger("dbmodel") @@ -477,11 +476,9 @@ class GetMixin(PropsMixin): ): params = {} mongo_field = order_field.replace(".", "__") - if mongo_field in get_fields(cls, of_type=ListField, subfields=True): + if mongo_field in cls.get_field_names_for_type(of_type=ListField): params["is_list"] = True - elif mongo_field in get_fields( - cls, of_type=StringField, subfields=True - ): + elif mongo_field in cls.get_field_names_for_type(of_type=StringField): params["empty_value"] = "" non_empty = query & field_exists(mongo_field, **params) empty = query & field_does_not_exist(mongo_field, **params) diff --git a/server/database/props.py b/server/database/props.py index 4613606..a61dc0b 100644 --- a/server/database/props.py +++ b/server/database/props.py @@ -1,11 +1,12 @@ -from collections import OrderedDict +from collections import OrderedDict, defaultdict +from itertools import chain from operator import attrgetter from threading import Lock from typing import Sequence import six from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField -from mongoengine.base import get_document +from mongoengine.base import get_document, BaseField from database.fields import ( LengthRangeEmbeddedDocumentListField, @@ -20,6 +21,7 @@ class PropsMixin(object): __cached_reference_fields = None __cached_exclude_fields = None __cached_fields_with_instance = None + __cached_field_names_per_type = None __cached_dpath_computed_fields_lock = Lock() __cached_dpath_computed_fields = None @@ -30,6 +32,39 @@ class PropsMixin(object): cls.__cached_fields = get_fields(cls) return cls.__cached_fields + @classmethod + def get_field_names_for_type(cls, of_type=BaseField): + """ + Return field names per type including subfields + The fields of derived types are also returned + """ + assert issubclass(of_type, BaseField) + if cls.__cached_field_names_per_type is None: + fields = defaultdict(list) + for name, field in get_fields(cls, return_instance=True, subfields=True): + fields[type(field)].append(name) + for type_ in fields: + fields[type_].extend( + chain.from_iterable( + fields[other_type] + for other_type in fields + if other_type != type_ and issubclass(other_type, type_) + ) + ) + cls.__cached_field_names_per_type = fields + + if of_type not in cls.__cached_field_names_per_type: + names = list( + chain.from_iterable( + field_names + for type_, field_names in cls.__cached_field_names_per_type.items() + if issubclass(type_, of_type) + ) + ) + cls.__cached_field_names_per_type[of_type] = names + + return cls.__cached_field_names_per_type[of_type] + @classmethod def get_fields_with_instance(cls, doc_cls): if cls.__cached_fields_with_instance is None: diff --git a/server/server.py b/server/server.py index 0204058..937163e 100644 --- a/server/server.py +++ b/server/server.py @@ -14,6 +14,8 @@ from service_repo.errors import PathParsingError from timing_context import TimingContext from utilities import json from init_data import init_es_data, init_mongo_data +from updates import check_updates_thread + app = Flask(__name__, static_url_path="/static") CORS(app, **config.get("apiserver.cors")) @@ -35,6 +37,9 @@ ServiceRepo.load("services") log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}") +check_updates_thread.start() + + @app.before_first_request def before_app_first_request(): pass diff --git a/server/services/queues.py b/server/services/queues.py index c9f1521..4c73122 100644 --- a/server/services/queues.py +++ b/server/services/queues.py @@ -198,8 +198,6 @@ def get_queue_metrics( interval=req_model.interval, queue_ids=req_model.queue_ids, ) - if not ret: - return GetMetricsResponse(queues=[]) queue_dicts = { queue: extract_properties_to_lists( @@ -214,7 +212,7 @@ def get_queue_metrics( dates=data["date"], avg_waiting_times=data["avg_waiting_time"], queue_lengths=data["queue_length"], - ) + ) if data else QueueMetrics(queue=queue) for queue, data in queue_dicts.items() ] ) diff --git a/server/updates.py b/server/updates.py new file mode 100644 index 0000000..76b5fc6 --- /dev/null +++ b/server/updates.py @@ -0,0 +1,110 @@ +import os +from threading import Thread +from time import sleep +from typing import Optional + +import attr +import requests +from semantic_version import Version + +from config import config +from version import __version__ as current_version + +log = config.logger(__name__) + + +class CheckUpdatesThread(Thread): + _enabled = bool(config.get("apiserver.check_for_updates.enabled", True)) + + @attr.s(auto_attribs=True) + class _VersionResponse: + version: str + patch_upgrade: bool + description: str = None + + def __init__(self): + super(CheckUpdatesThread, self).__init__( + target=self._check_updates, daemon=True + ) + + def start(self) -> None: + if not self._enabled: + log.info("Checking for updates is disabled") + return + super(CheckUpdatesThread, self).start() + + @property + def component_name(self) -> str: + return config.get("apiserver.check_for_updates.component_name", "trains-server") + + def _check_new_version_available(self) -> Optional[_VersionResponse]: + url = config.get( + "apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates" + ) + + response = requests.get( + url, + json={"versions": {self.component_name: str(current_version)}}, + timeout=float( + config.get("apiserver.check_for_updates.request_timeout_sec", 3.0) + ), + ) + + if not response.ok: + return + + response = response.json().get(self.component_name) + if not response: + return + + latest_version = response.get("version") + if not latest_version: + return + + cur_version = Version(current_version) + latest_version = Version(latest_version) + if cur_version >= latest_version: + return + + return self._VersionResponse( + version=str(latest_version), + patch_upgrade=( + latest_version.major == cur_version.major + and latest_version.minor == cur_version.minor + ), + description=response.get("description").split("\r\n"), + ) + + def _check_updates(self): + while True: + # noinspection PyBroadException + try: + response = self._check_new_version_available() + if response: + if response.patch_upgrade: + log.info( + f"{self.component_name.upper()} new package available: upgrade to v{response.version} " + f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}" + ) + else: + log.info( + f"{self.component_name.upper()} new version available: upgrade to v{response.version}" + f" is recommended!" + ) + except Exception: + log.exception("Failed obtaining updates") + + sleep( + max( + float( + config.get( + "apiserver.check_for_updates.check_interval_sec", + 60 * 60 * 24, + ) + ), + 60 * 5, + ) + ) + + +check_updates_thread = CheckUpdatesThread() diff --git a/server/version.py b/server/version.py new file mode 100644 index 0000000..ea370a8 --- /dev/null +++ b/server/version.py @@ -0,0 +1 @@ +__version__ = "0.12.0"