mirror of
https://github.com/clearml/clearml-server
synced 2025-02-26 05:59:20 +00:00
Improve server update checks
This commit is contained in:
parent
fe3dbc92dc
commit
0ad687008c
14
server/apimodels/server.py
Normal file
14
server/apimodels/server.py
Normal file
@ -0,0 +1,14 @@
|
||||
from jsonmodels.fields import BoolField, DateTimeField, StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
|
||||
class ReportStatsOptionRequest(Base):
|
||||
enabled = BoolField(default=None, nullable=True)
|
||||
|
||||
|
||||
class ReportStatsOptionResponse(Base):
|
||||
supported = BoolField(default=True)
|
||||
enabled = BoolField()
|
||||
enabled_time = DateTimeField(nullable=True)
|
||||
enabled_version = StringField(nullable=True)
|
||||
enabled_user = StringField(nullable=True)
|
@ -161,7 +161,7 @@ class QueueMetrics:
|
||||
In case no queue ids are specified the avg across all the
|
||||
company queues is calculated for each metric
|
||||
"""
|
||||
# self._log_current_metrics(company_id, queue_ids=queue_ids)
|
||||
# self._log_current_metrics(company, queue_ids=queue_ids)
|
||||
|
||||
if from_date >= to_date:
|
||||
raise bad_request.FieldsValueError("from_date must be less than to_date")
|
||||
|
306
server/bll/statistics/stats_reporter.py
Normal file
306
server/bll/statistics/stats_reporter.py
Normal file
@ -0,0 +1,306 @@
|
||||
import logging
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import dpath
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
from bll.query import Builder as QueryBuilder
|
||||
from bll.util import get_server_uuid
|
||||
from bll.workers import WorkerStats, WorkerBLL
|
||||
from config import config
|
||||
from config.info import get_deployment_type
|
||||
from database.model import Company, User
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from utilities import safe_get
|
||||
from utilities.json import dumps
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from version import __version__ as current_version
|
||||
from .resource_monitor import ResourceMonitor
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
worker_bll = WorkerBLL()
|
||||
|
||||
|
||||
class StatisticsReporter:
|
||||
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
|
||||
send_queue = queue.Queue()
|
||||
supported = config.get("apiserver.statistics.supported", True)
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
cls.start_sender()
|
||||
cls.start_reporter()
|
||||
|
||||
@classmethod
|
||||
@threads.register("reporter", daemon=True)
|
||||
def start_reporter(cls):
|
||||
"""
|
||||
Periodically send statistics reports for companies who have opted in.
|
||||
Note: in trains we usually have only a single company
|
||||
"""
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
report_interval = timedelta(
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
sleep(report_interval.total_seconds())
|
||||
|
||||
try:
|
||||
for company in Company.objects(
|
||||
defaults__stats_option__enabled=True
|
||||
).only("id"):
|
||||
stats = cls.get_statistics(company.id)
|
||||
cls.send_queue.put(stats)
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed collecting stats: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
@threads.register("sender", daemon=True)
|
||||
def start_sender(cls):
|
||||
if not cls.supported:
|
||||
return
|
||||
|
||||
url = config.get("apiserver.statistics.url")
|
||||
|
||||
retries = config.get("apiserver.statistics.max_retries", 5)
|
||||
max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5)
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(retries))
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
session.headers["Content-type"] = "application/json"
|
||||
|
||||
WarningFilter.attach()
|
||||
|
||||
while True:
|
||||
try:
|
||||
report = cls.send_queue.get()
|
||||
|
||||
# Set a random backoff factor each time we send a report
|
||||
adapter.max_retries.backoff_factor = random.random() * max_backoff
|
||||
|
||||
session.post(url, data=dumps(report))
|
||||
|
||||
except Exception as ex:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_statistics(cls, company_id: str) -> dict:
|
||||
"""
|
||||
Returns a statistics report per company
|
||||
"""
|
||||
return {
|
||||
"time": datetime.utcnow(),
|
||||
"company_id": company_id,
|
||||
"server": {
|
||||
"version": current_version,
|
||||
"deployment": get_deployment_type(),
|
||||
"uuid": get_server_uuid(),
|
||||
"queues": {"count": Queue.objects(company=company_id).count()},
|
||||
"users": {"count": User.objects(company=company_id).count()},
|
||||
"resources": cls.threads.resource_monitor.get_stats(),
|
||||
"experiments": next(
|
||||
iter(cls._get_experiments_stats(company_id).values()), {}
|
||||
),
|
||||
},
|
||||
"agents": cls._get_agents_statistics(company_id),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]:
|
||||
result = cls._get_resource_stats_per_agent(company_id, key="resources")
|
||||
dpath.merge(
|
||||
result, cls._get_experiments_stats_per_agent(company_id, key="experiments")
|
||||
)
|
||||
return [{"uuid": agent_id, **data} for agent_id, data in result.items()]
|
||||
|
||||
@classmethod
|
||||
def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
||||
agent_resource_threshold_sec = timedelta(
|
||||
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
||||
).total_seconds()
|
||||
to_timestamp = int(time.time())
|
||||
from_timestamp = to_timestamp - int(agent_resource_threshold_sec)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {
|
||||
"categories": {
|
||||
"terms": {"field": "category"},
|
||||
"aggs": {"count": {"cardinality": {"field": "variant"}}},
|
||||
},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": {
|
||||
"min": {"min": {"field": "value"}},
|
||||
"max": {"max": {"field": "value"}},
|
||||
"avg": {"avg": {"field": "value"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
|
||||
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
|
||||
names = {"cpu": "num_cores"}
|
||||
return {
|
||||
names[c["key"]]: safe_get(c, "count/value")
|
||||
for c in categories
|
||||
if c["key"] in names
|
||||
}
|
||||
|
||||
def _get_metric_fields(metrics: Sequence[dict]) -> dict:
|
||||
names = {
|
||||
"cpu_usage": "cpu_usage",
|
||||
"memory_used": "mem_used_gb",
|
||||
"memory_free": "mem_free_gb",
|
||||
}
|
||||
return {
|
||||
names[m["key"]]: {
|
||||
"min": safe_get(m, "min/value"),
|
||||
"max": safe_get(m, "max/value"),
|
||||
"avg": safe_get(m, "avg/value"),
|
||||
}
|
||||
for m in metrics
|
||||
if m["key"] in names
|
||||
}
|
||||
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
return {
|
||||
b["key"]: {
|
||||
key: {
|
||||
"interval_sec": agent_resource_threshold_sec,
|
||||
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
|
||||
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
|
||||
}
|
||||
}
|
||||
for b in buckets
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
||||
agent_relevant_threshold = timedelta(
|
||||
days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30)
|
||||
)
|
||||
to_timestamp = int(time.time())
|
||||
from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds())
|
||||
workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp)
|
||||
if not workers:
|
||||
return {}
|
||||
|
||||
stats = cls._get_experiments_stats(company_id, list(workers.keys()))
|
||||
return {
|
||||
worker_id: {key: {**workers[worker_id], **stat}}
|
||||
for worker_id, stat in stats.items()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_active_workers(
|
||||
cls, company_id, from_timestamp: int, to_timestamp: int
|
||||
) -> dict:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
||||
"aggs": {
|
||||
"workers": {
|
||||
"terms": {"field": "worker"},
|
||||
"aggs": {"last_activity_time": {"max": {"field": "timestamp"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
res = cls._run_worker_stats_query(company_id, es_req)
|
||||
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
||||
return {
|
||||
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
|
||||
for b in buckets
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_experiments_stats(
|
||||
cls, company_id, workers: Optional[Sequence] = None
|
||||
) -> dict:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"started": {"$exists": True, "$ne": None},
|
||||
"last_update": {"$exists": True, "$ne": None},
|
||||
"status": {"$nin": ["created", "queued"]},
|
||||
**({"last_worker": {"$in": workers}} if workers else {}),
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$last_worker" if workers else None,
|
||||
"count": {"$sum": 1},
|
||||
"avg_run_time_sec": {
|
||||
"$avg": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$last_update", "$started"]},
|
||||
1000,
|
||||
]
|
||||
}
|
||||
},
|
||||
"avg_iterations": {"$avg": "$last_iteration"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"count": 1,
|
||||
"avg_run_time_sec": {"$trunc": "$avg_run_time_sec"},
|
||||
"avg_iterations": {"$trunc": "$avg_iterations"},
|
||||
}
|
||||
},
|
||||
]
|
||||
return {
|
||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||
for group in Task.aggregate(*pipeline)
|
||||
}
|
||||
|
||||
|
||||
class WarningFilter(logging.Filter):
|
||||
@classmethod
|
||||
def attach(cls):
|
||||
from urllib3.connectionpool import (
|
||||
ConnectionPool,
|
||||
) # required to make sure the logger is created
|
||||
|
||||
assert ConnectionPool # make sure import is not optimized out
|
||||
|
||||
logging.getLogger("urllib3.connectionpool").addFilter(cls())
|
||||
|
||||
def filter(self, record):
|
||||
if (
|
||||
record.levelno == logging.WARNING
|
||||
and len(record.args) > 2
|
||||
and record.args[2] == "/stats"
|
||||
):
|
||||
return False
|
||||
return True
|
@ -101,4 +101,18 @@
|
||||
# GET request timeout
|
||||
request_timeout_sec: 3.0
|
||||
}
|
||||
|
||||
statistics {
|
||||
# Note: statistics are sent ONLY if the user has actively opted-in
|
||||
supported: true
|
||||
|
||||
url: "https://updates.trains.allegro.ai/stats"
|
||||
|
||||
report_interval_hours: 24
|
||||
agent_relevant_threshold_days: 30
|
||||
|
||||
max_retries: 5
|
||||
max_backoff_sec: 5
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from os import getenv
|
||||
|
||||
root = Path(__file__).parent.parent
|
||||
|
||||
@ -26,3 +27,17 @@ def get_commit_number():
|
||||
return (root / "COMMIT").read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_deployment_type() -> str:
|
||||
value = getenv("TRAINS_SERVER_DEPLOYMENT_TYPE")
|
||||
if value:
|
||||
return value
|
||||
|
||||
try:
|
||||
value = (root / "DEPLOY").read_text().strip()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return value or "manual"
|
||||
|
@ -1,23 +1,36 @@
|
||||
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
EmbeddedDocumentField,
|
||||
StringField,
|
||||
Q,
|
||||
BooleanField,
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class ReportStatsOption(EmbeddedDocument):
|
||||
enabled = BooleanField(default=False) # opt-in for statistics reporting
|
||||
enabled_version = StringField() # server version when enabled
|
||||
enabled_time = DateTimeField() # time when enabled
|
||||
enabled_user = StringField() # ID of user who enabled
|
||||
|
||||
|
||||
class CompanyDefaults(EmbeddedDocument):
|
||||
cluster = StringField()
|
||||
stats_option = EmbeddedDocumentField(ReportStatsOption, default=ReportStatsOption)
|
||||
|
||||
|
||||
class Company(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
meta = {"db_alias": Database.backend, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(unique=True, min_length=3)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
|
57
server/database/model/settings.py
Normal file
57
server/database/model/settings.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mongoengine import Document, StringField, DynamicField, Q
|
||||
from mongoengine.errors import NotUniqueError
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class Settings(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
key = StringField(primary_key=True)
|
||||
value = DynamicField()
|
||||
|
||||
@classmethod
|
||||
def get_by_key(cls, key: str, default: Optional[Any] = None, sep: str = ".") -> Any:
|
||||
key = key.strip(sep)
|
||||
res = Settings.objects(key=key).first()
|
||||
if not res:
|
||||
return default
|
||||
return res.value
|
||||
|
||||
@classmethod
|
||||
def get_by_prefix(
|
||||
cls, key_prefix: str, default: Optional[Any] = None, sep: str = "."
|
||||
) -> Sequence[Tuple[str, Any]]:
|
||||
key_prefix = key_prefix.strip(sep)
|
||||
query = Q(key=key_prefix) | Q(key__startswith=key_prefix + sep)
|
||||
res = Settings.objects(query)
|
||||
if not res:
|
||||
return default
|
||||
return [(x.key, x.value) for x in res]
|
||||
|
||||
@classmethod
|
||||
def set_or_add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
|
||||
""" Sets a new value or adds a new key/value setting (if key does not exist) """
|
||||
key = key.strip(sep)
|
||||
res = Settings.objects(key=key).update(key=key, value=value, upsert=True)
|
||||
# if Settings.objects(key=key).only("key"):
|
||||
#
|
||||
# else:
|
||||
# res = Settings(key=key, value=value).save()
|
||||
return bool(res)
|
||||
|
||||
@classmethod
|
||||
def add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
|
||||
""" Adds a new key/value settings. Fails if key already exists. """
|
||||
key = key.strip(sep)
|
||||
try:
|
||||
res = Settings(key=key, value=value).save(force_insert=True)
|
||||
return bool(res)
|
||||
except NotUniqueError:
|
||||
return False
|
@ -1,6 +1,7 @@
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
@ -15,6 +16,7 @@ from database.model.auth import Role
|
||||
from database.model.auth import User as AuthUser, Credentials
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings
|
||||
from database.model.user import User
|
||||
from database.model.version import Version as DatabaseVersion
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
@ -109,10 +111,7 @@ def _ensure_user(user: FixedUser, company_id: str):
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
|
||||
_ensure_auth_user(
|
||||
user_data=data,
|
||||
company_id=company_id,
|
||||
)
|
||||
_ensure_auth_user(user_data=data, company_id=company_id)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
@ -142,9 +141,7 @@ def _apply_migrations():
|
||||
try:
|
||||
new_scripts = {
|
||||
ver: path
|
||||
for ver, path in (
|
||||
(Version(f.stem), f) for f in migration_dir.glob("*.py")
|
||||
)
|
||||
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
|
||||
if ver > last_version
|
||||
}
|
||||
except ValueError as ex:
|
||||
@ -179,16 +176,30 @@ def _apply_migrations():
|
||||
).save()
|
||||
|
||||
|
||||
def _ensure_uuid():
|
||||
Settings.add_value("server.uuid", str(uuid4()))
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
_apply_migrations()
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company()
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
users = [
|
||||
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
|
||||
{"name": "webserver", "role": Role.system, "email": "webserver@example.com"},
|
||||
{
|
||||
"name": "apiserver",
|
||||
"role": Role.system,
|
||||
"email": "apiserver@example.com",
|
||||
},
|
||||
{
|
||||
"name": "webserver",
|
||||
"role": Role.system,
|
||||
"email": "webserver@example.com",
|
||||
},
|
||||
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
|
||||
]
|
||||
|
||||
|
@ -3,6 +3,25 @@ _default {
|
||||
internal: true
|
||||
allow_roles: ["root", "system"]
|
||||
}
|
||||
get_stats {
|
||||
"2.1" {
|
||||
description: "Get the server collected statistics."
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
interval {
|
||||
description: "The period for statistics collection in seconds."
|
||||
type: long
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties: {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
config {
|
||||
"2.1" {
|
||||
description: "Get server configuration. Secure section is not returned."
|
||||
@ -66,3 +85,39 @@ endpoints {
|
||||
}
|
||||
}
|
||||
}
|
||||
report_stats_option {
|
||||
"2.4" {
|
||||
description: "Get or set the report statistics option per-company"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
enabled {
|
||||
description: "If provided, sets the report statistics option (true/false)"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
enabled {
|
||||
description: "Returns the current report stats option value"
|
||||
type: boolean
|
||||
}
|
||||
enabled_time {
|
||||
description: "If enabled, returns the time at which option was enabled"
|
||||
type: string
|
||||
format: date-time
|
||||
}
|
||||
enabled_version {
|
||||
description: "If enabled, returns the server version at the time option was enabled"
|
||||
type: string
|
||||
}
|
||||
enabled_user {
|
||||
description: "If enabled, returns Id of the user who enabled the option"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -7,15 +7,15 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from init_data import init_es_data, init_mongo_data
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from timing_context import TimingContext
|
||||
from utilities import json
|
||||
from init_data import init_es_data, init_mongo_data
|
||||
from updates import check_updates_thread
|
||||
|
||||
from utilities import json
|
||||
|
||||
app = Flask(__name__, static_url_path="/static")
|
||||
CORS(app, **config.get("apiserver.cors"))
|
||||
@ -38,6 +38,7 @@ log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
|
||||
check_updates_thread.start()
|
||||
StatisticsReporter.start()
|
||||
|
||||
|
||||
@app.before_first_request
|
||||
@ -57,7 +58,9 @@ def before_request():
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
headers = {}
|
||||
if call.result.filename:
|
||||
headers["Content-Disposition"] = f"attachment; filename={call.result.filename}"
|
||||
headers[
|
||||
"Content-Disposition"
|
||||
] = f"attachment; filename={call.result.filename}"
|
||||
|
||||
if call.result.headers:
|
||||
headers.update(call.result.headers)
|
||||
@ -71,7 +74,9 @@ def before_request():
|
||||
if value is None:
|
||||
response.set_cookie(key, "", expires=0)
|
||||
else:
|
||||
response.set_cookie(key, value, **config.get("apiserver.auth.cookies"))
|
||||
response.set_cookie(
|
||||
key, value, **config.get("apiserver.auth.cookies")
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
|
@ -1,8 +1,24 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pyhocon.config_tree import NoneValue
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from config.info import get_version, get_build_number, get_commit_number
|
||||
from database.errors import translate_errors_context
|
||||
from database.model import Company
|
||||
from database.model.company import ReportStatsOption
|
||||
from service_repo import ServiceRepo, APICall, endpoint
|
||||
from version import __version__ as current_version
|
||||
|
||||
|
||||
@endpoint("server.get_stats")
|
||||
def get_stats(call: APICall):
|
||||
call.result.data = StatisticsReporter.get_statistics(
|
||||
company_id=call.identity.company
|
||||
)
|
||||
|
||||
|
||||
@endpoint("server.config")
|
||||
@ -43,3 +59,35 @@ def info(call: APICall):
|
||||
"build": get_build_number(),
|
||||
"commit": get_commit_number(),
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"server.report_stats_option",
|
||||
request_data_model=ReportStatsOptionRequest,
|
||||
response_data_model=ReportStatsOptionResponse,
|
||||
)
|
||||
def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest):
|
||||
if not StatisticsReporter.supported:
|
||||
result = ReportStatsOptionResponse(supported=False)
|
||||
else:
|
||||
enabled = request.enabled
|
||||
with translate_errors_context():
|
||||
query = Company.objects(id=company)
|
||||
if enabled is None:
|
||||
stats_option = query.first().defaults.stats_option
|
||||
else:
|
||||
stats_option = ReportStatsOption(
|
||||
enabled=enabled,
|
||||
enabled_time=datetime.utcnow(),
|
||||
enabled_version=current_version,
|
||||
enabled_user=call.identity.user,
|
||||
)
|
||||
updated = query.update(defaults__stats_option=stats_option)
|
||||
if not updated:
|
||||
raise errors.server_error.InternalError(
|
||||
f"Failed setting report_stats to {enabled}"
|
||||
)
|
||||
|
||||
result = ReportStatsOptionResponse(**stats_option.to_mongo())
|
||||
|
||||
call.result.data_model = result
|
||||
|
@ -108,7 +108,7 @@ class TestWorkersService(TestService):
|
||||
from_date = to_date - timedelta(days=1)
|
||||
|
||||
# no variants
|
||||
res = self.api.workers.get_stats(
|
||||
res = self.api.workers.get_statistics(
|
||||
items=[
|
||||
dict(key="cpu_usage", aggregation="avg"),
|
||||
dict(key="cpu_usage", aggregation="max"),
|
||||
@ -142,7 +142,7 @@ class TestWorkersService(TestService):
|
||||
)
|
||||
|
||||
# split by variants
|
||||
res = self.api.workers.get_stats(
|
||||
res = self.api.workers.get_statistics(
|
||||
items=[dict(key="cpu_usage", aggregation="avg")],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
@ -165,7 +165,7 @@ class TestWorkersService(TestService):
|
||||
|
||||
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
|
||||
|
||||
res = self.api.workers.get_stats(
|
||||
res = self.api.workers.get_statistics(
|
||||
items=[dict(key="cpu_usage", aggregation="avg")],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
@ -183,12 +183,13 @@ class TestWorkersService(TestService):
|
||||
# run on an empty es db since we have no way
|
||||
# to pass non existing workers to this api
|
||||
# res = self.api.workers.get_activity_report(
|
||||
# from_date=from_date.timestamp(),
|
||||
# to_date=to_date.timestamp(),
|
||||
# from_timestamp=from_timestamp.timestamp(),
|
||||
# to_timestamp=to_timestamp.timestamp(),
|
||||
# interval=20,
|
||||
# )
|
||||
|
||||
self._simulate_workers()
|
||||
|
||||
to_date = utc_now_tz_aware()
|
||||
from_date = to_date - timedelta(minutes=10)
|
||||
|
||||
|
@ -8,6 +8,7 @@ import requests
|
||||
from semantic_version import Version
|
||||
|
||||
from config import config
|
||||
from database.model.settings import Settings
|
||||
from version import __version__ as current_version
|
||||
|
||||
log = config.logger(__name__)
|
||||
@ -39,12 +40,15 @@ class CheckUpdatesThread(Thread):
|
||||
|
||||
def _check_new_version_available(self) -> Optional[_VersionResponse]:
|
||||
url = config.get(
|
||||
"apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates"
|
||||
"apiserver.check_for_updates.url",
|
||||
"https://updates.trains.allegro.ai/updates",
|
||||
)
|
||||
|
||||
uid = Settings.get_by_key("server.uuid")
|
||||
|
||||
response = requests.get(
|
||||
url,
|
||||
json={"versions": {self.component_name: str(current_version)}},
|
||||
json={"versions": {self.component_name: str(current_version)}, "uid": uid},
|
||||
timeout=float(
|
||||
config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)
|
||||
),
|
||||
|
Loading…
Reference in New Issue
Block a user