mirror of
https://github.com/clearml/clearml-server
synced 2025-02-13 08:05:43 +00:00
![allegroai](/assets/img/avatar_default.png)
Add initial support for project ordering Add support for sortable task duration (used by the UI in the experiment's table) Add support for project name in worker's current task info Add support for results and artifacts in pre-populates examples Add demo server features
305 lines
10 KiB
Python
305 lines
10 KiB
Python
import logging
|
|
import queue
|
|
import random
|
|
import time
|
|
from datetime import timedelta, datetime
|
|
from time import sleep
|
|
from typing import Sequence, Optional
|
|
|
|
import dpath
|
|
import requests
|
|
from requests.adapters import HTTPAdapter
|
|
from requests.packages.urllib3.util.retry import Retry
|
|
|
|
from bll.query import Builder as QueryBuilder
|
|
from bll.util import get_server_uuid
|
|
from bll.workers import WorkerStats, WorkerBLL
|
|
from config import config
|
|
from config.info import get_deployment_type
|
|
from database.model import Company, User
|
|
from database.model.queue import Queue
|
|
from database.model.task.task import Task
|
|
from utilities import safe_get
|
|
from utilities.json import dumps
|
|
from utilities.threads_manager import ThreadsManager
|
|
from version import __version__ as current_version
|
|
from .resource_monitor import ResourceMonitor
|
|
|
|
log = config.logger(__file__)
|
|
|
|
worker_bll = WorkerBLL()
|
|
|
|
|
|
class StatisticsReporter:
|
|
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
|
|
send_queue = queue.Queue()
|
|
supported = config.get("apiserver.statistics.supported", True)
|
|
|
|
@classmethod
|
|
def start(cls):
|
|
cls.start_sender()
|
|
cls.start_reporter()
|
|
|
|
@classmethod
|
|
@threads.register("reporter", daemon=True)
|
|
def start_reporter(cls):
|
|
"""
|
|
Periodically send statistics reports for companies who have opted in.
|
|
Note: in trains we usually have only a single company
|
|
"""
|
|
if not cls.supported:
|
|
return
|
|
|
|
report_interval = timedelta(
|
|
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
|
)
|
|
sleep(report_interval.total_seconds())
|
|
while not ThreadsManager.terminating:
|
|
try:
|
|
for company in Company.objects(
|
|
defaults__stats_option__enabled=True
|
|
).only("id"):
|
|
stats = cls.get_statistics(company.id)
|
|
cls.send_queue.put(stats)
|
|
|
|
except Exception as ex:
|
|
log.exception(f"Failed collecting stats: {str(ex)}")
|
|
|
|
sleep(report_interval.total_seconds())
|
|
|
|
@classmethod
|
|
@threads.register("sender", daemon=True)
|
|
def start_sender(cls):
|
|
if not cls.supported:
|
|
return
|
|
|
|
url = config.get("apiserver.statistics.url")
|
|
|
|
retries = config.get("apiserver.statistics.max_retries", 5)
|
|
max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5)
|
|
session = requests.Session()
|
|
adapter = HTTPAdapter(max_retries=Retry(retries))
|
|
session.mount("http://", adapter)
|
|
session.mount("https://", adapter)
|
|
session.headers["Content-type"] = "application/json"
|
|
|
|
WarningFilter.attach()
|
|
|
|
while not ThreadsManager.terminating:
|
|
try:
|
|
report = cls.send_queue.get()
|
|
|
|
# Set a random backoff factor each time we send a report
|
|
adapter.max_retries.backoff_factor = random.random() * max_backoff
|
|
|
|
session.post(url, data=dumps(report))
|
|
|
|
except Exception as ex:
|
|
pass
|
|
|
|
@classmethod
|
|
def get_statistics(cls, company_id: str) -> dict:
|
|
"""
|
|
Returns a statistics report per company
|
|
"""
|
|
return {
|
|
"time": datetime.utcnow(),
|
|
"company_id": company_id,
|
|
"server": {
|
|
"version": current_version,
|
|
"deployment": get_deployment_type(),
|
|
"uuid": get_server_uuid(),
|
|
"queues": {"count": Queue.objects(company=company_id).count()},
|
|
"users": {"count": User.objects(company=company_id).count()},
|
|
"resources": cls.threads.resource_monitor.get_stats(),
|
|
"experiments": next(
|
|
iter(cls._get_experiments_stats(company_id).values()), {}
|
|
),
|
|
},
|
|
"agents": cls._get_agents_statistics(company_id),
|
|
}
|
|
|
|
@classmethod
|
|
def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]:
|
|
result = cls._get_resource_stats_per_agent(company_id, key="resources")
|
|
dpath.merge(
|
|
result, cls._get_experiments_stats_per_agent(company_id, key="experiments")
|
|
)
|
|
return [{"uuid": agent_id, **data} for agent_id, data in result.items()]
|
|
|
|
@classmethod
|
|
def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
|
agent_resource_threshold_sec = timedelta(
|
|
hours=config.get("apiserver.statistics.report_interval_hours", 24)
|
|
).total_seconds()
|
|
to_timestamp = int(time.time())
|
|
from_timestamp = to_timestamp - int(agent_resource_threshold_sec)
|
|
es_req = {
|
|
"size": 0,
|
|
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
|
"aggs": {
|
|
"workers": {
|
|
"terms": {"field": "worker"},
|
|
"aggs": {
|
|
"categories": {
|
|
"terms": {"field": "category"},
|
|
"aggs": {"count": {"cardinality": {"field": "variant"}}},
|
|
},
|
|
"metrics": {
|
|
"terms": {"field": "metric"},
|
|
"aggs": {
|
|
"min": {"min": {"field": "value"}},
|
|
"max": {"max": {"field": "value"}},
|
|
"avg": {"avg": {"field": "value"}},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
},
|
|
}
|
|
res = cls._run_worker_stats_query(company_id, es_req)
|
|
|
|
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
|
|
names = {"cpu": "num_cores"}
|
|
return {
|
|
names[c["key"]]: safe_get(c, "count/value")
|
|
for c in categories
|
|
if c["key"] in names
|
|
}
|
|
|
|
def _get_metric_fields(metrics: Sequence[dict]) -> dict:
|
|
names = {
|
|
"cpu_usage": "cpu_usage",
|
|
"memory_used": "mem_used_gb",
|
|
"memory_free": "mem_free_gb",
|
|
}
|
|
return {
|
|
names[m["key"]]: {
|
|
"min": safe_get(m, "min/value"),
|
|
"max": safe_get(m, "max/value"),
|
|
"avg": safe_get(m, "avg/value"),
|
|
}
|
|
for m in metrics
|
|
if m["key"] in names
|
|
}
|
|
|
|
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
|
return {
|
|
b["key"]: {
|
|
key: {
|
|
"interval_sec": agent_resource_threshold_sec,
|
|
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
|
|
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
|
|
}
|
|
}
|
|
for b in buckets
|
|
}
|
|
|
|
@classmethod
|
|
def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict:
|
|
agent_relevant_threshold = timedelta(
|
|
days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30)
|
|
)
|
|
to_timestamp = int(time.time())
|
|
from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds())
|
|
workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp)
|
|
if not workers:
|
|
return {}
|
|
|
|
stats = cls._get_experiments_stats(company_id, list(workers.keys()))
|
|
return {
|
|
worker_id: {key: {**workers[worker_id], **stat}}
|
|
for worker_id, stat in stats.items()
|
|
}
|
|
|
|
@classmethod
|
|
def _get_active_workers(
|
|
cls, company_id, from_timestamp: int, to_timestamp: int
|
|
) -> dict:
|
|
es_req = {
|
|
"size": 0,
|
|
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
|
|
"aggs": {
|
|
"workers": {
|
|
"terms": {"field": "worker"},
|
|
"aggs": {"last_activity_time": {"max": {"field": "timestamp"}}},
|
|
}
|
|
},
|
|
}
|
|
res = cls._run_worker_stats_query(company_id, es_req)
|
|
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
|
|
return {
|
|
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
|
|
for b in buckets
|
|
}
|
|
|
|
@classmethod
|
|
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
|
return worker_bll.es_client.search(
|
|
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
|
body=es_req,
|
|
)
|
|
|
|
@classmethod
|
|
def _get_experiments_stats(
|
|
cls, company_id, workers: Optional[Sequence] = None
|
|
) -> dict:
|
|
pipeline = [
|
|
{
|
|
"$match": {
|
|
"company": company_id,
|
|
"started": {"$exists": True, "$ne": None},
|
|
"last_update": {"$exists": True, "$ne": None},
|
|
"status": {"$nin": ["created", "queued"]},
|
|
**({"last_worker": {"$in": workers}} if workers else {}),
|
|
}
|
|
},
|
|
{
|
|
"$group": {
|
|
"_id": "$last_worker" if workers else None,
|
|
"count": {"$sum": 1},
|
|
"avg_run_time_sec": {
|
|
"$avg": {
|
|
"$divide": [
|
|
{"$subtract": ["$last_update", "$started"]},
|
|
1000,
|
|
]
|
|
}
|
|
},
|
|
"avg_iterations": {"$avg": "$last_iteration"},
|
|
}
|
|
},
|
|
{
|
|
"$project": {
|
|
"count": 1,
|
|
"avg_run_time_sec": {"$trunc": "$avg_run_time_sec"},
|
|
"avg_iterations": {"$trunc": "$avg_iterations"},
|
|
}
|
|
},
|
|
]
|
|
return {
|
|
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
|
for group in Task.aggregate(pipeline)
|
|
}
|
|
|
|
|
|
class WarningFilter(logging.Filter):
|
|
@classmethod
|
|
def attach(cls):
|
|
from urllib3.connectionpool import (
|
|
ConnectionPool,
|
|
) # required to make sure the logger is created
|
|
|
|
assert ConnectionPool # make sure import is not optimized out
|
|
|
|
logging.getLogger("urllib3.connectionpool").addFilter(cls())
|
|
|
|
def filter(self, record):
|
|
if (
|
|
record.levelno == logging.WARNING
|
|
and len(record.args) > 2
|
|
and record.args[2] == "/stats"
|
|
):
|
|
return False
|
|
return True
|