Return results per task iterations in debug images request

This commit is contained in:
allegroai 2021-05-03 18:08:14 +03:00
parent 6b0c45a861
commit b67aa05d6f
6 changed files with 163 additions and 221 deletions

View File

@ -89,7 +89,6 @@ class IterationEvents(Base):
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)

View File

@ -1,13 +1,12 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set
import attr
import dpath
from boltons.iterutils import bucketize, first
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
@ -27,19 +26,22 @@ from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
class VariantState(Base):
variant: str = StringField(required=True)
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
class MetricState(Base):
metric: str = StringField(required=True)
variants: Sequence[VariantState] = ListField([VariantState], required=True)
timestamp: int = IntField(default=0)
class TaskScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
@ -48,7 +50,7 @@ class MetricScrollState(Base):
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@ -82,7 +84,7 @@ class DebugImagesIterator:
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
state_.metrics = self._init_metric_states(company_id, task_metrics)
state_.tasks = self._init_task_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
"""
@ -91,9 +93,7 @@ class DebugImagesIterator:
Refresh the state if requested
"""
if refresh:
self._reinit_outdated_metric_states(company_id, state_, task_metrics)
for metric_state in state_.metrics:
metric_state.reset()
self._reinit_outdated_task_states(company_id, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
@ -108,88 +108,113 @@ class DebugImagesIterator:
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
state.tasks,
)
)
return res
def _reinit_outdated_metric_states(
def _reinit_outdated_task_states(
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
Determine the metrics for which new debug image events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return {}
requested_metrics = task_metrics[task.id]
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
)
return {
stats.metric: stats.event_stats_by_type[
self.EVENT_TYPE.value
].last_update
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
]
}
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
update_times = {
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
}
task_metric_states = {
task_state.task: {
metric_state.metric: metric_state for metric_state in task_state.metrics
}
for task_state in state.tasks
}
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = set(
m
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
)
)
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
metrics_to_update = defaultdict(set)
for (task, metric), update_time in update_times.items():
state_metric = first(
m for m in state.metrics if m.task == task and m.name == metric
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
) -> TaskScrollState:
task = old_state.task
updated_state = first(uts for uts in updates if uts.task == task)
if not updated_state:
old_state.reset()
return old_state
updated_metrics = [m.metric for m in updated_state.metrics]
return TaskScrollState(
task=task,
metrics=[
*updated_state.metrics,
*(
old_metric
for old_metric in old_state.metrics
if old_metric.metric not in updated_metrics
),
],
)
if not state_metric or state_metric.timestamp < update_time:
metrics_to_update[task].add(metric)
if metrics_to_update:
state.metrics = [
*(
metric
for metric in state.metrics
if metric.name not in metrics_to_update.get(metric.task, [])
),
*(self._init_metric_states(company_id, metrics_to_update)),
]
state.tasks = [
merge_with_updated_task_states(task_state, updated_task_states)
for task_state in state.tasks
]
def _init_metric_states(
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]]
) -> Sequence[MetricScrollState]:
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(
self._init_metric_states_for_task, company_id=company_id
),
task_metrics.items(),
)
)
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id),
task_metrics.items(),
)
return [
TaskScrollState(task=task, metrics=metric_states,)
for task, metric_states in zip(task_metrics, task_metric_states)
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Set[str]], company_id: str
) -> Sequence[MetricScrollState]:
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
@ -249,12 +274,12 @@ class DebugImagesIterator:
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
def init_variant_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
Return new variant state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
state = VariantState(variant=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
@ -262,102 +287,52 @@ class DebugImagesIterator:
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
variants=[
init_variant_scroll_state(variant)
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
task_state: TaskScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
Update task scroll state
"""
if metric.last_max_iter is None:
if not task_state.metrics:
return task_state.task, []
if task_state.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if navigate_earlier and task_state.last_min_iter is not None:
range_condition = {"lt": task_state.last_min_iter}
elif not navigate_earlier and task_state.last_max_iter is not None:
range_condition = {"gt": task_state.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {
@ -366,15 +341,26 @@ class DebugImagesIterator:
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"metrics": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"url": {"order": "desc"}}
}
}
},
}
},
}
@ -387,74 +373,41 @@ class DebugImagesIterator:
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []
return task_state.task, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
invalid_iterations = {
(m.metric, v.variant): v.last_invalid_iteration
for m in task_state.metrics
for v in m.variants
}
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return False
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
if is_valid_event(ev["_source"])
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
task_state.last_max_iter = iterations[0]["iter"]
task_state.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]
return task_state.task, iterations

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
@ -133,7 +132,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
task_metrics = {task: set(metrics)}
scroll_id = None
urls = defaultdict(set)
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company,
@ -142,17 +141,16 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
state_id=scroll_id,
)
if not res.metric_events or not any(
events for _, _, events in res.metric_events
iterations for _, iterations in res.metric_events
):
break
scroll_id = res.next_scroll_id
for _, metric, iterations in res.metric_events:
metric_urls = set(ev.get("url") for it in iterations for ev in it["events"])
metric_urls.discard(None)
urls[metric].update(metric_urls)
for task, iterations in res.metric_events:
urls.update(ev.get("url") for it in iterations for ev in it["events"])
return set(chain.from_iterable(urls.values()))
urls.discard({None})
return urls
def cleanup_task(

View File

@ -193,10 +193,6 @@
description: "Task ID"
type: string
}
metric {
description: "Metric name. If not specified then all metrics for this task will be returned"
type: string
}
}
}
task_log_event {
@ -370,7 +366,7 @@
}
}
"2.7" {
description: "Get the debug image events for the requested amount of iterations per each task's metric"
description: "Get the debug image events for the requested amount of iterations per each task"
request {
type: object
required: [

View File

@ -628,13 +628,12 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
metrics=[
MetricEvents(
task=task,
metric=metric,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, metric, iterations) in result.metric_events
for (task, iterations) in result.metric_events
],
)

View File

@ -130,11 +130,11 @@ class TestTaskDebugImages(TestService):
# test empty
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
self.assertFalse(res.metrics)
self.assertFalse(res.metrics[0].iterations)
res = self.api.events.debug_images(
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
)
self.assertFalse(res.metrics)
self.assertFalse(res.metrics[0].iterations)
# test not empty
metrics = {
@ -180,10 +180,9 @@ class TestTaskDebugImages(TestService):
)
# with refresh there are new metrics and existing ones are updated
metrics.update(update)
self._assertTaskMetrics(
task=task,
expected_metrics=metrics,
expected_metrics=update,
iterations=1,
scroll_id=scroll_id,
refresh=True,
@ -202,17 +201,16 @@ class TestTaskDebugImages(TestService):
res = self.api.events.debug_images(
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
)
self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics))
if not iterations:
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
for metric_data in res.metrics:
expected_variants = set(expected_metrics[metric_data.metric])
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set(e.variant for e in it_data.events), expected_variants
set((e.metric, e.variant) for e in it_data.events), expected_variants
)
return res.scroll_id
@ -227,7 +225,7 @@ class TestTaskDebugImages(TestService):
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics)
self.assertFalse(res.metrics[0].iterations)
# create events
events = [
@ -295,7 +293,6 @@ class TestTaskDebugImages(TestService):
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
self.assertEqual(data["metric"], metric)
left_iterations = max(0, max(unique_images) - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]: