From cff98ae900b61e99d3cc275b3cadb77c2229a0ce Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 8 Jul 2022 17:29:39 +0300 Subject: [PATCH] Add support for events.get_task_single_value_metrics, events.plots, events.get_plot_sample and events.next_plot_sample --- apiserver/apimodels/events.py | 2 + apiserver/bll/event/event_bll.py | 58 ++- apiserver/bll/event/event_common.py | 84 +++- apiserver/bll/event/event_metrics.py | 176 ++++--- apiserver/bll/event/history_plot_iterator.py | 36 ++ .../bll/event/history_sample_iterator.py | 445 ++++++++++++++++++ apiserver/bll/event/metric_events_iterator.py | 440 +++++++++++++++++ apiserver/bll/event/metric_plots_iterator.py | 25 + apiserver/schema/services/events.conf | 213 +++++++++ apiserver/services/events.py | 111 +++++ 10 files changed, 1499 insertions(+), 91 deletions(-) create mode 100644 apiserver/bll/event/history_plot_iterator.py create mode 100644 apiserver/bll/event/history_sample_iterator.py create mode 100644 apiserver/bll/event/metric_events_iterator.py create mode 100644 apiserver/bll/event/metric_plots_iterator.py diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index c0e02b6..2f8cbe5 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -124,6 +124,8 @@ class DebugImageResponse(Base): scroll_id: str = StringField() +class SingleValueMetricsRequest(MultiTasksRequestBase): + pass class TaskMetricsRequest(Base): tasks: Sequence[str] = ListField( items_types=str, validators=[Length(minimum_value=1)] diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index a0d1bef..c42bf1e 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -16,13 +16,14 @@ from nested_dict import nested_dict from apiserver.bll.event.debug_sample_history import DebugSampleHistory from apiserver.bll.event.event_common import ( EventType, - EventSettings, get_index_name, check_empty_data, search_company_events, delete_company_events, MetricVariants, get_metric_variants_condition, + uncompress_plot, + get_max_metric_and_variant_counts, ) from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult from apiserver.bll.util import parallel_chunked_decorator @@ -76,6 +77,8 @@ class EventBLL(object): self.redis = redis or redman.connection("apiserver") self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis) self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis) + self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis) + self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis) self.events_iterator = EventsIterator(es=self.es) @property @@ -307,11 +310,7 @@ class EventBLL(object): @parallel_chunked_decorator(chunk_size=10) def uncompress_plots(self, plot_events: Sequence[dict]): for event in plot_events: - plot_data = event.pop(PlotFields.plot_data, None) - if plot_data and event.get(PlotFields.plot_str) is None: - event[PlotFields.plot_str] = zlib.decompress( - base64.b64decode(plot_data) - ).decode() + uncompress_plot(event) @staticmethod def _is_valid_json(text: str) -> bool: @@ -479,6 +478,13 @@ class EventBLL(object): if metric_variants: must.append(get_metric_variants_condition(metric_variants)) query = {"bool": {"must": must}} + search_args = dict( + es=self.es, company_id=company_id, event_type=event_type, routing=task_id, + ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args, + ) + max_variants = int(max_variants // num_last_iterations) es_req: dict = { "size": 0, @@ -486,14 +492,14 @@ class EventBLL(object): "metrics": { "terms": { "field": "metric", - "size": EventSettings.max_metrics_count, + "size": max_metrics, "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", - "size": EventSettings.max_variants_count, + "size": max_variants, "order": {"_key": "asc"}, }, "aggs": { @@ -515,9 +521,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "task_last_iter_metric_variant" ): - es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req - ) + es_res = search_company_events(body=es_req, **search_args) if "aggregations" not in es_res: return [] @@ -763,20 +767,26 @@ class EventBLL(object): return {} query = {"bool": {"must": [{"term": {"task": task_id}}]}} + search_args = dict( + es=self.es, company_id=company_id, event_type=event_type, routing=task_id, + ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args, + ) es_req = { "size": 0, "aggs": { "metrics": { "terms": { "field": "metric", - "size": EventSettings.max_metrics_count, + "size": max_metrics, "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", - "size": EventSettings.max_variants_count, + "size": max_variants, "order": {"_key": "asc"}, } } @@ -789,9 +799,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req - ) + es_res = search_company_events(body=es_req, **search_args) metrics = {} for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): @@ -817,6 +825,12 @@ class EventBLL(object): ] } } + search_args = dict( + es=self.es, company_id=company_id, event_type=event_type, routing=task_id, + ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args, + ) es_req = { "size": 0, "query": query, @@ -824,14 +838,14 @@ class EventBLL(object): "metrics": { "terms": { "field": "metric", - "size": EventSettings.max_metrics_count, + "size": max_metrics, "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", - "size": EventSettings.max_variants_count, + "size": max_variants, "order": {"_key": "asc"}, }, "aggs": { @@ -862,9 +876,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req - ) + es_res = search_company_events(body=es_req, **search_args) metrics = [] max_timestamp = 0 @@ -1019,9 +1031,7 @@ class EventBLL(object): { "range": { "timestamp": { - "lt": ( - es_factory.get_timestamp_millis() - timestamp_ms - ) + "lt": (es_factory.get_timestamp_millis() - timestamp_ms) } } } diff --git a/apiserver/bll/event/event_common.py b/apiserver/bll/event/event_common.py index ff70d84..f4b8b17 100644 --- a/apiserver/bll/event/event_common.py +++ b/apiserver/bll/event/event_common.py @@ -1,10 +1,15 @@ +import base64 +import zlib from enum import Enum -from typing import Union, Sequence, Mapping +from typing import Union, Sequence, Mapping, Tuple from boltons.typeutils import classproperty from elasticsearch import Elasticsearch from apiserver.config_repo import config +from apiserver.database.errors import translate_errors_context +from apiserver.timing_context import TimingContext +from apiserver.tools import safe_get class EventType(Enum): @@ -16,10 +21,13 @@ class EventType(Enum): all = "*" +SINGLE_SCALAR_ITERATION = -2**31 MetricVariants = Mapping[str, Sequence[str]] class EventSettings: + _max_es_allowed_aggregation_buckets = 10000 + @classproperty def max_workers(self): return config.get("services.events.events_retrieval.max_metrics_concurrency", 4) @@ -31,12 +39,18 @@ class EventSettings: ) @classproperty - def max_metrics_count(self): - return config.get("services.events.events_retrieval.max_metrics_count", 100) - - @classproperty - def max_variants_count(self): - return config.get("services.events.events_retrieval.max_variants_count", 100) + def max_es_buckets(self): + percentage = ( + min( + 100, + config.get( + "services.events.events_retrieval.dynamic_metrics_count_threshold", + 80, + ), + ) + / 100 + ) + return int(self._max_es_allowed_aggregation_buckets * percentage) def get_index_name(company_id: str, event_type: str): @@ -78,6 +92,46 @@ def count_company_events( return es.count(index=es_index, body=body, **kwargs) +def get_max_metric_and_variant_counts( + es: Elasticsearch, + company_id: Union[str, Sequence[str]], + event_type: EventType, + query: dict, + **kwargs, +) -> Tuple[int, int]: + dynamic = config.get( + "services.events.events_retrieval.dynamic_metrics_count", False + ) + max_metrics_count = config.get( + "services.events.events_retrieval.max_metrics_count", 100 + ) + max_variants_count = config.get( + "services.events.events_retrieval.max_variants_count", 100 + ) + if not dynamic: + return max_metrics_count, max_variants_count + + es_req: dict = { + "size": 0, + "query": query, + "aggs": {"metrics_count": {"cardinality": {"field": "metric"}}}, + } + with translate_errors_context(), TimingContext( + "es", "get_max_metric_and_variant_counts" + ): + es_res = search_company_events( + es, company_id=company_id, event_type=event_type, body=es_req, **kwargs, + ) + + metrics_count = safe_get( + es_res, "aggregations/metrics_count/value", max_metrics_count + ) + if not metrics_count: + return max_metrics_count, max_variants_count + + return metrics_count, int(EventSettings.max_es_buckets / metrics_count) + + def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence: conditions = [ { @@ -94,3 +148,19 @@ def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence: ] return {"bool": {"should": conditions}} + + +class PlotFields: + valid_plot = "valid_plot" + plot_len = "plot_len" + plot_str = "plot_str" + plot_data = "plot_data" + source_urls = "source_urls" + + +def uncompress_plot(event: dict): + plot_data = event.pop(PlotFields.plot_data, None) + if plot_data and event.get(PlotFields.plot_str) is None: + event[PlotFields.plot_str] = zlib.decompress( + base64.b64decode(plot_data) + ).decode() diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index 01d8cb2..bdc2179 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -4,8 +4,9 @@ from collections import defaultdict from concurrent.futures.thread import ThreadPoolExecutor from functools import partial from operator import itemgetter -from typing import Sequence, Tuple +from typing import Sequence, Tuple, Mapping +from boltons.iterutils import bucketize from elasticsearch import Elasticsearch from mongoengine import Q @@ -17,6 +18,8 @@ from apiserver.bll.event.event_common import ( check_empty_data, MetricVariants, get_metric_variants_condition, + get_max_metric_and_variant_counts, + SINGLE_SCALAR_ITERATION, ) from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.config_repo import config @@ -166,6 +169,58 @@ class EventMetrics: return res + def get_task_single_value_metrics( + self, company_id: str, task_ids: Sequence[str] + ) -> Mapping[str, dict]: + """ + For the requested tasks return all the events delivered for the single iteration (-2**31) + """ + if check_empty_data( + self.es, company_id=company_id, event_type=EventType.metrics_scalar + ): + return {} + + with TimingContext("es", "get_task_single_value_metrics"): + task_events = self._get_task_single_value_metrics(company_id, task_ids) + + def _get_value(event: dict): + return { + field: event.get(field) + for field in ("metric", "variant", "value", "timestamp") + } + + return { + task: [_get_value(e) for e in events] + for task, events in bucketize(task_events, itemgetter("task")).items() + } + + def _get_task_single_value_metrics( + self, company_id: str, task_ids: Sequence[str] + ) -> Sequence[dict]: + es_req = { + "size": 10000, + "query": { + "bool": { + "must": [ + {"terms": {"task": task_ids}}, + {"term": {"iter": SINGLE_SCALAR_ITERATION}}, + ] + } + }, + } + with translate_errors_context(): + es_res = search_company_events( + body=es_req, + es=self.es, + company_id=company_id, + event_type=EventType.metrics_scalar, + routing=",".join(task_ids), + ) + if not es_res["hits"]["total"]["value"]: + return [] + + return [hit["_source"] for hit in es_res["hits"]["hits"]] + MetricInterval = Tuple[str, str, int, int] MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]] @@ -219,11 +274,17 @@ class EventMetrics: Return the list og metric variant intervals as the following tuple: (metric, variant, interval, samples) """ - must = [{"term": {"task": task_id}}] + must = self._task_conditions(task_id) if metric_variants: must.append(get_metric_variants_condition(metric_variants)) query = {"bool": {"must": must}} - + search_args = dict( + es=self.es, company_id=company_id, event_type=event_type, routing=task_id, + ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args, + ) + max_variants = int(max_variants // 2) es_req = { "size": 0, "query": query, @@ -231,14 +292,14 @@ class EventMetrics: "metrics": { "terms": { "field": "metric", - "size": EventSettings.max_metrics_count, + "size": max_metrics, "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", - "size": EventSettings.max_variants_count, + "size": max_variants, "order": {"_key": "asc"}, }, "aggs": { @@ -253,9 +314,7 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): - es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req, - ) + es_res = search_company_events(body=es_req, **search_args) aggs_result = es_res.get("aggregations") if not aggs_result: @@ -307,33 +366,42 @@ class EventMetrics: """ interval, metrics = metrics_interval aggregation = self._add_aggregation_average(key.get_aggregation(interval)) - aggs = { - "metrics": { - "terms": { - "field": "metric", - "size": EventSettings.max_metrics_count, - "order": {"_key": "asc"}, - }, - "aggs": { - "variants": { - "terms": { - "field": "variant", - "size": EventSettings.max_variants_count, - "order": {"_key": "asc"}, - }, - "aggs": aggregation, - } - }, - } - } - aggs_result = self._query_aggregation_for_task_metrics( - company_id=company_id, - event_type=event_type, - aggs=aggs, - task_id=task_id, - metrics=metrics, + query = self._get_task_metrics_query(task_id=task_id, metrics=metrics) + search_args = dict( + es=self.es, company_id=company_id, event_type=event_type, routing=task_id, ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args, + ) + max_variants = int(max_variants // 2) + es_req = { + "size": 0, + "query": query, + "aggs": { + "metrics": { + "terms": { + "field": "metric", + "size": max_metrics, + "order": {"_key": "asc"}, + }, + "aggs": { + "variants": { + "terms": { + "field": "variant", + "size": max_variants, + "order": {"_key": "asc"}, + }, + "aggs": aggregation, + } + }, + } + }, + } + with translate_errors_context(): + es_res = search_company_events(body=es_req, **search_args) + + aggs_result = es_res.get("aggregations") if not aggs_result: return {} @@ -360,19 +428,18 @@ class EventMetrics: for key, value in aggregation.items() } - def _query_aggregation_for_task_metrics( - self, - company_id: str, - event_type: EventType, - aggs: dict, - task_id: str, - metrics: Sequence[Tuple[str, str]], - ) -> dict: - """ - Return the result of elastic search query for the given aggregation filtered - by the given task_ids and metrics - """ - must = [{"term": {"task": task_id}}] + @staticmethod + def _task_conditions(task_id: str) -> list: + return [ + {"term": {"task": task_id}}, + {"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}}, + ] + + @classmethod + def _get_task_metrics_query( + cls, task_id: str, metrics: Sequence[Tuple[str, str]], + ): + must = cls._task_conditions(task_id) if metrics: should = [ { @@ -387,18 +454,7 @@ class EventMetrics: ] must.append({"bool": {"should": should}}) - es_req = { - "size": 0, - "query": {"bool": {"must": must}}, - "aggs": aggs, - } - - with translate_errors_context(), TimingContext("es", "task_stats_scalar"): - es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req, - ) - - return es_res.get("aggregations") + return {"bool": {"must": must}} def get_task_metrics( self, company_id, task_ids: Sequence, event_type: EventType @@ -426,12 +482,12 @@ class EventMetrics: ) -> Sequence: es_req = { "size": 0, - "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, + "query": {"bool": {"must": self._task_conditions(task_id)}}, "aggs": { "metrics": { "terms": { "field": "metric", - "size": EventSettings.max_metrics_count, + "size": EventSettings.max_es_buckets, "order": {"_key": "asc"}, } } diff --git a/apiserver/bll/event/history_plot_iterator.py b/apiserver/bll/event/history_plot_iterator.py new file mode 100644 index 0000000..a9a67f8 --- /dev/null +++ b/apiserver/bll/event/history_plot_iterator.py @@ -0,0 +1,36 @@ +from typing import Sequence, Tuple, Callable + +from elasticsearch import Elasticsearch +from redis.client import StrictRedis + +from .event_common import EventType, uncompress_plot +from .history_sample_iterator import HistorySampleIterator, VariantState + + +class HistoryPlotIterator(HistorySampleIterator): + def __init__(self, redis: StrictRedis, es: Elasticsearch): + super().__init__(redis, es, EventType.metrics_plot) + + def _get_extra_conditions(self) -> Sequence[dict]: + return [] + + def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict: + return {"terms": {"variant": [v.name for v in variants]}} + + def _process_event(self, event: dict) -> dict: + uncompress_plot(event) + return event + + def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]: + # The min iteration is the lowest iteration that contains non-recycled image url + aggs = { + "last_iter": {"max": {"field": "iter"}}, + "first_iter": {"min": {"field": "iter"}}, + } + + def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]: + min_iter = int(variant_bucket["first_iter"]["value"]) + max_iter = int(variant_bucket["last_iter"]["value"]) + return min_iter, max_iter + + return aggs, get_min_max_data diff --git a/apiserver/bll/event/history_sample_iterator.py b/apiserver/bll/event/history_sample_iterator.py new file mode 100644 index 0000000..3c0c535 --- /dev/null +++ b/apiserver/bll/event/history_sample_iterator.py @@ -0,0 +1,445 @@ +import abc +import operator +from operator import attrgetter +from typing import Sequence, Tuple, Optional, Callable, Mapping + +import attr +from boltons.iterutils import first, bucketize +from elasticsearch import Elasticsearch +from jsonmodels.fields import StringField, ListField, IntField, BoolField +from jsonmodels.models import Base +from redis import StrictRedis + +from apiserver.apierrors import errors +from apiserver.apimodels import JsonSerializableMixin +from apiserver.bll.event.event_common import ( + EventSettings, + EventType, + check_empty_data, + search_company_events, + get_max_metric_and_variant_counts, +) +from apiserver.bll.redis_cache_manager import RedisCacheManager +from apiserver.database.errors import translate_errors_context +from apiserver.timing_context import TimingContext +from apiserver.utilities.dicts import nested_get + + +class VariantState(Base): + name: str = StringField(required=True) + metric: str = StringField(default=None) + min_iteration: int = IntField() + max_iteration: int = IntField() + + +class HistorySampleState(Base, JsonSerializableMixin): + id: str = StringField(required=True) + iteration: int = IntField() + variant: str = StringField() + task: str = StringField() + metric: str = StringField() + reached_first: bool = BoolField() + reached_last: bool = BoolField() + variant_states: Sequence[VariantState] = ListField([VariantState]) + warning: str = StringField() + navigate_current_metric = BoolField(default=True) + + +@attr.s(auto_attribs=True) +class HistorySampleResult(object): + scroll_id: str = None + event: dict = None + min_iteration: int = None + max_iteration: int = None + + +class HistorySampleIterator(abc.ABC): + def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType): + self.es = es + self.event_type = event_type + self.cache_manager = RedisCacheManager( + state_class=HistorySampleState, + redis=redis, + expiration_interval=EventSettings.state_expiration_sec, + ) + + def get_next_sample( + self, company_id: str, task: str, state_id: str, navigate_earlier: bool + ) -> HistorySampleResult: + """ + Get the sample for next/prev variant on the current iteration + If does not exist then try getting sample for the first/last variant from next/prev iteration + """ + res = HistorySampleResult(scroll_id=state_id) + state = self.cache_manager.get_state(state_id) + if not state or state.task != task: + raise errors.bad_request.InvalidScrollId(scroll_id=state_id) + + if check_empty_data(self.es, company_id=company_id, event_type=self.event_type): + return res + + event = self._get_next_for_current_iteration( + company_id=company_id, navigate_earlier=navigate_earlier, state=state + ) or self._get_next_for_another_iteration( + company_id=company_id, navigate_earlier=navigate_earlier, state=state + ) + if not event: + return res + + self._fill_res_and_update_state(event=event, res=res, state=state) + self.cache_manager.set_state(state=state) + return res + + def _fill_res_and_update_state( + self, event: dict, res: HistorySampleResult, state: HistorySampleState + ): + self._process_event(event) + state.variant = event["variant"] + state.metric = event["metric"] + state.iteration = event["iter"] + res.event = event + var_state = first( + vs + for vs in state.variant_states + if vs.name == state.variant and vs.metric == state.metric + ) + if var_state: + res.min_iteration = var_state.min_iteration + res.max_iteration = var_state.max_iteration + + @abc.abstractmethod + def _get_extra_conditions(self) -> Sequence[dict]: + pass + + @abc.abstractmethod + def _process_event(self, event: dict) -> dict: + pass + + @abc.abstractmethod + def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict: + pass + + def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict: + metrics = bucketize(variants, key=attrgetter("metric")) + metrics_conditions = [ + { + "bool": { + "must": [ + {"term": {"metric": metric}}, + self._get_variants_conditions(vs), + ] + } + } + for metric, vs in metrics.items() + ] + return {"bool": {"should": metrics_conditions}} + + def _get_next_for_current_iteration( + self, company_id: str, navigate_earlier: bool, state: HistorySampleState + ) -> Optional[dict]: + """ + Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration + Only variants for which the iteration falls into their valid range are considered + Return None if no such variant or sample is found + """ + if state.navigate_current_metric: + variants = [ + var_state + for var_state in state.variant_states + if var_state.metric == state.metric + ] + else: + variants = state.variant_states + + cmp = operator.lt if navigate_earlier else operator.gt + variants = [ + var_state + for var_state in variants + if cmp((var_state.metric, var_state.name), (state.metric, state.variant)) + and var_state.min_iteration <= state.iteration + ] + if not variants: + return + + must_conditions = [ + {"term": {"task": state.task}}, + {"term": {"iter": state.iteration}}, + self._get_metric_variants_condition(variants), + *self._get_extra_conditions(), + ] + order = "desc" if navigate_earlier else "asc" + es_req = { + "size": 1, + "sort": [{"metric": order}, {"variant": order}], + "query": {"bool": {"must": must_conditions}}, + } + + with translate_errors_context(), TimingContext( + "es", "get_next_for_current_iteration" + ): + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=self.event_type, + body=es_req, + routing=state.task, + ) + + hits = nested_get(es_res, ("hits", "hits")) + if not hits: + return + + return hits[0]["_source"] + + def _get_next_for_another_iteration( + self, company_id: str, navigate_earlier: bool, state: HistorySampleState + ) -> Optional[dict]: + """ + Get the sample for the first variant for the next iteration (if navigate_earlier is set to False) + or from the last variant for the previous iteration (otherwise) + The variants for which the sample falls in invalid range are discarded + If no suitable sample is found then None is returned + """ + if state.navigate_current_metric: + variants = [ + var_state + for var_state in state.variant_states + if var_state.metric == state.metric + ] + else: + variants = state.variant_states + + if navigate_earlier: + range_operator = "lt" + order = "desc" + variants = [ + var_state + for var_state in variants + if var_state.min_iteration < state.iteration + ] + else: + range_operator = "gt" + order = "asc" + variants = variants + + if not variants: + return + + must_conditions = [ + {"term": {"task": state.task}}, + self._get_metric_variants_condition(variants), + {"range": {"iter": {range_operator: state.iteration}}}, + *self._get_extra_conditions(), + ] + es_req = { + "size": 1, + "sort": [{"iter": order}, {"metric": order}, {"variant": order}], + "query": {"bool": {"must": must_conditions}}, + } + with translate_errors_context(), TimingContext( + "es", "get_next_for_another_iteration" + ): + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=self.event_type, + body=es_req, + routing=state.task, + ) + + hits = nested_get(es_res, ("hits", "hits")) + if not hits: + return + + return hits[0]["_source"] + + def get_sample_for_variant( + self, + company_id: str, + task: str, + metric: str, + variant: str, + iteration: Optional[int] = None, + refresh: bool = False, + state_id: str = None, + navigate_current_metric: bool = True, + ) -> HistorySampleResult: + """ + Get the sample for the requested iteration or the latest before it + If the iteration is not passed then get the latest event + """ + res = HistorySampleResult() + if check_empty_data(self.es, company_id=company_id, event_type=self.event_type): + return res + + def init_state(state_: HistorySampleState): + state_.task = task + state_.metric = metric + state_.navigate_current_metric = navigate_current_metric + self._reset_variant_states(company_id=company_id, state=state_) + + def validate_state(state_: HistorySampleState): + if ( + state_.task != task + or state_.navigate_current_metric != navigate_current_metric + or (state_.navigate_current_metric and state_.metric != metric) + ): + raise errors.bad_request.InvalidScrollId( + "Task and metric stored in the state do not match the passed ones", + scroll_id=state_.id, + ) + # fix old variant states: + for vs in state_.variant_states: + if vs.metric is None: + vs.metric = metric + if refresh: + self._reset_variant_states(company_id=company_id, state=state_) + + state: HistorySampleState + with self.cache_manager.get_or_create_state( + state_id=state_id, init_state=init_state, validate_state=validate_state, + ) as state: + res.scroll_id = state.id + + var_state = first( + vs + for vs in state.variant_states + if vs.name == variant and vs.metric == metric + ) + if not var_state: + return res + + res.min_iteration = var_state.min_iteration + res.max_iteration = var_state.max_iteration + + must_conditions = [ + {"term": {"task": task}}, + {"term": {"metric": metric}}, + {"term": {"variant": variant}}, + *self._get_extra_conditions(), + ] + if iteration is not None: + must_conditions.append( + { + "range": { + "iter": {"lte": iteration, "gte": var_state.min_iteration} + } + } + ) + else: + must_conditions.append( + {"range": {"iter": {"gte": var_state.min_iteration}}} + ) + + es_req = { + "size": 1, + "sort": {"iter": "desc"}, + "query": {"bool": {"must": must_conditions}}, + } + + with translate_errors_context(), TimingContext( + "es", "get_history_sample_for_variant" + ): + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=self.event_type, + body=es_req, + routing=task, + ) + + hits = nested_get(es_res, ("hits", "hits")) + if not hits: + return res + + self._fill_res_and_update_state( + event=hits[0]["_source"], res=res, state=state + ) + return res + + def _reset_variant_states(self, company_id: str, state: HistorySampleState): + metrics = self._get_metric_variant_iterations( + company_id=company_id, + task=state.task, + metric=state.metric if state.navigate_current_metric else None, + ) + state.variant_states = [ + VariantState( + metric=metric, + name=var_name, + min_iteration=min_iter, + max_iteration=max_iter, + ) + for metric, variants in metrics.items() + for var_name, min_iter, max_iter in variants + ] + + @abc.abstractmethod + def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]: + pass + + def _get_metric_variant_iterations( + self, company_id: str, task: str, metric: str, + ) -> Mapping[str, Tuple[str, str, int, int]]: + """ + Return valid min and max iterations that the task reported events of the required type + """ + must = [ + {"term": {"task": task}}, + *self._get_extra_conditions(), + ] + if metric is not None: + must.append({"term": {"metric": metric}}) + query = {"bool": {"must": must}} + + search_args = dict( + es=self.es, company_id=company_id, event_type=self.event_type, routing=task, + ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args + ) + max_variants = int(max_variants // 2) + min_max_aggs, get_min_max_data = self._get_min_max_aggs() + es_req: dict = { + "size": 0, + "query": query, + "aggs": { + "metrics": { + "terms": { + "field": "metric", + "size": max_metrics, + "order": {"_key": "asc"}, + }, + "aggs": { + "variants": { + "terms": { + "field": "variant", + "size": max_variants, + "order": {"_key": "asc"}, + }, + "aggs": min_max_aggs, + } + }, + } + }, + } + + with translate_errors_context(), TimingContext( + "es", "get_history_sample_iterations" + ): + es_res = search_company_events(body=es_req, **search_args,) + + def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]: + variant = variant_bucket["key"] + min_iter, max_iter = get_min_max_data(variant_bucket) + return variant, min_iter, max_iter + + return { + metric_bucket["key"]: [ + get_variant_data(variant_bucket) + for variant_bucket in nested_get(metric_bucket, ("variants", "buckets")) + ] + for metric_bucket in nested_get( + es_res, ("aggregations", "metrics", "buckets") + ) + } diff --git a/apiserver/bll/event/metric_events_iterator.py b/apiserver/bll/event/metric_events_iterator.py new file mode 100644 index 0000000..37c2251 --- /dev/null +++ b/apiserver/bll/event/metric_events_iterator.py @@ -0,0 +1,440 @@ +import abc +from concurrent.futures.thread import ThreadPoolExecutor +from datetime import datetime +from functools import partial +from operator import itemgetter +from typing import Sequence, Tuple, Optional, Mapping, Callable + +import attr +import dpath +from boltons.iterutils import first +from elasticsearch import Elasticsearch +from jsonmodels.fields import StringField, ListField, IntField +from jsonmodels.models import Base +from redis import StrictRedis + +from apiserver.apimodels import JsonSerializableMixin +from apiserver.bll.event.event_common import ( + EventSettings, + check_empty_data, + search_company_events, + EventType, + get_metric_variants_condition, get_max_metric_and_variant_counts, +) +from apiserver.bll.redis_cache_manager import RedisCacheManager +from apiserver.config_repo import config +from apiserver.database.errors import translate_errors_context +from apiserver.database.model.task.metrics import MetricEventStats +from apiserver.database.model.task.task import Task +from apiserver.timing_context import TimingContext + + +class VariantState(Base): + variant: str = StringField(required=True) + last_invalid_iteration: int = IntField() + + +class MetricState(Base): + metric: str = StringField(required=True) + variants: Sequence[VariantState] = ListField([VariantState], required=True) + timestamp: int = IntField(default=0) + + +class TaskScrollState(Base): + task: str = StringField(required=True) + metrics: Sequence[MetricState] = ListField([MetricState], required=True) + last_min_iter: Optional[int] = IntField() + last_max_iter: Optional[int] = IntField() + + def reset(self): + """Reset the scrolling state for the metric""" + self.last_min_iter = self.last_max_iter = None + + +class MetricEventsScrollState(Base, JsonSerializableMixin): + id: str = StringField(required=True) + tasks: Sequence[TaskScrollState] = ListField([TaskScrollState]) + warning: str = StringField() + + +@attr.s(auto_attribs=True) +class MetricEventsResult(object): + metric_events: Sequence[tuple] = [] + next_scroll_id: str = None + + +class MetricEventsIterator: + def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType): + self.es = es + self.event_type = event_type + self.cache_manager = RedisCacheManager( + state_class=MetricEventsScrollState, + redis=redis, + expiration_interval=EventSettings.state_expiration_sec, + ) + + def get_task_events( + self, + company_id: str, + task_metrics: Mapping[str, dict], + iter_count: int, + navigate_earlier: bool = True, + refresh: bool = False, + state_id: str = None, + ) -> MetricEventsResult: + if check_empty_data(self.es, company_id, self.event_type): + return MetricEventsResult() + + def init_state(state_: MetricEventsScrollState): + state_.tasks = self._init_task_states(company_id, task_metrics) + + def validate_state(state_: MetricEventsScrollState): + """ + Validate that the metrics stored in the state are the same + as requested in the current call. + Refresh the state if requested + """ + if refresh: + self._reinit_outdated_task_states(company_id, state_, task_metrics) + + with self.cache_manager.get_or_create_state( + state_id=state_id, init_state=init_state, validate_state=validate_state + ) as state: + res = MetricEventsResult(next_scroll_id=state.id) + specific_variants_requested = any( + variants + for t, metrics in task_metrics.items() + if metrics + for m, variants in metrics.items() + ) + with ThreadPoolExecutor(EventSettings.max_workers) as pool: + res.metric_events = list( + pool.map( + partial( + self._get_task_metric_events, + company_id=company_id, + iter_count=iter_count, + navigate_earlier=navigate_earlier, + specific_variants_requested=specific_variants_requested, + ), + state.tasks, + ) + ) + + return res + + def _reinit_outdated_task_states( + self, + company_id, + state: MetricEventsScrollState, + task_metrics: Mapping[str, dict], + ): + """ + Determine the metrics for which new event_type events were added + since their states were initialized and re-init these states + """ + tasks = Task.objects(id__in=list(task_metrics), company=company_id).only( + "id", "metric_stats" + ) + + def get_last_update_times_for_task_metrics( + task: Task, + ) -> Mapping[str, datetime]: + """For metrics that reported event_type events get mapping of the metric name to the last update times""" + metric_stats: Mapping[str, MetricEventStats] = task.metric_stats + if not metric_stats: + return {} + + requested_metrics = task_metrics[task.id] + return { + stats.metric: stats.event_stats_by_type[ + self.event_type.value + ].last_update + for stats in metric_stats.values() + if self.event_type.value in stats.event_stats_by_type + and (not requested_metrics or stats.metric in requested_metrics) + } + + update_times = { + task.id: get_last_update_times_for_task_metrics(task) for task in tasks + } + task_metric_states = { + task_state.task: { + metric_state.metric: metric_state for metric_state in task_state.metrics + } + for task_state in state.tasks + } + task_metrics_to_recalc = {} + for task, metrics_times in update_times.items(): + old_metric_states = task_metric_states[task] + metrics_to_recalc = { + m: task_metrics[task].get(m) + for m, t in metrics_times.items() + if m not in old_metric_states or old_metric_states[m].timestamp < t + } + if metrics_to_recalc: + task_metrics_to_recalc[task] = metrics_to_recalc + + updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc) + + def merge_with_updated_task_states( + old_state: TaskScrollState, updates: Sequence[TaskScrollState] + ) -> TaskScrollState: + task = old_state.task + updated_state = first(uts for uts in updates if uts.task == task) + if not updated_state: + old_state.reset() + return old_state + + updated_metrics = [m.metric for m in updated_state.metrics] + return TaskScrollState( + task=task, + metrics=[ + *updated_state.metrics, + *( + old_metric + for old_metric in old_state.metrics + if old_metric.metric not in updated_metrics + ), + ], + ) + + state.tasks = [ + merge_with_updated_task_states(task_state, updated_task_states) + for task_state in state.tasks + ] + + def _init_task_states( + self, company_id: str, task_metrics: Mapping[str, dict] + ) -> Sequence[TaskScrollState]: + """ + Returned initialized metric scroll stated for the requested task metrics + """ + with ThreadPoolExecutor(EventSettings.max_workers) as pool: + task_metric_states = pool.map( + partial(self._init_metric_states_for_task, company_id=company_id), + task_metrics.items(), + ) + + return [ + TaskScrollState(task=task, metrics=metric_states,) + for task, metric_states in zip(task_metrics, task_metric_states) + ] + + @abc.abstractmethod + def _get_extra_conditions(self) -> Sequence[dict]: + pass + + @abc.abstractmethod + def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]: + pass + + def _init_metric_states_for_task( + self, task_metrics: Tuple[str, dict], company_id: str + ) -> Sequence[MetricState]: + """ + Return metric scroll states for the task filled with the variant states + for the variants that reported any event_type events + """ + task, metrics = task_metrics + must = [{"term": {"task": task}}, *self._get_extra_conditions()] + if metrics: + must.append(get_metric_variants_condition(metrics)) + query = {"bool": {"must": must}} + + search_args = dict( + es=self.es, company_id=company_id, event_type=self.event_type, routing=task, + ) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, **search_args + ) + max_variants = int(max_variants // 2) + variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs() + es_req: dict = { + "size": 0, + "query": query, + "aggs": { + "metrics": { + "terms": { + "field": "metric", + "size": max_metrics, + "order": {"_key": "asc"}, + }, + "aggs": { + "last_event_timestamp": {"max": {"field": "timestamp"}}, + "variants": { + "terms": { + "field": "variant", + "size": max_variants, + "order": {"_key": "asc"}, + }, + **({"aggs": variant_state_aggs} if variant_state_aggs else {}), + }, + }, + } + }, + } + + with translate_errors_context(), TimingContext("es", "_init_metric_states"): + es_res = search_company_events(body=es_req, **search_args) + if "aggregations" not in es_res: + return [] + + def init_variant_state(variant: dict): + """ + Return new variant state for the passed variant bucket + """ + state = VariantState(variant=variant["key"]) + if fill_variant_state_data: + fill_variant_state_data(variant, state) + + return state + + return [ + MetricState( + metric=metric["key"], + timestamp=dpath.get(metric, "last_event_timestamp/value"), + variants=[ + init_variant_state(variant) + for variant in dpath.get(metric, "variants/buckets") + ], + ) + for metric in dpath.get(es_res, "aggregations/metrics/buckets") + ] + + @abc.abstractmethod + def _process_event(self, event: dict) -> dict: + pass + + @abc.abstractmethod + def _get_same_variant_events_order(self) -> dict: + pass + + def _get_task_metric_events( + self, + task_state: TaskScrollState, + company_id: str, + iter_count: int, + navigate_earlier: bool, + specific_variants_requested: bool, + ) -> Tuple: + """ + Return task metric events grouped by iterations + Update task scroll state + """ + if not task_state.metrics: + return task_state.task, [] + + if task_state.last_max_iter is None: + # the first fetch is always from the latest iteration to the earlier ones + navigate_earlier = True + + must_conditions = [ + {"term": {"task": task_state.task}}, + {"terms": {"metric": [m.metric for m in task_state.metrics]}}, + *self._get_extra_conditions(), + ] + + range_condition = None + if navigate_earlier and task_state.last_min_iter is not None: + range_condition = {"lt": task_state.last_min_iter} + elif not navigate_earlier and task_state.last_max_iter is not None: + range_condition = {"gt": task_state.last_max_iter} + if range_condition: + must_conditions.append({"range": {"iter": range_condition}}) + + metrics_count = len(task_state.metrics) + max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count)) + es_req = { + "size": 0, + "query": {"bool": {"must": must_conditions}}, + "aggs": { + "iters": { + "terms": { + "field": "iter", + "size": iter_count, + "order": {"_key": "desc" if navigate_earlier else "asc"}, + }, + "aggs": { + "metrics": { + "terms": { + "field": "metric", + "size": metrics_count, + "order": {"_key": "asc"}, + }, + "aggs": { + "variants": { + "terms": { + "field": "variant", + "size": max_variants, + "order": {"_key": "asc"}, + }, + "aggs": { + "events": { + "top_hits": { + "sort": self._get_same_variant_events_order() + } + } + }, + } + }, + } + }, + } + }, + } + with translate_errors_context(), TimingContext("es", "_get_task_metric_events"): + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=self.event_type, + body=es_req, + routing=task_state.task, + ) + if "aggregations" not in es_res: + return task_state.task, [] + + invalid_iterations = { + (m.metric, v.variant): v.last_invalid_iteration + for m in task_state.metrics + for v in m.variants + } + allow_uninitialized = ( + False + if specific_variants_requested + else config.get( + "services.events.events_retrieval.debug_images.allow_uninitialized_variants", + False, + ) + ) + + def is_valid_event(event: dict) -> bool: + key = event.get("metric"), event.get("variant") + if key not in invalid_iterations: + return allow_uninitialized + + max_invalid = invalid_iterations[key] + return max_invalid is None or event.get("iter") > max_invalid + + def get_iteration_events(it_: dict) -> Sequence: + return [ + self._process_event(ev["_source"]) + for m in dpath.get(it_, "metrics/buckets") + for v in dpath.get(m, "variants/buckets") + for ev in dpath.get(v, "events/hits/hits") + if is_valid_event(ev["_source"]) + ] + + iterations = [] + for it in dpath.get(es_res, "aggregations/iters/buckets"): + events = get_iteration_events(it) + if events: + iterations.append({"iter": it["key"], "events": events}) + + if not navigate_earlier: + iterations.sort(key=itemgetter("iter"), reverse=True) + if iterations: + task_state.last_max_iter = iterations[0]["iter"] + task_state.last_min_iter = iterations[-1]["iter"] + + return task_state.task, iterations diff --git a/apiserver/bll/event/metric_plots_iterator.py b/apiserver/bll/event/metric_plots_iterator.py new file mode 100644 index 0000000..096b9ec --- /dev/null +++ b/apiserver/bll/event/metric_plots_iterator.py @@ -0,0 +1,25 @@ +from typing import Sequence + +from elasticsearch import Elasticsearch +from redis.client import StrictRedis + +from .event_common import EventType, uncompress_plot +from .metric_events_iterator import MetricEventsIterator + + +class MetricPlotsIterator(MetricEventsIterator): + def __init__(self, redis: StrictRedis, es: Elasticsearch): + super().__init__(redis, es, EventType.metrics_plot) + + def _get_extra_conditions(self) -> Sequence[dict]: + return [] + + def _get_variant_state_aggs(self): + return None, None + + def _process_event(self, event: dict) -> dict: + uncompress_plot(event) + return event + + def _get_same_variant_events_order(self) -> dict: + return {"timestamp": {"order": "desc"}} diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index ee7c5a7..91c442a 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -302,6 +302,48 @@ _definitions { } } } + plots_response_task_metrics { + type: object + properties { + task { + type: string + description: Task ID + } + iterations { + type: array + items { + type: object + properties { + iter { + type: integer + description: Iteration number + } + events { + type: array + items { + type: object + description: Plot event + } + } + } + } + } + } + } + plots_response { + type: object + properties { + scroll_id { + type: string + description: "Scroll ID for getting more results" + } + metrics { + type: array + description: "Plot events grouped by tasks and iterations" + items {"$ref": "#/definitions/plots_response_task_metrics"} + } + } + } debug_image_sample_response { type: object properties { @@ -323,6 +365,27 @@ _definitions { } } } + plot_sample_response { + type: object + properties { + scroll_id { + type: string + description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample" + } + event { + type: object + description: "Plot event" + } + min_iteration { + type: integer + description: "minimal valid iteration for the variant" + } + max_iteration { + type: integer + description: "maximal valid iteration for the variant" + } + } + } } add { "2.1" { @@ -486,6 +549,41 @@ debug_images { } } } +plots { + "999.0" { + description: "Get plot events for the requested amount of iterations per each task" + request { + type: object + required: [ + metrics + ] + properties { + metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/task_metric_variants" } + } + iters { + type: integer + description: "Max number of latest iterations for which to return debug images" + } + navigate_earlier { + type: boolean + description: "If set then events are retreived from latest iterations to earliest ones. Otherwise from earliest iterations to the latest. The default is True" + } + refresh { + type: boolean + description: "If set then scroll will be moved to the latest iterations. The default is False" + } + scroll_id { + type: string + description: "Scroll ID of previous call (used for getting more results)" + } + } + } + response {"$ref": "#/definitions/plots_response"} + } +} get_debug_image_sample { "2.12": { description: "Return the debug image per metric and variant for the provided iteration" @@ -547,6 +645,72 @@ next_debug_image_sample { response {"$ref": "#/definitions/debug_image_sample_response"} } } +get_plot_sample { + "999.0": { + description: "Return the plot per metric and variant for the provided iteration" + request { + type: object + required: [task, metric, variant] + properties { + task { + description: "Task ID" + type: string + } + metric { + description: "Metric name" + type: string + } + variant { + description: "Metric variant" + type: string + } + iteration { + description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved" + type: integer + } + refresh { + description: "If set then scroll state will be refreshed to reflect the latest changes in the plots" + type: boolean + } + scroll_id { + type: string + description: "Scroll ID from the previous call to get_plot_sample or empty" + } + navigate_current_metric { + description: If set then subsequent navigation with next_plot_sample is done on the plots for the passed metric only. Otherwise for all the metrics + type: boolean + default: true + } + } + } + response {"$ref": "#/definitions/plot_sample_response"} + } +} +next_plot_sample { + "999.0": { + description: "Get the plot for the next variant for the same iteration or for the next iteration" + request { + type: object + required: [task, scroll_id] + properties { + task { + description: "Task ID" + type: string + } + scroll_id { + type: string + description: "Scroll ID from the previous call to get_plot_sample" + } + navigate_earlier { + type: boolean + description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration. + Otherwise next variant event from the current iteration or first variant event from the next iteration""" + } + } + } + response {"$ref": "#/definitions/plot_sample_response"} + } +} get_task_metrics{ "2.7": { description: "For each task, get a list of metrics for which the requested event type was reported" @@ -1112,6 +1276,55 @@ multi_task_scalar_metrics_iter_histogram { } } } +get_task_single_value_metrics { + "999.0" { + description: Get single value metrics for the passed tasks + request { + type: object + required: [tasks] + properties { + tasks { + description: "List of task Task IDs" + type: array + items { + type: string + description: "Task ID" + } + } + } + } + response { + type: object + properties { + tasks { + description: Single value metrics grouped by task + type: array + items { + type: object + properties { + task { + type: string + description: Task ID + } + values { + type: array + items { + type: object + properties { + metric { type: string } + variant { type: string} + value { type: number } + timestamp { type: number } + } + } + } + } + } + } + } + } + } +} get_task_latest_scalar_values { "2.1" { description: "Get the tasks's latest scalar values" diff --git a/apiserver/services/events.py b/apiserver/services/events.py index f739490..ef37238 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -27,6 +27,7 @@ from apiserver.apimodels.events import ( ScalarMetricsIterRawRequest, ClearScrollRequest, ClearTaskLogRequest, + SingleValueMetricsRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.event.event_common import EventType, MetricVariants @@ -450,6 +451,30 @@ def multi_task_scalar_metrics_iter_histogram( ) +@endpoint("events.get_task_single_value_metrics") +def get_task_single_value_metrics( + call, company_id: str, request: SingleValueMetricsRequest +): + task_ids = call.data["tasks"] + tasks = task_bll.assert_exists( + company_id=call.identity.company, + only=("id", "name", "company", "company_origin"), + task_ids=task_ids, + allow_public=True, + ) + + companies = {t.get_index_company() for t in tasks} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + + res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids) + call.result.data = dict( + tasks=[{"task": task, "values": values} for task, values in res.items()] + ) + + @endpoint("events.get_multi_task_plots", required_fields=["tasks"]) def get_multi_task_plots_v1_7(call, company_id, _): task_ids = call.data["tasks"] @@ -613,6 +638,56 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest): ) +@endpoint( + "events.plots", + request_data_model=MetricEventsRequest, + response_data_model=MetricEventsResponse, +) +def task_plots(call, company_id, request: MetricEventsRequest): + task_metrics = defaultdict(dict) + for tm in request.metrics: + task_metrics[tm.task][tm.metric] = tm.variants + for metrics in task_metrics.values(): + if None in metrics: + metrics.clear() + + tasks = task_bll.assert_exists( + company_id, + task_ids=list(task_metrics), + allow_public=True, + only=("company", "company_origin"), + ) + + companies = {t.get_index_company() for t in tasks} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + + result = event_bll.plots_iterator.get_task_events( + company_id=next(iter(companies)), + task_metrics=task_metrics, + iter_count=request.iters, + navigate_earlier=request.navigate_earlier, + refresh=request.refresh, + state_id=request.scroll_id, + ) + + call.result.data_model = MetricEventsResponse( + scroll_id=result.next_scroll_id, + metrics=[ + MetricEvents( + task=task, + iterations=[ + IterationEvents(iter=iteration["iter"], events=iteration["events"]) + for iteration in iterations + ], + ) + for (task, iterations) in result.metric_events + ], + ) + + @endpoint("events.debug_images", required_fields=["task"]) def get_debug_images_v1_7(call, company_id, _): task_id = call.data["task"] @@ -769,6 +844,42 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque call.result.data = attr.asdict(res, recurse=False) +@endpoint( + "events.get_plot_sample", request_data_model=GetHistorySampleRequest, +) +def get_plot_sample(call, company_id, request: GetHistorySampleRequest): + task = task_bll.assert_exists( + company_id, task_ids=[request.task], allow_public=True, only=("company",) + )[0] + res = event_bll.plot_sample_history.get_sample_for_variant( + company_id=task.company, + task=request.task, + metric=request.metric, + variant=request.variant, + iteration=request.iteration, + refresh=request.refresh, + state_id=request.scroll_id, + navigate_current_metric=request.navigate_current_metric, + ) + call.result.data = attr.asdict(res, recurse=False) + + +@endpoint( + "events.next_plot_sample", request_data_model=NextHistorySampleRequest, +) +def next_plot_sample(call, company_id, request: NextHistorySampleRequest): + task = task_bll.assert_exists( + company_id, task_ids=[request.task], allow_public=True, only=("company",) + )[0] + res = event_bll.plot_sample_history.get_next_sample( + company_id=task.company, + task=request.task, + state_id=request.scroll_id, + navigate_earlier=request.navigate_earlier, + ) + call.result.data = attr.asdict(res, recurse=False) + + @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): task = task_bll.assert_exists(