mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 10:43:10 +00:00
Add support for events.get_task_single_value_metrics, events.plots, events.get_plot_sample and events.next_plot_sample
This commit is contained in:
parent
9b108740da
commit
cff98ae900
@ -124,6 +124,8 @@ class DebugImageResponse(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
pass
|
||||
class TaskMetricsRequest(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
|
@ -16,13 +16,14 @@ from nested_dict import nested_dict
|
||||
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
get_index_name,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
uncompress_plot,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
@ -76,6 +77,8 @@ class EventBLL(object):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
|
||||
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
|
||||
self.events_iterator = EventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
@ -307,11 +310,7 @@ class EventBLL(object):
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def uncompress_plots(self, plot_events: Sequence[dict]):
|
||||
for event in plot_events:
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
uncompress_plot(event)
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_json(text: str) -> bool:
|
||||
@ -479,6 +478,13 @@ class EventBLL(object):
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type, routing=task_id,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // num_last_iterations)
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
@ -486,14 +492,14 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@ -515,9 +521,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
@ -763,20 +767,26 @@ class EventBLL(object):
|
||||
return {}
|
||||
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type, routing=task_id,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
@ -789,9 +799,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@ -817,6 +825,12 @@ class EventBLL(object):
|
||||
]
|
||||
}
|
||||
}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type, routing=task_id,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
@ -824,14 +838,14 @@ class EventBLL(object):
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@ -862,9 +876,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@ -1019,9 +1031,7 @@ class EventBLL(object):
|
||||
{
|
||||
"range": {
|
||||
"timestamp": {
|
||||
"lt": (
|
||||
es_factory.get_timestamp_millis() - timestamp_ms
|
||||
)
|
||||
"lt": (es_factory.get_timestamp_millis() - timestamp_ms)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,15 @@
|
||||
import base64
|
||||
import zlib
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence, Mapping
|
||||
from typing import Union, Sequence, Mapping, Tuple
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
@ -16,10 +21,13 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
SINGLE_SCALAR_ITERATION = -2**31
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
_max_es_allowed_aggregation_buckets = 10000
|
||||
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
|
||||
@ -31,12 +39,18 @@ class EventSettings:
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def max_metrics_count(self):
|
||||
return config.get("services.events.events_retrieval.max_metrics_count", 100)
|
||||
|
||||
@classproperty
|
||||
def max_variants_count(self):
|
||||
return config.get("services.events.events_retrieval.max_variants_count", 100)
|
||||
def max_es_buckets(self):
|
||||
percentage = (
|
||||
min(
|
||||
100,
|
||||
config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count_threshold",
|
||||
80,
|
||||
),
|
||||
)
|
||||
/ 100
|
||||
)
|
||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
@ -78,6 +92,46 @@ def count_company_events(
|
||||
return es.count(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_max_metric_and_variant_counts(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
query: dict,
|
||||
**kwargs,
|
||||
) -> Tuple[int, int]:
|
||||
dynamic = config.get(
|
||||
"services.events.events_retrieval.dynamic_metrics_count", False
|
||||
)
|
||||
max_metrics_count = config.get(
|
||||
"services.events.events_retrieval.max_metrics_count", 100
|
||||
)
|
||||
max_variants_count = config.get(
|
||||
"services.events.events_retrieval.max_variants_count", 100
|
||||
)
|
||||
if not dynamic:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {"metrics_count": {"cardinality": {"field": "metric"}}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_max_metric_and_variant_counts"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
es, company_id=company_id, event_type=event_type, body=es_req, **kwargs,
|
||||
)
|
||||
|
||||
metrics_count = safe_get(
|
||||
es_res, "aggregations/metrics_count/value", max_metrics_count
|
||||
)
|
||||
if not metrics_count:
|
||||
return max_metrics_count, max_variants_count
|
||||
|
||||
return metrics_count, int(EventSettings.max_es_buckets / metrics_count)
|
||||
|
||||
|
||||
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
@ -94,3 +148,19 @@ def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
|
||||
|
||||
def uncompress_plot(event: dict):
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
|
@ -4,8 +4,9 @@ from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Mapping
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
@ -17,6 +18,8 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
SINGLE_SCALAR_ITERATION,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
@ -166,6 +169,58 @@ class EventMetrics:
|
||||
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
if check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
):
|
||||
return {}
|
||||
|
||||
with TimingContext("es", "get_task_single_value_metrics"):
|
||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
||||
|
||||
def _get_value(event: dict):
|
||||
return {
|
||||
field: event.get(field)
|
||||
for field in ("metric", "variant", "value", "timestamp")
|
||||
}
|
||||
|
||||
return {
|
||||
task: [_get_value(e) for e in events]
|
||||
for task, events in bucketize(task_events, itemgetter("task")).items()
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
) -> Sequence[dict]:
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"terms": {"task": task_ids}},
|
||||
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_scalar,
|
||||
routing=",".join(task_ids),
|
||||
)
|
||||
if not es_res["hits"]["total"]["value"]:
|
||||
return []
|
||||
|
||||
return [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@ -219,11 +274,17 @@ class EventMetrics:
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
must = self._task_conditions(task_id)
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type, routing=task_id,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
@ -231,14 +292,14 @@ class EventMetrics:
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
@ -253,9 +314,7 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
@ -307,33 +366,42 @@ class EventMetrics:
|
||||
"""
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
aggs=aggs,
|
||||
task_id=task_id,
|
||||
metrics=metrics,
|
||||
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type, routing=task_id,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
@ -360,19 +428,18 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
aggs: dict,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
@staticmethod
|
||||
def _task_conditions(task_id: str) -> list:
|
||||
return [
|
||||
{"term": {"task": task_id}},
|
||||
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_task_metrics_query(
|
||||
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
|
||||
):
|
||||
must = cls._task_conditions(task_id)
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
@ -387,18 +454,7 @@ class EventMetrics:
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
return {"bool": {"must": must}}
|
||||
|
||||
def get_task_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
@ -426,12 +482,12 @@ class EventMetrics:
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": self._task_conditions(task_id)}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"size": EventSettings.max_es_buckets,
|
||||
"order": {"_key": "asc"},
|
||||
}
|
||||
}
|
||||
|
36
apiserver/bll/event/history_plot_iterator.py
Normal file
36
apiserver/bll/event/history_plot_iterator.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
||||
|
||||
|
||||
class HistoryPlotIterator(HistorySampleIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
return {"terms": {"variant": [v.name for v in variants]}}
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
||||
aggs = {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
}
|
||||
|
||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
||||
min_iter = int(variant_bucket["first_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return min_iter, max_iter
|
||||
|
||||
return aggs, get_min_max_data
|
445
apiserver/bll/event/history_sample_iterator.py
Normal file
445
apiserver/bll/event/history_sample_iterator.py
Normal file
@ -0,0 +1,445 @@
|
||||
import abc
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Callable, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class HistorySampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
reached_first: bool = BoolField()
|
||||
reached_last: bool = BoolField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class HistorySampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistorySampleIterator(abc.ABC):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=HistorySampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
|
||||
) -> HistorySampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = HistorySampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
event = self._get_next_for_current_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
) or self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def _fill_res_and_update_state(
|
||||
self, event: dict, res: HistorySampleResult, state: HistorySampleState
|
||||
):
|
||||
self._process_event(event)
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
pass
|
||||
|
||||
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
self._get_variants_conditions(vs),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, vs in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_variants_condition(variants),
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_current_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
routing=state.task,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_variants_condition(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_another_iteration"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
routing=state.task,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> HistorySampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = HistorySampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: HistorySampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: HistorySampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: HistorySampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_history_sample_for_variant"
|
||||
):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
routing=task,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
pass
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[str, str, int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type, routing=task,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": min_max_aggs,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_history_sample_iterations"
|
||||
):
|
||||
es_res = search_company_events(body=es_req, **search_args,)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
min_iter, max_iter = get_min_max_data(variant_bucket)
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
440
apiserver/bll/event/metric_events_iterator.py
Normal file
440
apiserver/bll/event/metric_events_iterator.py
Normal file
@ -0,0 +1,440 @@
|
||||
import abc
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Callable
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition, get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.metrics import MetricEventStats
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
variant: str = StringField(required=True)
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[VariantState] = ListField([VariantState], required=True)
|
||||
timestamp: int = IntField(default=0)
|
||||
|
||||
|
||||
class TaskScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class MetricEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricEventsResult(object):
|
||||
metric_events: Sequence[tuple] = []
|
||||
next_scroll_id: str = None
|
||||
|
||||
|
||||
class MetricEventsIterator:
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=MetricEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> MetricEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.event_type):
|
||||
return MetricEventsResult()
|
||||
|
||||
def init_state(state_: MetricEventsScrollState):
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
|
||||
def validate_state(state_: MetricEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if refresh:
|
||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = MetricEventsResult(next_scroll_id=state.id)
|
||||
specific_variants_requested = any(
|
||||
variants
|
||||
for t, metrics in task_metrics.items()
|
||||
if metrics
|
||||
for m, variants in metrics.items()
|
||||
)
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
company_id=company_id,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
specific_variants_requested=specific_variants_requested,
|
||||
),
|
||||
state.tasks,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
state: MetricEventsScrollState,
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new event_type events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
) -> Mapping[str, datetime]:
|
||||
"""For metrics that reported event_type events get mapping of the metric name to the last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return {}
|
||||
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return {
|
||||
stats.metric: stats.event_stats_by_type[
|
||||
self.event_type.value
|
||||
].last_update
|
||||
for stats in metric_stats.values()
|
||||
if self.event_type.value in stats.event_stats_by_type
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
}
|
||||
|
||||
update_times = {
|
||||
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
}
|
||||
task_metric_states = {
|
||||
task_state.task: {
|
||||
metric_state.metric: metric_state for metric_state in task_state.metrics
|
||||
}
|
||||
for task_state in state.tasks
|
||||
}
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
) -> TaskScrollState:
|
||||
task = old_state.task
|
||||
updated_state = first(uts for uts in updates if uts.task == task)
|
||||
if not updated_state:
|
||||
old_state.reset()
|
||||
return old_state
|
||||
|
||||
updated_metrics = [m.metric for m in updated_state.metrics]
|
||||
return TaskScrollState(
|
||||
task=task,
|
||||
metrics=[
|
||||
*updated_state.metrics,
|
||||
*(
|
||||
old_metric
|
||||
for old_metric in old_state.metrics
|
||||
if old_metric.metric not in updated_metrics
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
state.tasks = [
|
||||
merge_with_updated_task_states(task_state, updated_task_states)
|
||||
for task_state in state.tasks
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
return [
|
||||
TaskScrollState(task=task, metrics=metric_states,)
|
||||
for task, metric_states in zip(task_metrics, task_metric_states)
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
|
||||
pass
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any event_type events
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type, routing=task,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
**({"aggs": variant_state_aggs} if variant_state_aggs else {}),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_state(variant: dict):
|
||||
"""
|
||||
Return new variant state for the passed variant bucket
|
||||
"""
|
||||
state = VariantState(variant=variant["key"])
|
||||
if fill_variant_state_data:
|
||||
fill_variant_state_data(variant, state)
|
||||
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricState(
|
||||
metric=metric["key"],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
variants=[
|
||||
init_variant_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
pass
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
specific_variants_requested: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update task scroll state
|
||||
"""
|
||||
if not task_state.metrics:
|
||||
return task_state.task, []
|
||||
|
||||
if task_state.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task_state.task}},
|
||||
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and task_state.last_min_iter is not None:
|
||||
range_condition = {"lt": task_state.last_min_iter}
|
||||
elif not navigate_earlier and task_state.last_max_iter is not None:
|
||||
range_condition = {"gt": task_state.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
metrics_count = len(task_state.metrics)
|
||||
max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count))
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": self._get_same_variant_events_order()
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metric_events"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
routing=task_state.task,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return task_state.task, []
|
||||
|
||||
invalid_iterations = {
|
||||
(m.metric, v.variant): v.last_invalid_iteration
|
||||
for m in task_state.metrics
|
||||
for v in m.variants
|
||||
}
|
||||
allow_uninitialized = (
|
||||
False
|
||||
if specific_variants_requested
|
||||
else config.get(
|
||||
"services.events.events_retrieval.debug_images.allow_uninitialized_variants",
|
||||
False,
|
||||
)
|
||||
)
|
||||
|
||||
def is_valid_event(event: dict) -> bool:
|
||||
key = event.get("metric"), event.get("variant")
|
||||
if key not in invalid_iterations:
|
||||
return allow_uninitialized
|
||||
|
||||
max_invalid = invalid_iterations[key]
|
||||
return max_invalid is None or event.get("iter") > max_invalid
|
||||
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
self._process_event(ev["_source"])
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
if is_valid_event(ev["_source"])
|
||||
]
|
||||
|
||||
iterations = []
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets"):
|
||||
events = get_iteration_events(it)
|
||||
if events:
|
||||
iterations.append({"iter": it["key"], "events": events})
|
||||
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
task_state.last_max_iter = iterations[0]["iter"]
|
||||
task_state.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
return task_state.task, iterations
|
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
25
apiserver/bll/event/metric_plots_iterator.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Sequence
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .metric_events_iterator import MetricEventsIterator
|
||||
|
||||
|
||||
class MetricPlotsIterator(MetricEventsIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variant_state_aggs(self):
|
||||
return None, None
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_same_variant_events_order(self) -> dict:
|
||||
return {"timestamp": {"order": "desc"}}
|
@ -302,6 +302,48 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
plots_response_task_metrics {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: Task ID
|
||||
}
|
||||
iterations {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
iter {
|
||||
type: integer
|
||||
description: Iteration number
|
||||
}
|
||||
events {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
description: Plot event
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
plots_response {
|
||||
type: object
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
metrics {
|
||||
type: array
|
||||
description: "Plot events grouped by tasks and iterations"
|
||||
items {"$ref": "#/definitions/plots_response_task_metrics"}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_image_sample_response {
|
||||
type: object
|
||||
properties {
|
||||
@ -323,6 +365,27 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
plot_sample_response {
|
||||
type: object
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
|
||||
}
|
||||
event {
|
||||
type: object
|
||||
description: "Plot event"
|
||||
}
|
||||
min_iteration {
|
||||
type: integer
|
||||
description: "minimal valid iteration for the variant"
|
||||
}
|
||||
max_iteration {
|
||||
type: integer
|
||||
description: "maximal valid iteration for the variant"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
add {
|
||||
"2.1" {
|
||||
@ -486,6 +549,41 @@ debug_images {
|
||||
}
|
||||
}
|
||||
}
|
||||
plots {
|
||||
"999.0" {
|
||||
description: "Get plot events for the requested amount of iterations per each task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
metrics
|
||||
]
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/task_metric_variants" }
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True"
|
||||
}
|
||||
refresh {
|
||||
type: boolean
|
||||
description: "If set then scroll will be moved to the latest iterations. The default is False"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {"$ref": "#/definitions/plots_response"}
|
||||
}
|
||||
}
|
||||
get_debug_image_sample {
|
||||
"2.12": {
|
||||
description: "Return the debug image per metric and variant for the provided iteration"
|
||||
@ -547,6 +645,72 @@ next_debug_image_sample {
|
||||
response {"$ref": "#/definitions/debug_image_sample_response"}
|
||||
}
|
||||
}
|
||||
get_plot_sample {
|
||||
"999.0": {
|
||||
description: "Return the plot per metric and variant for the provided iteration"
|
||||
request {
|
||||
type: object
|
||||
required: [task, metric, variant]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Metric variant"
|
||||
type: string
|
||||
}
|
||||
iteration {
|
||||
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
|
||||
type: integer
|
||||
}
|
||||
refresh {
|
||||
description: "If set then scroll state will be refreshed to reflect the latest changes in the plots"
|
||||
type: boolean
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID from the previous call to get_plot_sample or empty"
|
||||
}
|
||||
navigate_current_metric {
|
||||
description: If set then subsequent navigation with next_plot_sample is done on the plots for the passed metric only. Otherwise for all the metrics
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {"$ref": "#/definitions/plot_sample_response"}
|
||||
}
|
||||
}
|
||||
next_plot_sample {
|
||||
"999.0": {
|
||||
description: "Get the plot for the next variant for the same iteration or for the next iteration"
|
||||
request {
|
||||
type: object
|
||||
required: [task, scroll_id]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID from the previous call to get_plot_sample"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration.
|
||||
Otherwise next variant event from the current iteration or first variant event from the next iteration"""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {"$ref": "#/definitions/plot_sample_response"}
|
||||
}
|
||||
}
|
||||
get_task_metrics{
|
||||
"2.7": {
|
||||
description: "For each task, get a list of metrics for which the requested event type was reported"
|
||||
@ -1112,6 +1276,55 @@ multi_task_scalar_metrics_iter_histogram {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_single_value_metrics {
|
||||
"999.0" {
|
||||
description: Get single value metrics for the passed tasks
|
||||
request {
|
||||
type: object
|
||||
required: [tasks]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tasks {
|
||||
description: Single value metrics grouped by task
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: Task ID
|
||||
}
|
||||
values {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
metric { type: string }
|
||||
variant { type: string}
|
||||
value { type: number }
|
||||
timestamp { type: number }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_latest_scalar_values {
|
||||
"2.1" {
|
||||
description: "Get the tasks's latest scalar values"
|
||||
|
@ -27,6 +27,7 @@ from apiserver.apimodels.events import (
|
||||
ScalarMetricsIterRawRequest,
|
||||
ClearScrollRequest,
|
||||
ClearTaskLogRequest,
|
||||
SingleValueMetricsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
@ -450,6 +451,30 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_single_value_metrics")
|
||||
def get_task_single_value_metrics(
|
||||
call, company_id: str, request: SingleValueMetricsRequest
|
||||
):
|
||||
task_ids = call.data["tasks"]
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids)
|
||||
call.result.data = dict(
|
||||
tasks=[{"task": task, "values": values} for task, values in res.items()]
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
|
||||
def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
@ -613,6 +638,56 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.plots",
|
||||
request_data_model=MetricEventsRequest,
|
||||
response_data_model=MetricEventsResponse,
|
||||
)
|
||||
def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.plots_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
refresh=request.refresh,
|
||||
state_id=request.scroll_id,
|
||||
)
|
||||
|
||||
call.result.data_model = MetricEventsResponse(
|
||||
scroll_id=result.next_scroll_id,
|
||||
metrics=[
|
||||
MetricEvents(
|
||||
task=task,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, iterations) in result.metric_events
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", required_fields=["task"])
|
||||
def get_debug_images_v1_7(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
@ -769,6 +844,42 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
|
||||
)
|
||||
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_sample_for_variant(
|
||||
company_id=task.company,
|
||||
task=request.task,
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
iteration=request.iteration,
|
||||
refresh=request.refresh,
|
||||
state_id=request.scroll_id,
|
||||
navigate_current_metric=request.navigate_current_metric,
|
||||
)
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
|
||||
)
|
||||
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_next_sample(
|
||||
company_id=task.company,
|
||||
task=request.task,
|
||||
state_id=request.scroll_id,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
)
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
||||
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
task = task_bll.assert_exists(
|
||||
|
Loading…
Reference in New Issue
Block a user