diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index bfe38eb..71d3ae5 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -38,7 +38,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): class TaskMetric(Base): task: str = StringField(required=True) - metric: str = StringField(required=True) + metric: str = StringField(default=None) class DebugImagesRequest(Base): diff --git a/apiserver/bll/event/debug_images_iterator.py b/apiserver/bll/event/debug_images_iterator.py index 09470dd..31e5d3d 100644 --- a/apiserver/bll/event/debug_images_iterator.py +++ b/apiserver/bll/event/debug_images_iterator.py @@ -3,17 +3,16 @@ from concurrent.futures.thread import ThreadPoolExecutor from functools import partial from itertools import chain from operator import attrgetter, itemgetter -from typing import Sequence, Tuple, Optional, Mapping +from typing import Sequence, Tuple, Optional, Mapping, Set import attr import dpath -from boltons.iterutils import bucketize +from boltons.iterutils import bucketize, first from elasticsearch import Elasticsearch from jsonmodels.fields import StringField, ListField, IntField from jsonmodels.models import Base from redis import StrictRedis -from apiserver.apierrors import errors from apiserver.apimodels import JsonSerializableMixin from apiserver.bll.event.event_common import ( EventSettings, @@ -73,7 +72,7 @@ class DebugImagesIterator: def get_task_events( self, company_id: str, - metrics: Sequence[Tuple[str, str]], + task_metrics: Mapping[str, Set[str]], iter_count: int, navigate_earlier: bool = True, refresh: bool = False, @@ -83,8 +82,7 @@ class DebugImagesIterator: return DebugImagesResult() def init_state(state_: DebugImageEventsScrollState): - unique_metrics = set(metrics) - state_.metrics = self._init_metric_states(company_id, list(unique_metrics)) + state_.metrics = self._init_metric_states(company_id, task_metrics) def validate_state(state_: DebugImageEventsScrollState): """ @@ -92,14 +90,8 @@ class DebugImagesIterator: as requested in the current call. Refresh the state if requested """ - state_metrics = set((m.task, m.name) for m in state_.metrics) - if state_metrics != set(metrics): - raise errors.bad_request.InvalidScrollId( - "Task metrics stored in the state do not match the passed ones", - scroll_id=state_.id, - ) if refresh: - self._reinit_outdated_metric_states(company_id, state_) + self._reinit_outdated_metric_states(company_id, state_, task_metrics) for metric_state in state_.metrics: metric_state.reset() @@ -123,14 +115,16 @@ class DebugImagesIterator: return res def _reinit_outdated_metric_states( - self, company_id, state: DebugImageEventsScrollState + self, + company_id, + state: DebugImageEventsScrollState, + task_metrics: Mapping[str, Set[str]], ): """ Determines the metrics for which new debug image events were added since their states were initialized and reinits these states """ - task_ids = set(metric.task for metric in state.metrics) - tasks = Task.objects(id__in=list(task_ids), company=company_id).only( + tasks = Task.objects(id__in=list(task_metrics), company=company_id).only( "id", "metric_stats" ) @@ -140,6 +134,7 @@ class DebugImagesIterator: if not metric_stats: return [] + requested_metrics = task_metrics[task.id] return [ ( (task.id, stats.metric), @@ -147,6 +142,7 @@ class DebugImagesIterator: ) for stats in metric_stats.values() if self.EVENT_TYPE.value in stats.event_stats_by_type + and (not requested_metrics or stats.metric in requested_metrics) ] update_times = dict( @@ -154,32 +150,31 @@ class DebugImagesIterator: get_last_update_times_for_task_metrics(task) for task in tasks ) ) - outdated_metrics = [ - metric - for metric in state.metrics - if (metric.task, metric.name) in update_times - and update_times[metric.task, metric.name] > metric.timestamp - ] - state.metrics = [ - *(metric for metric in state.metrics if metric not in outdated_metrics), - *( - self._init_metric_states( - company_id, - [(metric.task, metric.name) for metric in outdated_metrics], - ) - ), - ] + + metrics_to_update = defaultdict(set) + for (task, metric), update_time in update_times.items(): + state_metric = first( + m for m in state.metrics if m.task == task and m.name == metric + ) + if not state_metric or state_metric.timestamp < update_time: + metrics_to_update[task].add(metric) + + if metrics_to_update: + state.metrics = [ + *( + metric + for metric in state.metrics + if metric.name not in metrics_to_update.get(metric.task, []) + ), + *(self._init_metric_states(company_id, metrics_to_update)), + ] def _init_metric_states( - self, company_id: str, metrics: Sequence[Tuple[str, str]] + self, company_id: str, task_metrics: Mapping[str, Set[str]] ) -> Sequence[MetricScrollState]: """ Returned initialized metric scroll stated for the requested task metrics """ - tasks = defaultdict(list) - for (task, metric) in metrics: - tasks[task].append(metric) - with ThreadPoolExecutor(EventSettings.max_workers) as pool: return list( chain.from_iterable( @@ -187,30 +182,25 @@ class DebugImagesIterator: partial( self._init_metric_states_for_task, company_id=company_id ), - tasks.items(), + task_metrics.items(), ) ) ) def _init_metric_states_for_task( - self, task_metrics: Tuple[str, Sequence[str]], company_id: str + self, task_metrics: Tuple[str, Set[str]], company_id: str ) -> Sequence[MetricScrollState]: """ Return metric scroll states for the task filled with the variant states for the variants that reported any debug images """ task, metrics = task_metrics + must = [{"term": {"task": task}}, {"exists": {"field": "url"}}] + if metrics: + must.append({"terms": {"metric": list(metrics)}}) es_req: dict = { "size": 0, - "query": { - "bool": { - "must": [ - {"term": {"task": task}}, - {"terms": {"metric": metrics}}, - {"exists": {"field": "url"}}, - ] - } - }, + "query": {"bool": {"must": must}}, "aggs": { "metrics": { "terms": { @@ -254,10 +244,7 @@ class DebugImagesIterator: with translate_errors_context(), TimingContext("es", "_init_metric_states"): es_res = search_company_events( - self.es, - company_id=company_id, - event_type=self.EVENT_TYPE, - body=es_req, + self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, ) if "aggregations" not in es_res: return [] @@ -397,10 +384,7 @@ class DebugImagesIterator: } with translate_errors_context(), TimingContext("es", "get_debug_image_events"): es_res = search_company_events( - self.es, - company_id=company_id, - event_type=self.EVENT_TYPE, - body=es_req, + self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, ) if "aggregations" not in es_res: return metric.task, metric.name, [] diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py index c813d04..f816788 100644 --- a/apiserver/bll/task/task_cleanup.py +++ b/apiserver/bll/task/task_cleanup.py @@ -110,12 +110,15 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]: if not metrics: return set() - task_metrics = [(task, metric) for metric in metrics] + task_metrics = {task: set(metrics)} scroll_id = None urls = defaultdict(set) while True: res = event_bll.debug_images_iterator.get_task_events( - company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id + company_id=company, + task_metrics=task_metrics, + iter_count=100, + state_id=scroll_id, ) if not res.metric_events or not any( events for _, _, events in res.metric_events diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 83257bf..3d532ae 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -187,14 +187,14 @@ } task_metric { type: object - required: [task, metric] + required: [task] properties { task { description: "Task ID" type: string } metric { - description: "Metric name" + description: "Metric name. If not specified then all metrics for this task will be returned" type: string } } diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 6908bcb..18ae5ca 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -594,10 +594,16 @@ def get_debug_images_v1_8(call, company_id, _): response_data_model=DebugImageResponse, ) def get_debug_images(call, company_id, request: DebugImagesRequest): - task_ids = {m.task for m in request.metrics} + task_metrics = defaultdict(set) + for tm in request.metrics: + task_metrics[tm.task].add(tm.metric) + for metrics in task_metrics.values(): + if None in metrics: + metrics.clear() + tasks = task_bll.assert_exists( company_id, - task_ids=task_ids, + task_ids=list(task_metrics), allow_public=True, only=("company", "company_origin"), ) @@ -610,7 +616,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest): result = event_bll.debug_images_iterator.get_task_events( company_id=next(iter(companies)), - metrics=[(m.task, m.metric) for m in request.metrics], + task_metrics=task_metrics, iter_count=request.iters, navigate_earlier=request.navigate_earlier, refresh=request.refresh, diff --git a/apiserver/tests/automated/test_task_debug_images.py b/apiserver/tests/automated/test_task_debug_images.py index a582078..b9081eb 100644 --- a/apiserver/tests/automated/test_task_debug_images.py +++ b/apiserver/tests/automated/test_task_debug_images.py @@ -1,6 +1,5 @@ from functools import partial -from typing import Sequence - +from typing import Sequence, Mapping from apiserver.es_factory import es_factory from apiserver.tests.automated import TestService @@ -128,6 +127,98 @@ class TestTaskDebugImages(TestService): def test_task_debug_images(self): task = self._temp_task() + + # test empty + res = self.api.events.debug_images(metrics=[{"task": task}], iters=5) + self.assertFalse(res.metrics) + res = self.api.events.debug_images( + metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True + ) + self.assertFalse(res.metrics) + + # test not empty + metrics = { + "Metric1": ["Variant1", "Variant2"], + "Metric2": ["Variant3", "Variant4"], + } + events = [ + self._create_task_event( + task=task, + iteration=1, + metric=metric, + variant=variant, + url=f"{metric}_{variant}_{1}", + ) + for metric, variants in metrics.items() + for variant in variants + ] + self.send_batch(events) + scroll_id = self._assertTaskMetrics( + task=task, expected_metrics=metrics, iterations=1 + ) + + # test refresh + update = { + "Metric2": ["Variant3", "Variant4", "Variant5"], + "Metric3": ["VariantA", "VariantB"], + } + events = [ + self._create_task_event( + task=task, + iteration=2, + metric=metric, + variant=variant, + url=f"{metric}_{variant}_{2}", + ) + for metric, variants in update.items() + for variant in variants + ] + self.send_batch(events) + # without refresh the metric states are not updated + scroll_id = self._assertTaskMetrics( + task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id + ) + + # with refresh there are new metrics and existing ones are updated + metrics.update(update) + self._assertTaskMetrics( + task=task, + expected_metrics=metrics, + iterations=1, + scroll_id=scroll_id, + refresh=True, + ) + + pass + + def _assertTaskMetrics( + self, + task: str, + expected_metrics: Mapping[str, Sequence[str]], + iterations, + scroll_id: str = None, + refresh=False, + ) -> str: + res = self.api.events.debug_images( + metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh + ) + self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics)) + if not iterations: + self.assertTrue(all(m.iterations == [] for m in res.metrics)) + return res.scroll_id + + for metric_data in res.metrics: + expected_variants = set(expected_metrics[metric_data.metric]) + self.assertEqual(len(metric_data.iterations), iterations) + for it_data in metric_data.iterations: + self.assertEqual( + set(e.variant for e in it_data.events), expected_variants + ) + + return res.scroll_id + + def test_get_debug_images_navigation(self): + task = self._temp_task() metric = "Metric1" variants = [("Variant1", 7), ("Variant2", 4)] iterations = 10 @@ -195,7 +286,7 @@ class TestTaskDebugImages(TestService): expected_page: int, iters: int = 5, **extra_params, - ): + ) -> str: res = self.api.events.debug_images( metrics=[{"task": task, "metric": metric}], iters=iters,