From 0ad687008c545827e6214f584d88257c9ea9f8db Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 14 Dec 2019 23:33:04 +0200 Subject: [PATCH] Improve server update checks --- server/apimodels/server.py | 14 + server/bll/queue/queue_metrics.py | 2 +- server/bll/statistics/stats_reporter.py | 306 +++++++++++++++++++++ server/config/default/apiserver.conf | 14 + server/config/info.py | 15 + server/database/model/company.py | 25 +- server/database/model/settings.py | 57 ++++ server/init_data.py | 29 +- server/{service_repo => }/redis_manager.py | 0 server/schema/services/server.conf | 55 ++++ server/server.py | 15 +- server/services/server/__init__.py | 48 ++++ server/tests/automated/test_workers.py | 11 +- server/updates.py | 8 +- 14 files changed, 571 insertions(+), 28 deletions(-) create mode 100644 server/apimodels/server.py create mode 100644 server/bll/statistics/stats_reporter.py create mode 100644 server/database/model/settings.py rename server/{service_repo => }/redis_manager.py (100%) diff --git a/server/apimodels/server.py b/server/apimodels/server.py new file mode 100644 index 0000000..53eea57 --- /dev/null +++ b/server/apimodels/server.py @@ -0,0 +1,14 @@ +from jsonmodels.fields import BoolField, DateTimeField, StringField +from jsonmodels.models import Base + + +class ReportStatsOptionRequest(Base): + enabled = BoolField(default=None, nullable=True) + + +class ReportStatsOptionResponse(Base): + supported = BoolField(default=True) + enabled = BoolField() + enabled_time = DateTimeField(nullable=True) + enabled_version = StringField(nullable=True) + enabled_user = StringField(nullable=True) diff --git a/server/bll/queue/queue_metrics.py b/server/bll/queue/queue_metrics.py index 0c287e2..41d7df1 100644 --- a/server/bll/queue/queue_metrics.py +++ b/server/bll/queue/queue_metrics.py @@ -161,7 +161,7 @@ class QueueMetrics: In case no queue ids are specified the avg across all the company queues is calculated for each metric """ - # self._log_current_metrics(company_id, queue_ids=queue_ids) + # self._log_current_metrics(company, queue_ids=queue_ids) if from_date >= to_date: raise bad_request.FieldsValueError("from_date must be less than to_date") diff --git a/server/bll/statistics/stats_reporter.py b/server/bll/statistics/stats_reporter.py new file mode 100644 index 0000000..601f1a3 --- /dev/null +++ b/server/bll/statistics/stats_reporter.py @@ -0,0 +1,306 @@ +import logging +import queue +import random +import time +from datetime import timedelta, datetime +from time import sleep +from typing import Sequence, Optional + +import dpath +import requests +from requests.adapters import HTTPAdapter +from requests.packages.urllib3.util.retry import Retry + +from bll.query import Builder as QueryBuilder +from bll.util import get_server_uuid +from bll.workers import WorkerStats, WorkerBLL +from config import config +from config.info import get_deployment_type +from database.model import Company, User +from database.model.queue import Queue +from database.model.task.task import Task +from utilities import safe_get +from utilities.json import dumps +from utilities.threads_manager import ThreadsManager +from version import __version__ as current_version +from .resource_monitor import ResourceMonitor + +log = config.logger(__file__) + +worker_bll = WorkerBLL() + + +class StatisticsReporter: + threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor) + send_queue = queue.Queue() + supported = config.get("apiserver.statistics.supported", True) + + @classmethod + def start(cls): + cls.start_sender() + cls.start_reporter() + + @classmethod + @threads.register("reporter", daemon=True) + def start_reporter(cls): + """ + Periodically send statistics reports for companies who have opted in. + Note: in trains we usually have only a single company + """ + if not cls.supported: + return + + report_interval = timedelta( + hours=config.get("apiserver.statistics.report_interval_hours", 24) + ) + + while True: + + sleep(report_interval.total_seconds()) + + try: + for company in Company.objects( + defaults__stats_option__enabled=True + ).only("id"): + stats = cls.get_statistics(company.id) + cls.send_queue.put(stats) + + except Exception as ex: + log.exception(f"Failed collecting stats: {str(ex)}") + + @classmethod + @threads.register("sender", daemon=True) + def start_sender(cls): + if not cls.supported: + return + + url = config.get("apiserver.statistics.url") + + retries = config.get("apiserver.statistics.max_retries", 5) + max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5) + session = requests.Session() + adapter = HTTPAdapter(max_retries=Retry(retries)) + session.mount("http://", adapter) + session.mount("https://", adapter) + session.headers["Content-type"] = "application/json" + + WarningFilter.attach() + + while True: + try: + report = cls.send_queue.get() + + # Set a random backoff factor each time we send a report + adapter.max_retries.backoff_factor = random.random() * max_backoff + + session.post(url, data=dumps(report)) + + except Exception as ex: + pass + + @classmethod + def get_statistics(cls, company_id: str) -> dict: + """ + Returns a statistics report per company + """ + return { + "time": datetime.utcnow(), + "company_id": company_id, + "server": { + "version": current_version, + "deployment": get_deployment_type(), + "uuid": get_server_uuid(), + "queues": {"count": Queue.objects(company=company_id).count()}, + "users": {"count": User.objects(company=company_id).count()}, + "resources": cls.threads.resource_monitor.get_stats(), + "experiments": next( + iter(cls._get_experiments_stats(company_id).values()), {} + ), + }, + "agents": cls._get_agents_statistics(company_id), + } + + @classmethod + def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]: + result = cls._get_resource_stats_per_agent(company_id, key="resources") + dpath.merge( + result, cls._get_experiments_stats_per_agent(company_id, key="experiments") + ) + return [{"uuid": agent_id, **data} for agent_id, data in result.items()] + + @classmethod + def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict: + agent_resource_threshold_sec = timedelta( + hours=config.get("apiserver.statistics.report_interval_hours", 24) + ).total_seconds() + to_timestamp = int(time.time()) + from_timestamp = to_timestamp - int(agent_resource_threshold_sec) + es_req = { + "size": 0, + "query": QueryBuilder.dates_range(from_timestamp, to_timestamp), + "aggs": { + "workers": { + "terms": {"field": "worker"}, + "aggs": { + "categories": { + "terms": {"field": "category"}, + "aggs": {"count": {"cardinality": {"field": "variant"}}}, + }, + "metrics": { + "terms": {"field": "metric"}, + "aggs": { + "min": {"min": {"field": "value"}}, + "max": {"max": {"field": "value"}}, + "avg": {"avg": {"field": "value"}}, + }, + }, + }, + } + }, + } + res = cls._run_worker_stats_query(company_id, es_req) + + def _get_cardinality_fields(categories: Sequence[dict]) -> dict: + names = {"cpu": "num_cores"} + return { + names[c["key"]]: safe_get(c, "count/value") + for c in categories + if c["key"] in names + } + + def _get_metric_fields(metrics: Sequence[dict]) -> dict: + names = { + "cpu_usage": "cpu_usage", + "memory_used": "mem_used_gb", + "memory_free": "mem_free_gb", + } + return { + names[m["key"]]: { + "min": safe_get(m, "min/value"), + "max": safe_get(m, "max/value"), + "avg": safe_get(m, "avg/value"), + } + for m in metrics + if m["key"] in names + } + + buckets = safe_get(res, "aggregations/workers/buckets", default=[]) + return { + b["key"]: { + key: { + "interval_sec": agent_resource_threshold_sec, + **_get_cardinality_fields(safe_get(b, "categories/buckets", [])), + **_get_metric_fields(safe_get(b, "metrics/buckets", [])), + } + } + for b in buckets + } + + @classmethod + def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict: + agent_relevant_threshold = timedelta( + days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30) + ) + to_timestamp = int(time.time()) + from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds()) + workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp) + if not workers: + return {} + + stats = cls._get_experiments_stats(company_id, list(workers.keys())) + return { + worker_id: {key: {**workers[worker_id], **stat}} + for worker_id, stat in stats.items() + } + + @classmethod + def _get_active_workers( + cls, company_id, from_timestamp: int, to_timestamp: int + ) -> dict: + es_req = { + "size": 0, + "query": QueryBuilder.dates_range(from_timestamp, to_timestamp), + "aggs": { + "workers": { + "terms": {"field": "worker"}, + "aggs": {"last_activity_time": {"max": {"field": "timestamp"}}}, + } + }, + } + res = cls._run_worker_stats_query(company_id, es_req) + buckets = safe_get(res, "aggregations/workers/buckets", default=[]) + return { + b["key"]: {"last_activity_time": b["last_activity_time"]["value"]} + for b in buckets + } + + @classmethod + def _run_worker_stats_query(cls, company_id, es_req) -> dict: + return worker_bll.es_client.search( + index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*", + doc_type="stat", + body=es_req, + ) + + @classmethod + def _get_experiments_stats( + cls, company_id, workers: Optional[Sequence] = None + ) -> dict: + pipeline = [ + { + "$match": { + "company": company_id, + "started": {"$exists": True, "$ne": None}, + "last_update": {"$exists": True, "$ne": None}, + "status": {"$nin": ["created", "queued"]}, + **({"last_worker": {"$in": workers}} if workers else {}), + } + }, + { + "$group": { + "_id": "$last_worker" if workers else None, + "count": {"$sum": 1}, + "avg_run_time_sec": { + "$avg": { + "$divide": [ + {"$subtract": ["$last_update", "$started"]}, + 1000, + ] + } + }, + "avg_iterations": {"$avg": "$last_iteration"}, + } + }, + { + "$project": { + "count": 1, + "avg_run_time_sec": {"$trunc": "$avg_run_time_sec"}, + "avg_iterations": {"$trunc": "$avg_iterations"}, + } + }, + ] + return { + group["_id"]: {k: v for k, v in group.items() if k != "_id"} + for group in Task.aggregate(*pipeline) + } + + +class WarningFilter(logging.Filter): + @classmethod + def attach(cls): + from urllib3.connectionpool import ( + ConnectionPool, + ) # required to make sure the logger is created + + assert ConnectionPool # make sure import is not optimized out + + logging.getLogger("urllib3.connectionpool").addFilter(cls()) + + def filter(self, record): + if ( + record.levelno == logging.WARNING + and len(record.args) > 2 + and record.args[2] == "/stats" + ): + return False + return True diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index 7bf2f75..f5be1e3 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -101,4 +101,18 @@ # GET request timeout request_timeout_sec: 3.0 } + + statistics { + # Note: statistics are sent ONLY if the user has actively opted-in + supported: true + + url: "https://updates.trains.allegro.ai/stats" + + report_interval_hours: 24 + agent_relevant_threshold_days: 30 + + max_retries: 5 + max_backoff_sec: 5 + } + } diff --git a/server/config/info.py b/server/config/info.py index 2475ec5..4569159 100644 --- a/server/config/info.py +++ b/server/config/info.py @@ -1,5 +1,6 @@ from functools import lru_cache from pathlib import Path +from os import getenv root = Path(__file__).parent.parent @@ -26,3 +27,17 @@ def get_commit_number(): return (root / "COMMIT").read_text().strip() except FileNotFoundError: return "" + + +@lru_cache() +def get_deployment_type() -> str: + value = getenv("TRAINS_SERVER_DEPLOYMENT_TYPE") + if value: + return value + + try: + value = (root / "DEPLOY").read_text().strip() + except FileNotFoundError: + pass + + return value or "manual" diff --git a/server/database/model/company.py b/server/database/model/company.py index b4aae5c..4eaf26b 100644 --- a/server/database/model/company.py +++ b/server/database/model/company.py @@ -1,23 +1,36 @@ -from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q +from mongoengine import ( + Document, + EmbeddedDocument, + EmbeddedDocumentField, + StringField, + Q, + BooleanField, + DateTimeField, +) from database import Database, strict from database.fields import StrippedStringField from database.model import DbModelMixin +class ReportStatsOption(EmbeddedDocument): + enabled = BooleanField(default=False) # opt-in for statistics reporting + enabled_version = StringField() # server version when enabled + enabled_time = DateTimeField() # time when enabled + enabled_user = StringField() # ID of user who enabled + + class CompanyDefaults(EmbeddedDocument): cluster = StringField() + stats_option = EmbeddedDocumentField(ReportStatsOption, default=ReportStatsOption) class Company(DbModelMixin, Document): - meta = { - 'db_alias': Database.backend, - 'strict': strict, - } + meta = {"db_alias": Database.backend, "strict": strict} id = StringField(primary_key=True) name = StrippedStringField(unique=True, min_length=3) - defaults = EmbeddedDocumentField(CompanyDefaults) + defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults) @classmethod def _prepare_perm_query(cls, company, allow_public=False): diff --git a/server/database/model/settings.py b/server/database/model/settings.py new file mode 100644 index 0000000..76675ae --- /dev/null +++ b/server/database/model/settings.py @@ -0,0 +1,57 @@ +from typing import Any, Optional, Sequence, Tuple + +from mongoengine import Document, StringField, DynamicField, Q +from mongoengine.errors import NotUniqueError + +from database import Database, strict +from database.model import DbModelMixin + + +class Settings(DbModelMixin, Document): + meta = { + "db_alias": Database.backend, + "strict": strict, + } + + key = StringField(primary_key=True) + value = DynamicField() + + @classmethod + def get_by_key(cls, key: str, default: Optional[Any] = None, sep: str = ".") -> Any: + key = key.strip(sep) + res = Settings.objects(key=key).first() + if not res: + return default + return res.value + + @classmethod + def get_by_prefix( + cls, key_prefix: str, default: Optional[Any] = None, sep: str = "." + ) -> Sequence[Tuple[str, Any]]: + key_prefix = key_prefix.strip(sep) + query = Q(key=key_prefix) | Q(key__startswith=key_prefix + sep) + res = Settings.objects(query) + if not res: + return default + return [(x.key, x.value) for x in res] + + @classmethod + def set_or_add_value(cls, key: str, value: Any, sep: str = ".") -> bool: + """ Sets a new value or adds a new key/value setting (if key does not exist) """ + key = key.strip(sep) + res = Settings.objects(key=key).update(key=key, value=value, upsert=True) + # if Settings.objects(key=key).only("key"): + # + # else: + # res = Settings(key=key, value=value).save() + return bool(res) + + @classmethod + def add_value(cls, key: str, value: Any, sep: str = ".") -> bool: + """ Adds a new key/value settings. Fails if key already exists. """ + key = key.strip(sep) + try: + res = Settings(key=key, value=value).save(force_insert=True) + return bool(res) + except NotUniqueError: + return False diff --git a/server/init_data.py b/server/init_data.py index 599eaf5..0e81b84 100644 --- a/server/init_data.py +++ b/server/init_data.py @@ -1,6 +1,7 @@ import importlib.util from datetime import datetime from pathlib import Path +from uuid import uuid4 import attr from furl import furl @@ -15,6 +16,7 @@ from database.model.auth import Role from database.model.auth import User as AuthUser, Credentials from database.model.company import Company from database.model.queue import Queue +from database.model.settings import Settings from database.model.user import User from database.model.version import Version as DatabaseVersion from elastic.apply_mappings import apply_mappings_to_host @@ -109,10 +111,7 @@ def _ensure_user(user: FixedUser, company_id: str): data["email"] = f"{user.user_id}@example.com" data["role"] = Role.user - _ensure_auth_user( - user_data=data, - company_id=company_id, - ) + _ensure_auth_user(user_data=data, company_id=company_id) given_name, _, family_name = user.name.partition(" ") @@ -142,9 +141,7 @@ def _apply_migrations(): try: new_scripts = { ver: path - for ver, path in ( - (Version(f.stem), f) for f in migration_dir.glob("*.py") - ) + for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py")) if ver > last_version } except ValueError as ex: @@ -179,16 +176,30 @@ def _apply_migrations(): ).save() +def _ensure_uuid(): + Settings.add_value("server.uuid", str(uuid4())) + + def init_mongo_data(): try: _apply_migrations() + _ensure_uuid() + company_id = _ensure_company() _ensure_default_queue(company_id) users = [ - {"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"}, - {"name": "webserver", "role": Role.system, "email": "webserver@example.com"}, + { + "name": "apiserver", + "role": Role.system, + "email": "apiserver@example.com", + }, + { + "name": "webserver", + "role": Role.system, + "email": "webserver@example.com", + }, {"name": "tests", "role": Role.user, "email": "tests@example.com"}, ] diff --git a/server/service_repo/redis_manager.py b/server/redis_manager.py similarity index 100% rename from server/service_repo/redis_manager.py rename to server/redis_manager.py diff --git a/server/schema/services/server.conf b/server/schema/services/server.conf index ea8b6b5..fff029c 100644 --- a/server/schema/services/server.conf +++ b/server/schema/services/server.conf @@ -3,6 +3,25 @@ _default { internal: true allow_roles: ["root", "system"] } +get_stats { + "2.1" { + description: "Get the server collected statistics." + request { + type: object + properties { + interval { + description: "The period for statistics collection in seconds." + type: long + } + } + } + response { + type: object + properties: { + } + } + } +} config { "2.1" { description: "Get server configuration. Secure section is not returned." @@ -66,3 +85,39 @@ endpoints { } } } +report_stats_option { + "2.4" { + description: "Get or set the report statistics option per-company" + request { + type: object + properties { + enabled { + description: "If provided, sets the report statistics option (true/false)" + type: boolean + } + } + } + response { + type: object + properties { + enabled { + description: "Returns the current report stats option value" + type: boolean + } + enabled_time { + description: "If enabled, returns the time at which option was enabled" + type: string + format: date-time + } + enabled_version { + description: "If enabled, returns the server version at the time option was enabled" + type: string + } + enabled_user { + description: "If enabled, returns Id of the user who enabled the option" + type: string + } + } + } + } +} \ No newline at end of file diff --git a/server/server.py b/server/server.py index 937163e..16a903b 100644 --- a/server/server.py +++ b/server/server.py @@ -7,15 +7,15 @@ from werkzeug.exceptions import BadRequest import database from apierrors.base import BaseError +from bll.statistics.stats_reporter import StatisticsReporter from config import config +from init_data import init_es_data, init_mongo_data from service_repo import ServiceRepo, APICall from service_repo.auth import AuthType from service_repo.errors import PathParsingError from timing_context import TimingContext -from utilities import json -from init_data import init_es_data, init_mongo_data from updates import check_updates_thread - +from utilities import json app = Flask(__name__, static_url_path="/static") CORS(app, **config.get("apiserver.cors")) @@ -38,6 +38,7 @@ log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}") check_updates_thread.start() +StatisticsReporter.start() @app.before_first_request @@ -57,7 +58,9 @@ def before_request(): content, content_type = ServiceRepo.handle_call(call) headers = {} if call.result.filename: - headers["Content-Disposition"] = f"attachment; filename={call.result.filename}" + headers[ + "Content-Disposition" + ] = f"attachment; filename={call.result.filename}" if call.result.headers: headers.update(call.result.headers) @@ -71,7 +74,9 @@ def before_request(): if value is None: response.set_cookie(key, "", expires=0) else: - response.set_cookie(key, value, **config.get("apiserver.auth.cookies")) + response.set_cookie( + key, value, **config.get("apiserver.auth.cookies") + ) return response except Exception as ex: diff --git a/server/services/server/__init__.py b/server/services/server/__init__.py index fd914d6..24643bb 100644 --- a/server/services/server/__init__.py +++ b/server/services/server/__init__.py @@ -1,8 +1,24 @@ +from datetime import datetime + from pyhocon.config_tree import NoneValue +from apierrors import errors +from apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse +from bll.statistics.stats_reporter import StatisticsReporter from config import config from config.info import get_version, get_build_number, get_commit_number +from database.errors import translate_errors_context +from database.model import Company +from database.model.company import ReportStatsOption from service_repo import ServiceRepo, APICall, endpoint +from version import __version__ as current_version + + +@endpoint("server.get_stats") +def get_stats(call: APICall): + call.result.data = StatisticsReporter.get_statistics( + company_id=call.identity.company + ) @endpoint("server.config") @@ -43,3 +59,35 @@ def info(call: APICall): "build": get_build_number(), "commit": get_commit_number(), } + + +@endpoint( + "server.report_stats_option", + request_data_model=ReportStatsOptionRequest, + response_data_model=ReportStatsOptionResponse, +) +def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest): + if not StatisticsReporter.supported: + result = ReportStatsOptionResponse(supported=False) + else: + enabled = request.enabled + with translate_errors_context(): + query = Company.objects(id=company) + if enabled is None: + stats_option = query.first().defaults.stats_option + else: + stats_option = ReportStatsOption( + enabled=enabled, + enabled_time=datetime.utcnow(), + enabled_version=current_version, + enabled_user=call.identity.user, + ) + updated = query.update(defaults__stats_option=stats_option) + if not updated: + raise errors.server_error.InternalError( + f"Failed setting report_stats to {enabled}" + ) + + result = ReportStatsOptionResponse(**stats_option.to_mongo()) + + call.result.data_model = result diff --git a/server/tests/automated/test_workers.py b/server/tests/automated/test_workers.py index 935370e..02c9fbf 100644 --- a/server/tests/automated/test_workers.py +++ b/server/tests/automated/test_workers.py @@ -108,7 +108,7 @@ class TestWorkersService(TestService): from_date = to_date - timedelta(days=1) # no variants - res = self.api.workers.get_stats( + res = self.api.workers.get_statistics( items=[ dict(key="cpu_usage", aggregation="avg"), dict(key="cpu_usage", aggregation="max"), @@ -142,7 +142,7 @@ class TestWorkersService(TestService): ) # split by variants - res = self.api.workers.get_stats( + res = self.api.workers.get_statistics( items=[dict(key="cpu_usage", aggregation="avg")], from_date=from_date.timestamp(), to_date=to_date.timestamp(), @@ -165,7 +165,7 @@ class TestWorkersService(TestService): assert all(_check_metric_and_variants(worker) for worker in res["workers"]) - res = self.api.workers.get_stats( + res = self.api.workers.get_statistics( items=[dict(key="cpu_usage", aggregation="avg")], from_date=from_date.timestamp(), to_date=to_date.timestamp(), @@ -183,12 +183,13 @@ class TestWorkersService(TestService): # run on an empty es db since we have no way # to pass non existing workers to this api # res = self.api.workers.get_activity_report( - # from_date=from_date.timestamp(), - # to_date=to_date.timestamp(), + # from_timestamp=from_timestamp.timestamp(), + # to_timestamp=to_timestamp.timestamp(), # interval=20, # ) self._simulate_workers() + to_date = utc_now_tz_aware() from_date = to_date - timedelta(minutes=10) diff --git a/server/updates.py b/server/updates.py index 76b5fc6..3144d7e 100644 --- a/server/updates.py +++ b/server/updates.py @@ -8,6 +8,7 @@ import requests from semantic_version import Version from config import config +from database.model.settings import Settings from version import __version__ as current_version log = config.logger(__name__) @@ -39,12 +40,15 @@ class CheckUpdatesThread(Thread): def _check_new_version_available(self) -> Optional[_VersionResponse]: url = config.get( - "apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates" + "apiserver.check_for_updates.url", + "https://updates.trains.allegro.ai/updates", ) + uid = Settings.get_by_key("server.uuid") + response = requests.get( url, - json={"versions": {self.component_name: str(current_version)}}, + json={"versions": {self.component_name: str(current_version)}, "uid": uid}, timeout=float( config.get("apiserver.check_for_updates.request_timeout_sec", 3.0) ),