import abc from concurrent.futures.thread import ThreadPoolExecutor from datetime import datetime from functools import partial from operator import itemgetter from typing import Sequence, Tuple, Optional, Mapping, Callable import attr import dpath from boltons.iterutils import first from elasticsearch import Elasticsearch from jsonmodels.fields import StringField, ListField, IntField from jsonmodels.models import Base from redis import StrictRedis from apiserver.apimodels import JsonSerializableMixin from apiserver.bll.event.event_common import ( EventSettings, check_empty_data, search_company_events, EventType, get_metric_variants_condition, get_max_metric_and_variant_counts, ) from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.metrics import MetricEventStats from apiserver.database.model.task.task import Task class VariantState(Base): variant: str = StringField(required=True) last_invalid_iteration: int = IntField() class MetricState(Base): metric: str = StringField(required=True) variants: Sequence[VariantState] = ListField([VariantState], required=True) timestamp: int = IntField(default=0) class TaskScrollState(Base): task: str = StringField(required=True) metrics: Sequence[MetricState] = ListField([MetricState], required=True) last_min_iter: Optional[int] = IntField() last_max_iter: Optional[int] = IntField() def reset(self): """Reset the scrolling state for the metric""" self.last_min_iter = self.last_max_iter = None class MetricEventsScrollState(Base, JsonSerializableMixin): id: str = StringField(required=True) tasks: Sequence[TaskScrollState] = ListField([TaskScrollState]) warning: str = StringField() @attr.s(auto_attribs=True) class MetricEventsResult(object): metric_events: Sequence[tuple] = [] next_scroll_id: str = None class MetricEventsIterator: def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType): self.es = es self.event_type = event_type self.cache_manager = RedisCacheManager( state_class=MetricEventsScrollState, redis=redis, expiration_interval=EventSettings.state_expiration_sec, ) def get_task_events( self, company_id: str, task_metrics: Mapping[str, dict], iter_count: int, navigate_earlier: bool = True, refresh: bool = False, state_id: str = None, ) -> MetricEventsResult: if check_empty_data(self.es, company_id, self.event_type): return MetricEventsResult() def init_state(state_: MetricEventsScrollState): state_.tasks = self._init_task_states(company_id, task_metrics) def validate_state(state_: MetricEventsScrollState): """ Validate that the metrics stored in the state are the same as requested in the current call. Refresh the state if requested """ if refresh: self._reinit_outdated_task_states(company_id, state_, task_metrics) with self.cache_manager.get_or_create_state( state_id=state_id, init_state=init_state, validate_state=validate_state ) as state: res = MetricEventsResult(next_scroll_id=state.id) specific_variants_requested = any( variants for t, metrics in task_metrics.items() if metrics for m, variants in metrics.items() ) with ThreadPoolExecutor(EventSettings.max_workers) as pool: res.metric_events = list( pool.map( partial( self._get_task_metric_events, company_id=company_id, iter_count=iter_count, navigate_earlier=navigate_earlier, specific_variants_requested=specific_variants_requested, ), state.tasks, ) ) return res def _reinit_outdated_task_states( self, company_id, state: MetricEventsScrollState, task_metrics: Mapping[str, dict], ): """ Determine the metrics for which new event_type events were added since their states were initialized and re-init these states """ tasks = Task.objects(id__in=list(task_metrics), company=company_id).only( "id", "metric_stats" ) def get_last_update_times_for_task_metrics( task: Task, ) -> Mapping[str, datetime]: """For metrics that reported event_type events get mapping of the metric name to the last update times""" metric_stats: Mapping[str, MetricEventStats] = task.metric_stats if not metric_stats: return {} requested_metrics = task_metrics[task.id] return { stats.metric: stats.event_stats_by_type[ self.event_type.value ].last_update for stats in metric_stats.values() if self.event_type.value in stats.event_stats_by_type and (not requested_metrics or stats.metric in requested_metrics) } update_times = { task.id: get_last_update_times_for_task_metrics(task) for task in tasks } task_metric_states = { task_state.task: { metric_state.metric: metric_state for metric_state in task_state.metrics } for task_state in state.tasks } task_metrics_to_recalc = {} for task, metrics_times in update_times.items(): old_metric_states = task_metric_states[task] metrics_to_recalc = { m: task_metrics[task].get(m) for m, t in metrics_times.items() if m not in old_metric_states or old_metric_states[m].timestamp < t } if metrics_to_recalc: task_metrics_to_recalc[task] = metrics_to_recalc updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc) def merge_with_updated_task_states( old_state: TaskScrollState, updates: Sequence[TaskScrollState] ) -> TaskScrollState: task = old_state.task updated_state = first(uts for uts in updates if uts.task == task) if not updated_state: old_state.reset() return old_state updated_metrics = [m.metric for m in updated_state.metrics] return TaskScrollState( task=task, metrics=[ *updated_state.metrics, *( old_metric for old_metric in old_state.metrics if old_metric.metric not in updated_metrics ), ], ) state.tasks = [ merge_with_updated_task_states(task_state, updated_task_states) for task_state in state.tasks ] def _init_task_states( self, company_id: str, task_metrics: Mapping[str, dict] ) -> Sequence[TaskScrollState]: """ Returned initialized metric scroll stated for the requested task metrics """ with ThreadPoolExecutor(EventSettings.max_workers) as pool: task_metric_states = pool.map( partial(self._init_metric_states_for_task, company_id=company_id), task_metrics.items(), ) return [ TaskScrollState(task=task, metrics=metric_states,) for task, metric_states in zip(task_metrics, task_metric_states) ] @abc.abstractmethod def _get_extra_conditions(self) -> Sequence[dict]: pass @abc.abstractmethod def _get_variant_state_aggs( self, ) -> Tuple[dict, Callable[[dict, VariantState], None]]: pass def _init_metric_states_for_task( self, task_metrics: Tuple[str, dict], company_id: str ) -> Sequence[MetricState]: """ Return metric scroll states for the task filled with the variant states for the variants that reported any event_type events """ task, metrics = task_metrics must = [{"term": {"task": task}}, *self._get_extra_conditions()] if metrics: must.append(get_metric_variants_condition(metrics)) query = {"bool": {"must": must}} search_args = dict( es=self.es, company_id=company_id, event_type=self.event_type ) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args ) max_variants = int(max_variants // 2) variant_state_aggs, fill_variant_state_data = self._get_variant_state_aggs() es_req: dict = { "size": 0, "query": query, "aggs": { "metrics": { "terms": { "field": "metric", "size": max_metrics, "order": {"_key": "asc"}, }, "aggs": { "last_event_timestamp": {"max": {"field": "timestamp"}}, "variants": { "terms": { "field": "variant", "size": max_variants, "order": {"_key": "asc"}, }, **( {"aggs": variant_state_aggs} if variant_state_aggs else {} ), }, }, } }, } with translate_errors_context(): es_res = search_company_events(body=es_req, **search_args) if "aggregations" not in es_res: return [] def init_variant_state(variant: dict): """ Return new variant state for the passed variant bucket """ state = VariantState(variant=variant["key"]) if fill_variant_state_data: fill_variant_state_data(variant, state) return state return [ MetricState( metric=metric["key"], timestamp=dpath.get(metric, "last_event_timestamp/value"), variants=[ init_variant_state(variant) for variant in dpath.get(metric, "variants/buckets") ], ) for metric in dpath.get(es_res, "aggregations/metrics/buckets") ] @abc.abstractmethod def _process_event(self, event: dict) -> dict: pass @abc.abstractmethod def _get_same_variant_events_order(self) -> dict: pass def _get_task_metric_events( self, task_state: TaskScrollState, company_id: str, iter_count: int, navigate_earlier: bool, specific_variants_requested: bool, ) -> Tuple: """ Return task metric events grouped by iterations Update task scroll state """ if not task_state.metrics: return task_state.task, [] if task_state.last_max_iter is None: # the first fetch is always from the latest iteration to the earlier ones navigate_earlier = True must_conditions = [ {"term": {"task": task_state.task}}, {"terms": {"metric": [m.metric for m in task_state.metrics]}}, *self._get_extra_conditions(), ] range_condition = None if navigate_earlier and task_state.last_min_iter is not None: range_condition = {"lt": task_state.last_min_iter} elif not navigate_earlier and task_state.last_max_iter is not None: range_condition = {"gt": task_state.last_max_iter} if range_condition: must_conditions.append({"range": {"iter": range_condition}}) metrics_count = len(task_state.metrics) max_variants = int(EventSettings.max_es_buckets / (metrics_count * iter_count)) es_req = { "size": 0, "query": {"bool": {"must": must_conditions}}, "aggs": { "iters": { "terms": { "field": "iter", "size": iter_count, "order": {"_key": "desc" if navigate_earlier else "asc"}, }, "aggs": { "metrics": { "terms": { "field": "metric", "size": metrics_count, "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", "size": max_variants, "order": {"_key": "asc"}, }, "aggs": { "events": { "top_hits": { "sort": self._get_same_variant_events_order() } } }, } }, } }, } }, } with translate_errors_context(): es_res = search_company_events( self.es, company_id=company_id, event_type=self.event_type, body=es_req, ) if "aggregations" not in es_res: return task_state.task, [] invalid_iterations = { (m.metric, v.variant): v.last_invalid_iteration for m in task_state.metrics for v in m.variants } allow_uninitialized = ( False if specific_variants_requested else config.get( "services.events.events_retrieval.debug_images.allow_uninitialized_variants", False, ) ) def is_valid_event(event: dict) -> bool: key = event.get("metric"), event.get("variant") if key not in invalid_iterations: return allow_uninitialized max_invalid = invalid_iterations[key] return max_invalid is None or event.get("iter") > max_invalid def get_iteration_events(it_: dict) -> Sequence: return [ self._process_event(ev["_source"]) for m in dpath.get(it_, "metrics/buckets") for v in dpath.get(m, "variants/buckets") for ev in dpath.get(v, "events/hits/hits") if is_valid_event(ev["_source"]) ] iterations = [] for it in dpath.get(es_res, "aggregations/iters/buckets"): events = get_iteration_events(it) if events: iterations.append({"iter": it["key"], "events": events}) if not navigate_earlier: iterations.sort(key=itemgetter("iter"), reverse=True) if iterations: task_state.last_max_iter = iterations[0]["iter"] task_state.last_min_iter = iterations[-1]["iter"] return task_state.task, iterations