mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Add check for server updates
This commit is contained in:
parent
5d17059cbe
commit
6101dc4f11
@ -88,4 +88,17 @@
|
|||||||
task_update_timeout: 600
|
task_update_timeout: 600
|
||||||
}
|
}
|
||||||
|
|
||||||
|
check_for_updates {
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
# Check for updates every 24 hours
|
||||||
|
check_interval_sec: 86400
|
||||||
|
|
||||||
|
url: "https://updates.trains.allegro.ai/updates"
|
||||||
|
|
||||||
|
component_name: "trains-server"
|
||||||
|
|
||||||
|
# GET request timeout
|
||||||
|
request_timeout_sec: 3.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,6 @@ from database.utils import (
|
|||||||
get_fields_choices,
|
get_fields_choices,
|
||||||
field_does_not_exist,
|
field_does_not_exist,
|
||||||
field_exists,
|
field_exists,
|
||||||
get_fields,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log = config.logger("dbmodel")
|
log = config.logger("dbmodel")
|
||||||
@ -477,11 +476,9 @@ class GetMixin(PropsMixin):
|
|||||||
):
|
):
|
||||||
params = {}
|
params = {}
|
||||||
mongo_field = order_field.replace(".", "__")
|
mongo_field = order_field.replace(".", "__")
|
||||||
if mongo_field in get_fields(cls, of_type=ListField, subfields=True):
|
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
|
||||||
params["is_list"] = True
|
params["is_list"] = True
|
||||||
elif mongo_field in get_fields(
|
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
|
||||||
cls, of_type=StringField, subfields=True
|
|
||||||
):
|
|
||||||
params["empty_value"] = ""
|
params["empty_value"] = ""
|
||||||
non_empty = query & field_exists(mongo_field, **params)
|
non_empty = query & field_exists(mongo_field, **params)
|
||||||
empty = query & field_does_not_exist(mongo_field, **params)
|
empty = query & field_does_not_exist(mongo_field, **params)
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict, defaultdict
|
||||||
|
from itertools import chain
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||||
from mongoengine.base import get_document
|
from mongoengine.base import get_document, BaseField
|
||||||
|
|
||||||
from database.fields import (
|
from database.fields import (
|
||||||
LengthRangeEmbeddedDocumentListField,
|
LengthRangeEmbeddedDocumentListField,
|
||||||
@ -20,6 +21,7 @@ class PropsMixin(object):
|
|||||||
__cached_reference_fields = None
|
__cached_reference_fields = None
|
||||||
__cached_exclude_fields = None
|
__cached_exclude_fields = None
|
||||||
__cached_fields_with_instance = None
|
__cached_fields_with_instance = None
|
||||||
|
__cached_field_names_per_type = None
|
||||||
|
|
||||||
__cached_dpath_computed_fields_lock = Lock()
|
__cached_dpath_computed_fields_lock = Lock()
|
||||||
__cached_dpath_computed_fields = None
|
__cached_dpath_computed_fields = None
|
||||||
@ -30,6 +32,39 @@ class PropsMixin(object):
|
|||||||
cls.__cached_fields = get_fields(cls)
|
cls.__cached_fields = get_fields(cls)
|
||||||
return cls.__cached_fields
|
return cls.__cached_fields
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_field_names_for_type(cls, of_type=BaseField):
|
||||||
|
"""
|
||||||
|
Return field names per type including subfields
|
||||||
|
The fields of derived types are also returned
|
||||||
|
"""
|
||||||
|
assert issubclass(of_type, BaseField)
|
||||||
|
if cls.__cached_field_names_per_type is None:
|
||||||
|
fields = defaultdict(list)
|
||||||
|
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
||||||
|
fields[type(field)].append(name)
|
||||||
|
for type_ in fields:
|
||||||
|
fields[type_].extend(
|
||||||
|
chain.from_iterable(
|
||||||
|
fields[other_type]
|
||||||
|
for other_type in fields
|
||||||
|
if other_type != type_ and issubclass(other_type, type_)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.__cached_field_names_per_type = fields
|
||||||
|
|
||||||
|
if of_type not in cls.__cached_field_names_per_type:
|
||||||
|
names = list(
|
||||||
|
chain.from_iterable(
|
||||||
|
field_names
|
||||||
|
for type_, field_names in cls.__cached_field_names_per_type.items()
|
||||||
|
if issubclass(type_, of_type)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cls.__cached_field_names_per_type[of_type] = names
|
||||||
|
|
||||||
|
return cls.__cached_field_names_per_type[of_type]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields_with_instance(cls, doc_cls):
|
def get_fields_with_instance(cls, doc_cls):
|
||||||
if cls.__cached_fields_with_instance is None:
|
if cls.__cached_fields_with_instance is None:
|
||||||
|
@ -14,6 +14,8 @@ from service_repo.errors import PathParsingError
|
|||||||
from timing_context import TimingContext
|
from timing_context import TimingContext
|
||||||
from utilities import json
|
from utilities import json
|
||||||
from init_data import init_es_data, init_mongo_data
|
from init_data import init_es_data, init_mongo_data
|
||||||
|
from updates import check_updates_thread
|
||||||
|
|
||||||
|
|
||||||
app = Flask(__name__, static_url_path="/static")
|
app = Flask(__name__, static_url_path="/static")
|
||||||
CORS(app, **config.get("apiserver.cors"))
|
CORS(app, **config.get("apiserver.cors"))
|
||||||
@ -35,6 +37,9 @@ ServiceRepo.load("services")
|
|||||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||||
|
|
||||||
|
|
||||||
|
check_updates_thread.start()
|
||||||
|
|
||||||
|
|
||||||
@app.before_first_request
|
@app.before_first_request
|
||||||
def before_app_first_request():
|
def before_app_first_request():
|
||||||
pass
|
pass
|
||||||
|
@ -198,8 +198,6 @@ def get_queue_metrics(
|
|||||||
interval=req_model.interval,
|
interval=req_model.interval,
|
||||||
queue_ids=req_model.queue_ids,
|
queue_ids=req_model.queue_ids,
|
||||||
)
|
)
|
||||||
if not ret:
|
|
||||||
return GetMetricsResponse(queues=[])
|
|
||||||
|
|
||||||
queue_dicts = {
|
queue_dicts = {
|
||||||
queue: extract_properties_to_lists(
|
queue: extract_properties_to_lists(
|
||||||
@ -214,7 +212,7 @@ def get_queue_metrics(
|
|||||||
dates=data["date"],
|
dates=data["date"],
|
||||||
avg_waiting_times=data["avg_waiting_time"],
|
avg_waiting_times=data["avg_waiting_time"],
|
||||||
queue_lengths=data["queue_length"],
|
queue_lengths=data["queue_length"],
|
||||||
)
|
) if data else QueueMetrics(queue=queue)
|
||||||
for queue, data in queue_dicts.items()
|
for queue, data in queue_dicts.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
110
server/updates.py
Normal file
110
server/updates.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
from threading import Thread
|
||||||
|
from time import sleep
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import attr
|
||||||
|
import requests
|
||||||
|
from semantic_version import Version
|
||||||
|
|
||||||
|
from config import config
|
||||||
|
from version import __version__ as current_version
|
||||||
|
|
||||||
|
log = config.logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckUpdatesThread(Thread):
|
||||||
|
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class _VersionResponse:
|
||||||
|
version: str
|
||||||
|
patch_upgrade: bool
|
||||||
|
description: str = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CheckUpdatesThread, self).__init__(
|
||||||
|
target=self._check_updates, daemon=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if not self._enabled:
|
||||||
|
log.info("Checking for updates is disabled")
|
||||||
|
return
|
||||||
|
super(CheckUpdatesThread, self).start()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component_name(self) -> str:
|
||||||
|
return config.get("apiserver.check_for_updates.component_name", "trains-server")
|
||||||
|
|
||||||
|
def _check_new_version_available(self) -> Optional[_VersionResponse]:
|
||||||
|
url = config.get(
|
||||||
|
"apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
url,
|
||||||
|
json={"versions": {self.component_name: str(current_version)}},
|
||||||
|
timeout=float(
|
||||||
|
config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.ok:
|
||||||
|
return
|
||||||
|
|
||||||
|
response = response.json().get(self.component_name)
|
||||||
|
if not response:
|
||||||
|
return
|
||||||
|
|
||||||
|
latest_version = response.get("version")
|
||||||
|
if not latest_version:
|
||||||
|
return
|
||||||
|
|
||||||
|
cur_version = Version(current_version)
|
||||||
|
latest_version = Version(latest_version)
|
||||||
|
if cur_version >= latest_version:
|
||||||
|
return
|
||||||
|
|
||||||
|
return self._VersionResponse(
|
||||||
|
version=str(latest_version),
|
||||||
|
patch_upgrade=(
|
||||||
|
latest_version.major == cur_version.major
|
||||||
|
and latest_version.minor == cur_version.minor
|
||||||
|
),
|
||||||
|
description=response.get("description").split("\r\n"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_updates(self):
|
||||||
|
while True:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
response = self._check_new_version_available()
|
||||||
|
if response:
|
||||||
|
if response.patch_upgrade:
|
||||||
|
log.info(
|
||||||
|
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
|
||||||
|
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.info(
|
||||||
|
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
|
||||||
|
f" is recommended!"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
log.exception("Failed obtaining updates")
|
||||||
|
|
||||||
|
sleep(
|
||||||
|
max(
|
||||||
|
float(
|
||||||
|
config.get(
|
||||||
|
"apiserver.check_for_updates.check_interval_sec",
|
||||||
|
60 * 60 * 24,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
60 * 5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
check_updates_thread = CheckUpdatesThread()
|
1
server/version.py
Normal file
1
server/version.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
__version__ = "0.12.0"
|
Loading…
Reference in New Issue
Block a user