Support iterating over all task metrics in task debug images

This commit is contained in:
allegroai 2021-05-03 17:43:02 +03:00
parent c034c1a986
commit 0d5174c453
6 changed files with 150 additions and 66 deletions

View File

@ -38,7 +38,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
metric: str = StringField(default=None)
class DebugImagesRequest(Base):

View File

@ -3,17 +3,16 @@ from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
from typing import Sequence, Tuple, Optional, Mapping, Set
import attr
import dpath
from boltons.iterutils import bucketize
from boltons.iterutils import bucketize, first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
@ -73,7 +72,7 @@ class DebugImagesIterator:
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
task_metrics: Mapping[str, Set[str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
@ -83,8 +82,7 @@ class DebugImagesIterator:
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
state_.metrics = self._init_metric_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
"""
@ -92,14 +90,8 @@ class DebugImagesIterator:
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, state_)
self._reinit_outdated_metric_states(company_id, state_, task_metrics)
for metric_state in state_.metrics:
metric_state.reset()
@ -123,14 +115,16 @@ class DebugImagesIterator:
return res
def _reinit_outdated_metric_states(
self, company_id, state: DebugImageEventsScrollState
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
@ -140,6 +134,7 @@ class DebugImagesIterator:
if not metric_stats:
return []
requested_metrics = task_metrics[task.id]
return [
(
(task.id, stats.metric),
@ -147,6 +142,7 @@ class DebugImagesIterator:
)
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
]
update_times = dict(
@ -154,32 +150,31 @@ class DebugImagesIterator:
get_last_update_times_for_task_metrics(task) for task in tasks
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
company_id,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
]
metrics_to_update = defaultdict(set)
for (task, metric), update_time in update_times.items():
state_metric = first(
m for m in state.metrics if m.task == task and m.name == metric
)
if not state_metric or state_metric.timestamp < update_time:
metrics_to_update[task].add(metric)
if metrics_to_update:
state.metrics = [
*(
metric
for metric in state.metrics
if metric.name not in metrics_to_update.get(metric.task, [])
),
*(self._init_metric_states(company_id, metrics_to_update)),
]
def _init_metric_states(
self, company_id: str, metrics: Sequence[Tuple[str, str]]
self, company_id: str, task_metrics: Mapping[str, Set[str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
@ -187,30 +182,25 @@ class DebugImagesIterator:
partial(
self._init_metric_states_for_task, company_id=company_id
),
tasks.items(),
task_metrics.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
self, task_metrics: Tuple[str, Set[str]], company_id: str
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append({"terms": {"metric": list(metrics)}})
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
}
},
"query": {"bool": {"must": must}},
"aggs": {
"metrics": {
"terms": {
@ -254,10 +244,7 @@ class DebugImagesIterator:
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return []
@ -397,10 +384,7 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []

View File

@ -110,12 +110,15 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
if not metrics:
return set()
task_metrics = [(task, metric) for metric in metrics]
task_metrics = {task: set(metrics)}
scroll_id = None
urls = defaultdict(set)
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id
company_id=company,
task_metrics=task_metrics,
iter_count=100,
state_id=scroll_id,
)
if not res.metric_events or not any(
events for _, _, events in res.metric_events

View File

@ -187,14 +187,14 @@
}
task_metric {
type: object
required: [task, metric]
required: [task]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
description: "Metric name. If not specified then all metrics for this task will be returned"
type: string
}
}

View File

@ -594,10 +594,16 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
task_metrics = defaultdict(set)
for tm in request.metrics:
task_metrics[tm.task].add(tm.metric)
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
tasks = task_bll.assert_exists(
company_id,
task_ids=task_ids,
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
)
@ -610,7 +616,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
result = event_bll.debug_images_iterator.get_task_events(
company_id=next(iter(companies)),
metrics=[(m.task, m.metric) for m in request.metrics],
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,

View File

@ -1,6 +1,5 @@
from functools import partial
from typing import Sequence
from typing import Sequence, Mapping
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService
@ -128,6 +127,98 @@ class TestTaskDebugImages(TestService):
def test_task_debug_images(self):
task = self._temp_task()
# test empty
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
self.assertFalse(res.metrics)
res = self.api.events.debug_images(
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
)
self.assertFalse(res.metrics)
# test not empty
metrics = {
"Metric1": ["Variant1", "Variant2"],
"Metric2": ["Variant3", "Variant4"],
}
events = [
self._create_task_event(
task=task,
iteration=1,
metric=metric,
variant=variant,
url=f"{metric}_{variant}_{1}",
)
for metric, variants in metrics.items()
for variant in variants
]
self.send_batch(events)
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=1
)
# test refresh
update = {
"Metric2": ["Variant3", "Variant4", "Variant5"],
"Metric3": ["VariantA", "VariantB"],
}
events = [
self._create_task_event(
task=task,
iteration=2,
metric=metric,
variant=variant,
url=f"{metric}_{variant}_{2}",
)
for metric, variants in update.items()
for variant in variants
]
self.send_batch(events)
# without refresh the metric states are not updated
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
)
# with refresh there are new metrics and existing ones are updated
metrics.update(update)
self._assertTaskMetrics(
task=task,
expected_metrics=metrics,
iterations=1,
scroll_id=scroll_id,
refresh=True,
)
pass
def _assertTaskMetrics(
self,
task: str,
expected_metrics: Mapping[str, Sequence[str]],
iterations,
scroll_id: str = None,
refresh=False,
) -> str:
res = self.api.events.debug_images(
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
)
self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics))
if not iterations:
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
for metric_data in res.metrics:
expected_variants = set(expected_metrics[metric_data.metric])
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set(e.variant for e in it_data.events), expected_variants
)
return res.scroll_id
def test_get_debug_images_navigation(self):
task = self._temp_task()
metric = "Metric1"
variants = [("Variant1", 7), ("Variant2", 4)]
iterations = 10
@ -195,7 +286,7 @@ class TestTaskDebugImages(TestService):
expected_page: int,
iters: int = 5,
**extra_params,
):
) -> str:
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=iters,