clearml-server/apiserver/bll/event/debug_images_iterator.py

461 lines
18 KiB
Python

from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set
import attr
import dpath
from boltons.iterutils import bucketize, first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
EventType,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
task_metrics: Mapping[str, Set[str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
state_.metrics = self._init_metric_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
if refresh:
self._reinit_outdated_metric_states(company_id, state_, task_metrics)
for metric_state in state_.metrics:
metric_state.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
company_id=company_id,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
)
)
return res
def _reinit_outdated_metric_states(
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
requested_metrics = task_metrics[task.id]
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
)
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
]
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
)
)
metrics_to_update = defaultdict(set)
for (task, metric), update_time in update_times.items():
state_metric = first(
m for m in state.metrics if m.task == task and m.name == metric
)
if not state_metric or state_metric.timestamp < update_time:
metrics_to_update[task].add(metric)
if metrics_to_update:
state.metrics = [
*(
metric
for metric in state.metrics
if metric.name not in metrics_to_update.get(metric.task, [])
),
*(self._init_metric_states(company_id, metrics_to_update)),
]
def _init_metric_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(
self._init_metric_states_for_task, company_id=company_id
),
task_metrics.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Set[str]], company_id: str
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append({"terms": {"metric": list(metrics)}})
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
variants=[
init_variant_scroll_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
"""
if metric.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for ev in dpath.get(v, "events/hits/hits")
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]