From b67aa05d6f5e85be4a93390557ec9836ac6aca27 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 3 May 2021 18:08:14 +0300 Subject: [PATCH] Return results per task iterations in debug images request --- apiserver/apimodels/events.py | 1 - apiserver/bll/event/debug_images_iterator.py | 345 ++++++++---------- apiserver/bll/task/task_cleanup.py | 14 +- apiserver/schema/services/events.conf | 6 +- apiserver/services/events.py | 3 +- .../tests/automated/test_task_debug_images.py | 15 +- 6 files changed, 163 insertions(+), 221 deletions(-) diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index 71d3ae5..e8dabbf 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -89,7 +89,6 @@ class IterationEvents(Base): class MetricEvents(Base): task: str = StringField() - metric: str = StringField() iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents) diff --git a/apiserver/bll/event/debug_images_iterator.py b/apiserver/bll/event/debug_images_iterator.py index 31e5d3d..ea79efb 100644 --- a/apiserver/bll/event/debug_images_iterator.py +++ b/apiserver/bll/event/debug_images_iterator.py @@ -1,13 +1,12 @@ -from collections import defaultdict from concurrent.futures.thread import ThreadPoolExecutor +from datetime import datetime from functools import partial -from itertools import chain -from operator import attrgetter, itemgetter +from operator import itemgetter from typing import Sequence, Tuple, Optional, Mapping, Set import attr import dpath -from boltons.iterutils import bucketize, first +from boltons.iterutils import first from elasticsearch import Elasticsearch from jsonmodels.fields import StringField, ListField, IntField from jsonmodels.models import Base @@ -27,19 +26,22 @@ from apiserver.database.model.task.task import Task from apiserver.timing_context import TimingContext -class VariantScrollState(Base): - name: str = StringField(required=True) - recycle_url_marker: str = StringField() +class VariantState(Base): + variant: str = StringField(required=True) last_invalid_iteration: int = IntField() -class MetricScrollState(Base): +class MetricState(Base): + metric: str = StringField(required=True) + variants: Sequence[VariantState] = ListField([VariantState], required=True) + timestamp: int = IntField(default=0) + + +class TaskScrollState(Base): task: str = StringField(required=True) - name: str = StringField(required=True) + metrics: Sequence[MetricState] = ListField([MetricState], required=True) last_min_iter: Optional[int] = IntField() last_max_iter: Optional[int] = IntField() - timestamp: int = IntField(default=0) - variants: Sequence[VariantScrollState] = ListField([VariantScrollState]) def reset(self): """Reset the scrolling state for the metric""" @@ -48,7 +50,7 @@ class MetricScrollState(Base): class DebugImageEventsScrollState(Base, JsonSerializableMixin): id: str = StringField(required=True) - metrics: Sequence[MetricScrollState] = ListField([MetricScrollState]) + tasks: Sequence[TaskScrollState] = ListField([TaskScrollState]) warning: str = StringField() @@ -82,7 +84,7 @@ class DebugImagesIterator: return DebugImagesResult() def init_state(state_: DebugImageEventsScrollState): - state_.metrics = self._init_metric_states(company_id, task_metrics) + state_.tasks = self._init_task_states(company_id, task_metrics) def validate_state(state_: DebugImageEventsScrollState): """ @@ -91,9 +93,7 @@ class DebugImagesIterator: Refresh the state if requested """ if refresh: - self._reinit_outdated_metric_states(company_id, state_, task_metrics) - for metric_state in state_.metrics: - metric_state.reset() + self._reinit_outdated_task_states(company_id, state_, task_metrics) with self.cache_manager.get_or_create_state( state_id=state_id, init_state=init_state, validate_state=validate_state @@ -108,88 +108,113 @@ class DebugImagesIterator: iter_count=iter_count, navigate_earlier=navigate_earlier, ), - state.metrics, + state.tasks, ) ) return res - def _reinit_outdated_metric_states( + def _reinit_outdated_task_states( self, company_id, state: DebugImageEventsScrollState, task_metrics: Mapping[str, Set[str]], ): """ - Determines the metrics for which new debug image events were added - since their states were initialized and reinits these states + Determine the metrics for which new debug image events were added + since their states were initialized and re-init these states """ tasks = Task.objects(id__in=list(task_metrics), company=company_id).only( "id", "metric_stats" ) - def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]: - """For metrics that reported debug image events get tuples of task_id/metric_name and last update times""" + def get_last_update_times_for_task_metrics( + task: Task, + ) -> Mapping[str, datetime]: + """For metrics that reported debug image events get mapping of the metric name to the last update times""" metric_stats: Mapping[str, MetricEventStats] = task.metric_stats if not metric_stats: - return [] + return {} requested_metrics = task_metrics[task.id] - return [ - ( - (task.id, stats.metric), - stats.event_stats_by_type[self.EVENT_TYPE.value].last_update, - ) + return { + stats.metric: stats.event_stats_by_type[ + self.EVENT_TYPE.value + ].last_update for stats in metric_stats.values() if self.EVENT_TYPE.value in stats.event_stats_by_type and (not requested_metrics or stats.metric in requested_metrics) - ] + } - update_times = dict( - chain.from_iterable( - get_last_update_times_for_task_metrics(task) for task in tasks + update_times = { + task.id: get_last_update_times_for_task_metrics(task) for task in tasks + } + task_metric_states = { + task_state.task: { + metric_state.metric: metric_state for metric_state in task_state.metrics + } + for task_state in state.tasks + } + task_metrics_to_recalc = {} + for task, metrics_times in update_times.items(): + old_metric_states = task_metric_states[task] + metrics_to_recalc = set( + m + for m, t in metrics_times.items() + if m not in old_metric_states or old_metric_states[m].timestamp < t ) - ) + if metrics_to_recalc: + task_metrics_to_recalc[task] = metrics_to_recalc - metrics_to_update = defaultdict(set) - for (task, metric), update_time in update_times.items(): - state_metric = first( - m for m in state.metrics if m.task == task and m.name == metric + updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc) + + def merge_with_updated_task_states( + old_state: TaskScrollState, updates: Sequence[TaskScrollState] + ) -> TaskScrollState: + task = old_state.task + updated_state = first(uts for uts in updates if uts.task == task) + if not updated_state: + old_state.reset() + return old_state + + updated_metrics = [m.metric for m in updated_state.metrics] + return TaskScrollState( + task=task, + metrics=[ + *updated_state.metrics, + *( + old_metric + for old_metric in old_state.metrics + if old_metric.metric not in updated_metrics + ), + ], ) - if not state_metric or state_metric.timestamp < update_time: - metrics_to_update[task].add(metric) - if metrics_to_update: - state.metrics = [ - *( - metric - for metric in state.metrics - if metric.name not in metrics_to_update.get(metric.task, []) - ), - *(self._init_metric_states(company_id, metrics_to_update)), - ] + state.tasks = [ + merge_with_updated_task_states(task_state, updated_task_states) + for task_state in state.tasks + ] - def _init_metric_states( + def _init_task_states( self, company_id: str, task_metrics: Mapping[str, Set[str]] - ) -> Sequence[MetricScrollState]: + ) -> Sequence[TaskScrollState]: """ Returned initialized metric scroll stated for the requested task metrics """ with ThreadPoolExecutor(EventSettings.max_workers) as pool: - return list( - chain.from_iterable( - pool.map( - partial( - self._init_metric_states_for_task, company_id=company_id - ), - task_metrics.items(), - ) - ) + task_metric_states = pool.map( + partial(self._init_metric_states_for_task, company_id=company_id), + task_metrics.items(), ) + return [ + TaskScrollState(task=task, metrics=metric_states,) + for task, metric_states in zip(task_metrics, task_metric_states) + ] + def _init_metric_states_for_task( self, task_metrics: Tuple[str, Set[str]], company_id: str - ) -> Sequence[MetricScrollState]: + ) -> Sequence[MetricState]: """ Return metric scroll states for the task filled with the variant states for the variants that reported any debug images @@ -249,12 +274,12 @@ class DebugImagesIterator: if "aggregations" not in es_res: return [] - def init_variant_scroll_state(variant: dict): + def init_variant_state(variant: dict): """ - Return new variant scroll state for the passed variant bucket + Return new variant state for the passed variant bucket If the image urls get recycled then fill the last_invalid_iteration field """ - state = VariantScrollState(name=variant["key"]) + state = VariantState(variant=variant["key"]) top_iter_url = dpath.get(variant, "urls/buckets")[0] iters = dpath.get(top_iter_url, "iters/hits/hits") if len(iters) > 1: @@ -262,102 +287,52 @@ class DebugImagesIterator: return state return [ - MetricScrollState( - task=task, - name=metric["key"], + MetricState( + metric=metric["key"], + timestamp=dpath.get(metric, "last_event_timestamp/value"), variants=[ - init_variant_scroll_state(variant) + init_variant_state(variant) for variant in dpath.get(metric, "variants/buckets") ], - timestamp=dpath.get(metric, "last_event_timestamp/value"), ) for metric in dpath.get(es_res, "aggregations/metrics/buckets") ] def _get_task_metric_events( self, - metric: MetricScrollState, + task_state: TaskScrollState, company_id: str, iter_count: int, navigate_earlier: bool, ) -> Tuple: """ Return task metric events grouped by iterations - Update metric scroll state + Update task scroll state """ - if metric.last_max_iter is None: + if not task_state.metrics: + return task_state.task, [] + + if task_state.last_max_iter is None: # the first fetch is always from the latest iteration to the earlier ones navigate_earlier = True must_conditions = [ - {"term": {"task": metric.task}}, - {"term": {"metric": metric.name}}, + {"term": {"task": task_state.task}}, + {"terms": {"metric": [m.metric for m in task_state.metrics]}}, {"exists": {"field": "url"}}, ] - must_not_conditions = [] range_condition = None - if navigate_earlier and metric.last_min_iter is not None: - range_condition = {"lt": metric.last_min_iter} - elif not navigate_earlier and metric.last_max_iter is not None: - range_condition = {"gt": metric.last_max_iter} + if navigate_earlier and task_state.last_min_iter is not None: + range_condition = {"lt": task_state.last_min_iter} + elif not navigate_earlier and task_state.last_max_iter is not None: + range_condition = {"gt": task_state.last_max_iter} if range_condition: must_conditions.append({"range": {"iter": range_condition}}) - if navigate_earlier: - """ - When navigating to earlier iterations consider only - variants whose invalid iterations border is lower than - our starting iteration. For these variants make sure - that only events from the valid iterations are returned - """ - if not metric.last_min_iter: - variants = metric.variants - else: - variants = list( - v - for v in metric.variants - if v.last_invalid_iteration is None - or v.last_invalid_iteration < metric.last_min_iter - ) - if not variants: - return metric.task, metric.name, [] - must_conditions.append( - {"terms": {"variant": list(v.name for v in variants)}} - ) - else: - """ - When navigating to later iterations all variants may be relevant. - For the variants whose invalid border is higher than our starting - iteration make sure that only events from valid iterations are returned - """ - variants = list( - v - for v in metric.variants - if v.last_invalid_iteration is not None - and v.last_invalid_iteration > metric.last_max_iter - ) - - variants_conditions = [ - { - "bool": { - "must": [ - {"term": {"variant": v.name}}, - {"range": {"iter": {"lte": v.last_invalid_iteration}}}, - ] - } - } - for v in variants - if v.last_invalid_iteration is not None - ] - if variants_conditions: - must_not_conditions.append({"bool": {"should": variants_conditions}}) - es_req = { "size": 0, - "query": { - "bool": {"must": must_conditions, "must_not": must_not_conditions} - }, + "query": {"bool": {"must": must_conditions}}, "aggs": { "iters": { "terms": { @@ -366,15 +341,26 @@ class DebugImagesIterator: "order": {"_key": "desc" if navigate_earlier else "asc"}, }, "aggs": { - "variants": { + "metrics": { "terms": { - "field": "variant", - "size": EventSettings.max_variants_count, + "field": "metric", + "size": EventSettings.max_metrics_count, "order": {"_key": "asc"}, }, "aggs": { - "events": { - "top_hits": {"sort": {"url": {"order": "desc"}}} + "variants": { + "terms": { + "field": "variant", + "size": EventSettings.max_variants_count, + "order": {"_key": "asc"}, + }, + "aggs": { + "events": { + "top_hits": { + "sort": {"url": {"order": "desc"}} + } + } + }, } }, } @@ -387,74 +373,41 @@ class DebugImagesIterator: self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, ) if "aggregations" not in es_res: - return metric.task, metric.name, [] + return task_state.task, [] - def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence: + invalid_iterations = { + (m.metric, v.variant): v.last_invalid_iteration + for m in task_state.metrics + for v in m.variants + } + + def is_valid_event(event: dict) -> bool: + key = event.get("metric"), event.get("variant") + if key not in invalid_iterations: + return False + + max_invalid = invalid_iterations[key] + return max_invalid is None or event.get("iter") > max_invalid + + def get_iteration_events(it_: dict) -> Sequence: return [ ev["_source"] - for v in variant_buckets + for m in dpath.get(it_, "metrics/buckets") + for v in dpath.get(m, "variants/buckets") for ev in dpath.get(v, "events/hits/hits") + if is_valid_event(ev["_source"]) ] - iterations = [ - { - "iter": it["key"], - "events": get_iteration_events(dpath.get(it, "variants/buckets")), - } - for it in dpath.get(es_res, "aggregations/iters/buckets") - ] + iterations = [] + for it in dpath.get(es_res, "aggregations/iters/buckets"): + events = get_iteration_events(it) + if events: + iterations.append({"iter": it["key"], "events": events}) + if not navigate_earlier: iterations.sort(key=itemgetter("iter"), reverse=True) if iterations: - metric.last_max_iter = iterations[0]["iter"] - metric.last_min_iter = iterations[-1]["iter"] + task_state.last_max_iter = iterations[0]["iter"] + task_state.last_min_iter = iterations[-1]["iter"] - # Commented for now since the last invalid iteration is calculated in the beginning - # if navigate_earlier and any( - # variant.last_invalid_iteration is None for variant in variants - # ): - # """ - # Variants validation flags due to recycling can - # be set only on navigation to earlier frames - # """ - # iterations = self._update_variants_invalid_iterations(variants, iterations) - - return metric.task, metric.name, iterations - - @staticmethod - def _update_variants_invalid_iterations( - variants: Sequence[VariantScrollState], iterations: Sequence[dict] - ) -> Sequence[dict]: - """ - This code is currently not in used since the invalid iterations - are calculated during MetricState initialization - For variants that do not have recycle url marker set it from the - first event - For variants that do not have last_invalid_iteration set check if the - recycle marker was reached on a certain iteration and set it to the - corresponding iteration - For variants that have a newly set last_invalid_iteration remove - events from the invalid iterations - Return the updated iterations list - """ - variants_lookup = bucketize(variants, attrgetter("name")) - for it in iterations: - iteration = it["iter"] - events_to_remove = [] - for event in it["events"]: - variant = variants_lookup[event["variant"]][0] - if ( - variant.last_invalid_iteration - and variant.last_invalid_iteration >= iteration - ): - events_to_remove.append(event) - continue - event_url = event.get("url") - if not variant.recycle_url_marker: - variant.recycle_url_marker = event_url - elif variant.recycle_url_marker == event_url: - variant.last_invalid_iteration = iteration - events_to_remove.append(event) - if events_to_remove: - it["events"] = [ev for ev in it["events"] if ev not in events_to_remove] - return [it for it in iterations if it["events"]] + return task_state.task, iterations diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py index 319cfeb..3a39960 100644 --- a/apiserver/bll/task/task_cleanup.py +++ b/apiserver/bll/task/task_cleanup.py @@ -1,4 +1,3 @@ -from collections import defaultdict from itertools import chain from operator import attrgetter from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set @@ -133,7 +132,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]: task_metrics = {task: set(metrics)} scroll_id = None - urls = defaultdict(set) + urls = set() while True: res = event_bll.debug_images_iterator.get_task_events( company_id=company, @@ -142,17 +141,16 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]: state_id=scroll_id, ) if not res.metric_events or not any( - events for _, _, events in res.metric_events + iterations for _, iterations in res.metric_events ): break scroll_id = res.next_scroll_id - for _, metric, iterations in res.metric_events: - metric_urls = set(ev.get("url") for it in iterations for ev in it["events"]) - metric_urls.discard(None) - urls[metric].update(metric_urls) + for task, iterations in res.metric_events: + urls.update(ev.get("url") for it in iterations for ev in it["events"]) - return set(chain.from_iterable(urls.values())) + urls.discard({None}) + return urls def cleanup_task( diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 3d532ae..bf56546 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -193,10 +193,6 @@ description: "Task ID" type: string } - metric { - description: "Metric name. If not specified then all metrics for this task will be returned" - type: string - } } } task_log_event { @@ -370,7 +366,7 @@ } } "2.7" { - description: "Get the debug image events for the requested amount of iterations per each task's metric" + description: "Get the debug image events for the requested amount of iterations per each task" request { type: object required: [ diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 18ae5ca..b792025 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -628,13 +628,12 @@ def get_debug_images(call, company_id, request: DebugImagesRequest): metrics=[ MetricEvents( task=task, - metric=metric, iterations=[ IterationEvents(iter=iteration["iter"], events=iteration["events"]) for iteration in iterations ], ) - for (task, metric, iterations) in result.metric_events + for (task, iterations) in result.metric_events ], ) diff --git a/apiserver/tests/automated/test_task_debug_images.py b/apiserver/tests/automated/test_task_debug_images.py index b9081eb..34ded85 100644 --- a/apiserver/tests/automated/test_task_debug_images.py +++ b/apiserver/tests/automated/test_task_debug_images.py @@ -130,11 +130,11 @@ class TestTaskDebugImages(TestService): # test empty res = self.api.events.debug_images(metrics=[{"task": task}], iters=5) - self.assertFalse(res.metrics) + self.assertFalse(res.metrics[0].iterations) res = self.api.events.debug_images( metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True ) - self.assertFalse(res.metrics) + self.assertFalse(res.metrics[0].iterations) # test not empty metrics = { @@ -180,10 +180,9 @@ class TestTaskDebugImages(TestService): ) # with refresh there are new metrics and existing ones are updated - metrics.update(update) self._assertTaskMetrics( task=task, - expected_metrics=metrics, + expected_metrics=update, iterations=1, scroll_id=scroll_id, refresh=True, @@ -202,17 +201,16 @@ class TestTaskDebugImages(TestService): res = self.api.events.debug_images( metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh ) - self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics)) if not iterations: self.assertTrue(all(m.iterations == [] for m in res.metrics)) return res.scroll_id + expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_) for metric_data in res.metrics: - expected_variants = set(expected_metrics[metric_data.metric]) self.assertEqual(len(metric_data.iterations), iterations) for it_data in metric_data.iterations: self.assertEqual( - set(e.variant for e in it_data.events), expected_variants + set((e.metric, e.variant) for e in it_data.events), expected_variants ) return res.scroll_id @@ -227,7 +225,7 @@ class TestTaskDebugImages(TestService): res = self.api.events.debug_images( metrics=[{"task": task, "metric": metric}], iters=5, ) - self.assertFalse(res.metrics) + self.assertFalse(res.metrics[0].iterations) # create events events = [ @@ -295,7 +293,6 @@ class TestTaskDebugImages(TestService): ) data = res["metrics"][0] self.assertEqual(data["task"], task) - self.assertEqual(data["metric"], metric) left_iterations = max(0, max(unique_images) - expected_page * iters) self.assertEqual(len(data["iterations"]), min(iters, left_iterations)) for it in data["iterations"]: