Improve server update checks

This commit is contained in:
allegroai 2019-12-14 23:33:04 +02:00
parent fe3dbc92dc
commit 0ad687008c
14 changed files with 571 additions and 28 deletions

View File

@ -0,0 +1,14 @@
from jsonmodels.fields import BoolField, DateTimeField, StringField
from jsonmodels.models import Base
class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()
enabled_time = DateTimeField(nullable=True)
enabled_version = StringField(nullable=True)
enabled_user = StringField(nullable=True)

View File

@ -161,7 +161,7 @@ class QueueMetrics:
In case no queue ids are specified the avg across all the In case no queue ids are specified the avg across all the
company queues is calculated for each metric company queues is calculated for each metric
""" """
# self._log_current_metrics(company_id, queue_ids=queue_ids) # self._log_current_metrics(company, queue_ids=queue_ids)
if from_date >= to_date: if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date") raise bad_request.FieldsValueError("from_date must be less than to_date")

View File

@ -0,0 +1,306 @@
import logging
import queue
import random
import time
from datetime import timedelta, datetime
from time import sleep
from typing import Sequence, Optional
import dpath
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from bll.query import Builder as QueryBuilder
from bll.util import get_server_uuid
from bll.workers import WorkerStats, WorkerBLL
from config import config
from config.info import get_deployment_type
from database.model import Company, User
from database.model.queue import Queue
from database.model.task.task import Task
from utilities import safe_get
from utilities.json import dumps
from utilities.threads_manager import ThreadsManager
from version import __version__ as current_version
from .resource_monitor import ResourceMonitor
log = config.logger(__file__)
worker_bll = WorkerBLL()
class StatisticsReporter:
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
send_queue = queue.Queue()
supported = config.get("apiserver.statistics.supported", True)
@classmethod
def start(cls):
cls.start_sender()
cls.start_reporter()
@classmethod
@threads.register("reporter", daemon=True)
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
Note: in trains we usually have only a single company
"""
if not cls.supported:
return
report_interval = timedelta(
hours=config.get("apiserver.statistics.report_interval_hours", 24)
)
while True:
sleep(report_interval.total_seconds())
try:
for company in Company.objects(
defaults__stats_option__enabled=True
).only("id"):
stats = cls.get_statistics(company.id)
cls.send_queue.put(stats)
except Exception as ex:
log.exception(f"Failed collecting stats: {str(ex)}")
@classmethod
@threads.register("sender", daemon=True)
def start_sender(cls):
if not cls.supported:
return
url = config.get("apiserver.statistics.url")
retries = config.get("apiserver.statistics.max_retries", 5)
max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5)
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(retries))
session.mount("http://", adapter)
session.mount("https://", adapter)
session.headers["Content-type"] = "application/json"
WarningFilter.attach()
while True:
try:
report = cls.send_queue.get()
# Set a random backoff factor each time we send a report
adapter.max_retries.backoff_factor = random.random() * max_backoff
session.post(url, data=dumps(report))
except Exception as ex:
pass
@classmethod
def get_statistics(cls, company_id: str) -> dict:
"""
Returns a statistics report per company
"""
return {
"time": datetime.utcnow(),
"company_id": company_id,
"server": {
"version": current_version,
"deployment": get_deployment_type(),
"uuid": get_server_uuid(),
"queues": {"count": Queue.objects(company=company_id).count()},
"users": {"count": User.objects(company=company_id).count()},
"resources": cls.threads.resource_monitor.get_stats(),
"experiments": next(
iter(cls._get_experiments_stats(company_id).values()), {}
),
},
"agents": cls._get_agents_statistics(company_id),
}
@classmethod
def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]:
result = cls._get_resource_stats_per_agent(company_id, key="resources")
dpath.merge(
result, cls._get_experiments_stats_per_agent(company_id, key="experiments")
)
return [{"uuid": agent_id, **data} for agent_id, data in result.items()]
@classmethod
def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict:
agent_resource_threshold_sec = timedelta(
hours=config.get("apiserver.statistics.report_interval_hours", 24)
).total_seconds()
to_timestamp = int(time.time())
from_timestamp = to_timestamp - int(agent_resource_threshold_sec)
es_req = {
"size": 0,
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"categories": {
"terms": {"field": "category"},
"aggs": {"count": {"cardinality": {"field": "variant"}}},
},
"metrics": {
"terms": {"field": "metric"},
"aggs": {
"min": {"min": {"field": "value"}},
"max": {"max": {"field": "value"}},
"avg": {"avg": {"field": "value"}},
},
},
},
}
},
}
res = cls._run_worker_stats_query(company_id, es_req)
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
names = {"cpu": "num_cores"}
return {
names[c["key"]]: safe_get(c, "count/value")
for c in categories
if c["key"] in names
}
def _get_metric_fields(metrics: Sequence[dict]) -> dict:
names = {
"cpu_usage": "cpu_usage",
"memory_used": "mem_used_gb",
"memory_free": "mem_free_gb",
}
return {
names[m["key"]]: {
"min": safe_get(m, "min/value"),
"max": safe_get(m, "max/value"),
"avg": safe_get(m, "avg/value"),
}
for m in metrics
if m["key"] in names
}
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
return {
b["key"]: {
key: {
"interval_sec": agent_resource_threshold_sec,
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
}
}
for b in buckets
}
@classmethod
def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict:
agent_relevant_threshold = timedelta(
days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30)
)
to_timestamp = int(time.time())
from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds())
workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp)
if not workers:
return {}
stats = cls._get_experiments_stats(company_id, list(workers.keys()))
return {
worker_id: {key: {**workers[worker_id], **stat}}
for worker_id, stat in stats.items()
}
@classmethod
def _get_active_workers(
cls, company_id, from_timestamp: int, to_timestamp: int
) -> dict:
es_req = {
"size": 0,
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {"last_activity_time": {"max": {"field": "timestamp"}}},
}
},
}
res = cls._run_worker_stats_query(company_id, es_req)
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
return {
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
for b in buckets
}
@classmethod
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
return worker_bll.es_client.search(
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)
@classmethod
def _get_experiments_stats(
cls, company_id, workers: Optional[Sequence] = None
) -> dict:
pipeline = [
{
"$match": {
"company": company_id,
"started": {"$exists": True, "$ne": None},
"last_update": {"$exists": True, "$ne": None},
"status": {"$nin": ["created", "queued"]},
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,
"count": {"$sum": 1},
"avg_run_time_sec": {
"$avg": {
"$divide": [
{"$subtract": ["$last_update", "$started"]},
1000,
]
}
},
"avg_iterations": {"$avg": "$last_iteration"},
}
},
{
"$project": {
"count": 1,
"avg_run_time_sec": {"$trunc": "$avg_run_time_sec"},
"avg_iterations": {"$trunc": "$avg_iterations"},
}
},
]
return {
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
for group in Task.aggregate(*pipeline)
}
class WarningFilter(logging.Filter):
@classmethod
def attach(cls):
from urllib3.connectionpool import (
ConnectionPool,
) # required to make sure the logger is created
assert ConnectionPool # make sure import is not optimized out
logging.getLogger("urllib3.connectionpool").addFilter(cls())
def filter(self, record):
if (
record.levelno == logging.WARNING
and len(record.args) > 2
and record.args[2] == "/stats"
):
return False
return True

View File

@ -101,4 +101,18 @@
# GET request timeout # GET request timeout
request_timeout_sec: 3.0 request_timeout_sec: 3.0
} }
statistics {
# Note: statistics are sent ONLY if the user has actively opted-in
supported: true
url: "https://updates.trains.allegro.ai/stats"
report_interval_hours: 24
agent_relevant_threshold_days: 30
max_retries: 5
max_backoff_sec: 5
}
} }

View File

@ -1,5 +1,6 @@
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from os import getenv
root = Path(__file__).parent.parent root = Path(__file__).parent.parent
@ -26,3 +27,17 @@ def get_commit_number():
return (root / "COMMIT").read_text().strip() return (root / "COMMIT").read_text().strip()
except FileNotFoundError: except FileNotFoundError:
return "" return ""
@lru_cache()
def get_deployment_type() -> str:
value = getenv("TRAINS_SERVER_DEPLOYMENT_TYPE")
if value:
return value
try:
value = (root / "DEPLOY").read_text().strip()
except FileNotFoundError:
pass
return value or "manual"

View File

@ -1,23 +1,36 @@
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q from mongoengine import (
Document,
EmbeddedDocument,
EmbeddedDocumentField,
StringField,
Q,
BooleanField,
DateTimeField,
)
from database import Database, strict from database import Database, strict
from database.fields import StrippedStringField from database.fields import StrippedStringField
from database.model import DbModelMixin from database.model import DbModelMixin
class ReportStatsOption(EmbeddedDocument):
enabled = BooleanField(default=False) # opt-in for statistics reporting
enabled_version = StringField() # server version when enabled
enabled_time = DateTimeField() # time when enabled
enabled_user = StringField() # ID of user who enabled
class CompanyDefaults(EmbeddedDocument): class CompanyDefaults(EmbeddedDocument):
cluster = StringField() cluster = StringField()
stats_option = EmbeddedDocumentField(ReportStatsOption, default=ReportStatsOption)
class Company(DbModelMixin, Document): class Company(DbModelMixin, Document):
meta = { meta = {"db_alias": Database.backend, "strict": strict}
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True) id = StringField(primary_key=True)
name = StrippedStringField(unique=True, min_length=3) name = StrippedStringField(unique=True, min_length=3)
defaults = EmbeddedDocumentField(CompanyDefaults) defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults)
@classmethod @classmethod
def _prepare_perm_query(cls, company, allow_public=False): def _prepare_perm_query(cls, company, allow_public=False):

View File

@ -0,0 +1,57 @@
from typing import Any, Optional, Sequence, Tuple
from mongoengine import Document, StringField, DynamicField, Q
from mongoengine.errors import NotUniqueError
from database import Database, strict
from database.model import DbModelMixin
class Settings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
"strict": strict,
}
key = StringField(primary_key=True)
value = DynamicField()
@classmethod
def get_by_key(cls, key: str, default: Optional[Any] = None, sep: str = ".") -> Any:
key = key.strip(sep)
res = Settings.objects(key=key).first()
if not res:
return default
return res.value
@classmethod
def get_by_prefix(
cls, key_prefix: str, default: Optional[Any] = None, sep: str = "."
) -> Sequence[Tuple[str, Any]]:
key_prefix = key_prefix.strip(sep)
query = Q(key=key_prefix) | Q(key__startswith=key_prefix + sep)
res = Settings.objects(query)
if not res:
return default
return [(x.key, x.value) for x in res]
@classmethod
def set_or_add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
""" Sets a new value or adds a new key/value setting (if key does not exist) """
key = key.strip(sep)
res = Settings.objects(key=key).update(key=key, value=value, upsert=True)
# if Settings.objects(key=key).only("key"):
#
# else:
# res = Settings(key=key, value=value).save()
return bool(res)
@classmethod
def add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
""" Adds a new key/value settings. Fails if key already exists. """
key = key.strip(sep)
try:
res = Settings(key=key, value=value).save(force_insert=True)
return bool(res)
except NotUniqueError:
return False

View File

@ -1,6 +1,7 @@
import importlib.util import importlib.util
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from uuid import uuid4
import attr import attr
from furl import furl from furl import furl
@ -15,6 +16,7 @@ from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials from database.model.auth import User as AuthUser, Credentials
from database.model.company import Company from database.model.company import Company
from database.model.queue import Queue from database.model.queue import Queue
from database.model.settings import Settings
from database.model.user import User from database.model.user import User
from database.model.version import Version as DatabaseVersion from database.model.version import Version as DatabaseVersion
from elastic.apply_mappings import apply_mappings_to_host from elastic.apply_mappings import apply_mappings_to_host
@ -109,10 +111,7 @@ def _ensure_user(user: FixedUser, company_id: str):
data["email"] = f"{user.user_id}@example.com" data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user data["role"] = Role.user
_ensure_auth_user( _ensure_auth_user(user_data=data, company_id=company_id)
user_data=data,
company_id=company_id,
)
given_name, _, family_name = user.name.partition(" ") given_name, _, family_name = user.name.partition(" ")
@ -142,9 +141,7 @@ def _apply_migrations():
try: try:
new_scripts = { new_scripts = {
ver: path ver: path
for ver, path in ( for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
(Version(f.stem), f) for f in migration_dir.glob("*.py")
)
if ver > last_version if ver > last_version
} }
except ValueError as ex: except ValueError as ex:
@ -179,16 +176,30 @@ def _apply_migrations():
).save() ).save()
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))
def init_mongo_data(): def init_mongo_data():
try: try:
_apply_migrations() _apply_migrations()
_ensure_uuid()
company_id = _ensure_company() company_id = _ensure_company()
_ensure_default_queue(company_id) _ensure_default_queue(company_id)
users = [ users = [
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"}, {
{"name": "webserver", "role": Role.system, "email": "webserver@example.com"}, "name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"}, {"name": "tests", "role": Role.user, "email": "tests@example.com"},
] ]

View File

@ -3,6 +3,25 @@ _default {
internal: true internal: true
allow_roles: ["root", "system"] allow_roles: ["root", "system"]
} }
get_stats {
"2.1" {
description: "Get the server collected statistics."
request {
type: object
properties {
interval {
description: "The period for statistics collection in seconds."
type: long
}
}
}
response {
type: object
properties: {
}
}
}
}
config { config {
"2.1" { "2.1" {
description: "Get server configuration. Secure section is not returned." description: "Get server configuration. Secure section is not returned."
@ -66,3 +85,39 @@ endpoints {
} }
} }
} }
report_stats_option {
"2.4" {
description: "Get or set the report statistics option per-company"
request {
type: object
properties {
enabled {
description: "If provided, sets the report statistics option (true/false)"
type: boolean
}
}
}
response {
type: object
properties {
enabled {
description: "Returns the current report stats option value"
type: boolean
}
enabled_time {
description: "If enabled, returns the time at which option was enabled"
type: string
format: date-time
}
enabled_version {
description: "If enabled, returns the server version at the time option was enabled"
type: string
}
enabled_user {
description: "If enabled, returns Id of the user who enabled the option"
type: string
}
}
}
}
}

View File

@ -7,15 +7,15 @@ from werkzeug.exceptions import BadRequest
import database import database
from apierrors.base import BaseError from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config from config import config
from init_data import init_es_data, init_mongo_data
from service_repo import ServiceRepo, APICall from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType from service_repo.auth import AuthType
from service_repo.errors import PathParsingError from service_repo.errors import PathParsingError
from timing_context import TimingContext from timing_context import TimingContext
from utilities import json
from init_data import init_es_data, init_mongo_data
from updates import check_updates_thread from updates import check_updates_thread
from utilities import json
app = Flask(__name__, static_url_path="/static") app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors")) CORS(app, **config.get("apiserver.cors"))
@ -38,6 +38,7 @@ log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
check_updates_thread.start() check_updates_thread.start()
StatisticsReporter.start()
@app.before_first_request @app.before_first_request
@ -57,7 +58,9 @@ def before_request():
content, content_type = ServiceRepo.handle_call(call) content, content_type = ServiceRepo.handle_call(call)
headers = {} headers = {}
if call.result.filename: if call.result.filename:
headers["Content-Disposition"] = f"attachment; filename={call.result.filename}" headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers: if call.result.headers:
headers.update(call.result.headers) headers.update(call.result.headers)
@ -71,7 +74,9 @@ def before_request():
if value is None: if value is None:
response.set_cookie(key, "", expires=0) response.set_cookie(key, "", expires=0)
else: else:
response.set_cookie(key, value, **config.get("apiserver.auth.cookies")) response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response return response
except Exception as ex: except Exception as ex:

View File

@ -1,8 +1,24 @@
from datetime import datetime
from pyhocon.config_tree import NoneValue from pyhocon.config_tree import NoneValue
from apierrors import errors
from apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
from bll.statistics.stats_reporter import StatisticsReporter
from config import config from config import config
from config.info import get_version, get_build_number, get_commit_number from config.info import get_version, get_build_number, get_commit_number
from database.errors import translate_errors_context
from database.model import Company
from database.model.company import ReportStatsOption
from service_repo import ServiceRepo, APICall, endpoint from service_repo import ServiceRepo, APICall, endpoint
from version import __version__ as current_version
@endpoint("server.get_stats")
def get_stats(call: APICall):
call.result.data = StatisticsReporter.get_statistics(
company_id=call.identity.company
)
@endpoint("server.config") @endpoint("server.config")
@ -43,3 +59,35 @@ def info(call: APICall):
"build": get_build_number(), "build": get_build_number(),
"commit": get_commit_number(), "commit": get_commit_number(),
} }
@endpoint(
"server.report_stats_option",
request_data_model=ReportStatsOptionRequest,
response_data_model=ReportStatsOptionResponse,
)
def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest):
if not StatisticsReporter.supported:
result = ReportStatsOptionResponse(supported=False)
else:
enabled = request.enabled
with translate_errors_context():
query = Company.objects(id=company)
if enabled is None:
stats_option = query.first().defaults.stats_option
else:
stats_option = ReportStatsOption(
enabled=enabled,
enabled_time=datetime.utcnow(),
enabled_version=current_version,
enabled_user=call.identity.user,
)
updated = query.update(defaults__stats_option=stats_option)
if not updated:
raise errors.server_error.InternalError(
f"Failed setting report_stats to {enabled}"
)
result = ReportStatsOptionResponse(**stats_option.to_mongo())
call.result.data_model = result

View File

@ -108,7 +108,7 @@ class TestWorkersService(TestService):
from_date = to_date - timedelta(days=1) from_date = to_date - timedelta(days=1)
# no variants # no variants
res = self.api.workers.get_stats( res = self.api.workers.get_statistics(
items=[ items=[
dict(key="cpu_usage", aggregation="avg"), dict(key="cpu_usage", aggregation="avg"),
dict(key="cpu_usage", aggregation="max"), dict(key="cpu_usage", aggregation="max"),
@ -142,7 +142,7 @@ class TestWorkersService(TestService):
) )
# split by variants # split by variants
res = self.api.workers.get_stats( res = self.api.workers.get_statistics(
items=[dict(key="cpu_usage", aggregation="avg")], items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(), from_date=from_date.timestamp(),
to_date=to_date.timestamp(), to_date=to_date.timestamp(),
@ -165,7 +165,7 @@ class TestWorkersService(TestService):
assert all(_check_metric_and_variants(worker) for worker in res["workers"]) assert all(_check_metric_and_variants(worker) for worker in res["workers"])
res = self.api.workers.get_stats( res = self.api.workers.get_statistics(
items=[dict(key="cpu_usage", aggregation="avg")], items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(), from_date=from_date.timestamp(),
to_date=to_date.timestamp(), to_date=to_date.timestamp(),
@ -183,12 +183,13 @@ class TestWorkersService(TestService):
# run on an empty es db since we have no way # run on an empty es db since we have no way
# to pass non existing workers to this api # to pass non existing workers to this api
# res = self.api.workers.get_activity_report( # res = self.api.workers.get_activity_report(
# from_date=from_date.timestamp(), # from_timestamp=from_timestamp.timestamp(),
# to_date=to_date.timestamp(), # to_timestamp=to_timestamp.timestamp(),
# interval=20, # interval=20,
# ) # )
self._simulate_workers() self._simulate_workers()
to_date = utc_now_tz_aware() to_date = utc_now_tz_aware()
from_date = to_date - timedelta(minutes=10) from_date = to_date - timedelta(minutes=10)

View File

@ -8,6 +8,7 @@ import requests
from semantic_version import Version from semantic_version import Version
from config import config from config import config
from database.model.settings import Settings
from version import __version__ as current_version from version import __version__ as current_version
log = config.logger(__name__) log = config.logger(__name__)
@ -39,12 +40,15 @@ class CheckUpdatesThread(Thread):
def _check_new_version_available(self) -> Optional[_VersionResponse]: def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get( url = config.get(
"apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates" "apiserver.check_for_updates.url",
"https://updates.trains.allegro.ai/updates",
) )
uid = Settings.get_by_key("server.uuid")
response = requests.get( response = requests.get(
url, url,
json={"versions": {self.component_name: str(current_version)}}, json={"versions": {self.component_name: str(current_version)}, "uid": uid},
timeout=float( timeout=float(
config.get("apiserver.check_for_updates.request_timeout_sec", 3.0) config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)
), ),