Add support for navigate_current_metric in events.get_debug_image_sample

This commit is contained in:
allegroai 2022-07-08 17:31:44 +03:00
parent cff98ae900
commit dd8a1503b0
8 changed files with 156 additions and 817 deletions

View File

@ -48,7 +48,7 @@ class TaskMetric(Base):
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
class MetricEventsRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
@ -64,13 +64,14 @@ class TaskMetricVariant(Base):
variant: str = StringField(required=True)
class GetDebugImageSampleRequest(TaskMetricVariant):
class GetHistorySampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
class NextDebugImageSampleRequest(Base):
class NextHistorySampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
@ -119,17 +120,22 @@ class MetricEvents(Base):
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
class MetricEventsResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
class TaskMetricsRequest(Base):
class MultiTasksRequestBase(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)

View File

@ -1,415 +0,0 @@
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
class VariantState(Base):
variant: str = StringField(required=True)
last_invalid_iteration: int = IntField()
class MetricState(Base):
metric: str = StringField(required=True)
variants: Sequence[VariantState] = ListField([VariantState], required=True)
timestamp: int = IntField(default=0)
class TaskScrollState(Base):
task: str = StringField(required=True)
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
state_.tasks = self._init_task_states(company_id, task_metrics)
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
if refresh:
self._reinit_outdated_task_states(company_id, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
company_id=company_id,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.tasks,
)
)
return res
def _reinit_outdated_task_states(
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new debug image events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(
task: Task,
) -> Mapping[str, datetime]:
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return {}
requested_metrics = task_metrics[task.id]
return {
stats.metric: stats.event_stats_by_type[
self.EVENT_TYPE.value
].last_update
for stats in metric_stats.values()
if self.EVENT_TYPE.value in stats.event_stats_by_type
and (not requested_metrics or stats.metric in requested_metrics)
}
update_times = {
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
}
task_metric_states = {
task_state.task: {
metric_state.metric: metric_state for metric_state in task_state.metrics
}
for task_state in state.tasks
}
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = {
m: task_metrics[task].get(m)
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
}
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
) -> TaskScrollState:
task = old_state.task
updated_state = first(uts for uts in updates if uts.task == task)
if not updated_state:
old_state.reset()
return old_state
updated_metrics = [m.metric for m in updated_state.metrics]
return TaskScrollState(
task=task,
metrics=[
*updated_state.metrics,
*(
old_metric
for old_metric in old_state.metrics
if old_metric.metric not in updated_metrics
),
],
)
state.tasks = [
merge_with_updated_task_states(task_state, updated_task_states)
for task_state in state.tasks
]
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id),
task_metrics.items(),
)
return [
TaskScrollState(task=task, metrics=metric_states,)
for task, metric_states in zip(task_metrics, task_metric_states)
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return []
def init_variant_state(variant: dict):
"""
Return new variant state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantState(variant=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
return state
return [
MetricState(
metric=metric["key"],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
variants=[
init_variant_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
task_state: TaskScrollState,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update task scroll state
"""
if not task_state.metrics:
return task_state.task, []
if task_state.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": task_state.task}},
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
{"exists": {"field": "url"}},
]
range_condition = None
if navigate_earlier and task_state.last_min_iter is not None:
range_condition = {"lt": task_state.last_min_iter}
elif not navigate_earlier and task_state.last_max_iter is not None:
range_condition = {"gt": task_state.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"url": {"order": "desc"}}
}
}
},
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
)
if "aggregations" not in es_res:
return task_state.task, []
invalid_iterations = {
(m.metric, v.variant): v.last_invalid_iteration
for m in task_state.metrics
for v in m.variants
}
def is_valid_event(event: dict) -> bool:
key = event.get("metric"), event.get("variant")
if key not in invalid_iterations:
return False
max_invalid = invalid_iterations[key]
return max_invalid is None or event.get("iter") > max_invalid
def get_iteration_events(it_: dict) -> Sequence:
return [
ev["_source"]
for m in dpath.get(it_, "metrics/buckets")
for v in dpath.get(m, "variants/buckets")
for ev in dpath.get(v, "events/hits/hits")
if is_valid_event(ev["_source"])
]
iterations = []
for it in dpath.get(es_res, "aggregations/iters/buckets"):
events = get_iteration_events(it)
if events:
iterations.append({"iter": it["key"], "events": events})
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
task_state.last_max_iter = iterations[0]["iter"]
task_state.last_min_iter = iterations[-1]["iter"]
return task_state.task, iterations

View File

@ -1,375 +0,0 @@
import operator
from typing import Sequence, Tuple, Optional
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugSampleHistoryState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
@attr.s(auto_attribs=True)
class DebugSampleHistoryResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class DebugSampleHistory:
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_debug_image(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> DebugSampleHistoryResult:
"""
Get the debug image for next/prev variant on the current iteration
If does not exist then try getting image for the first/last variant from next/prev iteration
"""
res = DebugSampleHistoryResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
image = self._get_next_for_current_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if not image:
return res
self._fill_res_and_update_state(image=image, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
):
state.variant = image["variant"]
state.iteration = image["iter"]
res.event = image
var_state = first(s for s in state.variant_states if s.name == state.variant)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or image is found
"""
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in state.variant_states
if cmp(var_state.name, state.variant)
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"terms": {"variant": [v.name for v in variants]}},
{"term": {"iter": state.iteration}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": {"variant": "desc" if navigate_earlier else "asc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the image falls in invalid range are discarded
If no suitable image is found then None is returned
"""
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"exists": {"field": "url"}},
]
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in state.variant_states
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = state.variant_states
if not variants:
return
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
must_conditions.append({"bool": {"should": variants_conditions}})
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
es_req = {
"size": 1,
"sort": [{"iter": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_debug_image_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
) -> DebugSampleHistoryResult:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(s for s in state.variant_states if s.name == variant)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
):
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
image=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
company_id=company_id, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
for var_name, min_iter, max_iter in variant_iterations
]
def _get_variant_iterations(
self,
company_id: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
) -> Sequence[Tuple[str, int, int]]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
must = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"exists": {"field": "url"}},
]
if variants:
must.append({"terms": {"variant": variants}})
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": {
"variants": {
# all variants that sent debug images
"terms": {
"field": "variant",
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
):
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(
es_res, ("aggregations", "variants", "buckets")
)
]

View File

@ -13,7 +13,6 @@ from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import (
EventType,
get_index_name,
@ -26,11 +25,14 @@ from apiserver.bll.event.event_common import (
get_max_metric_and_variant_counts,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
@ -70,13 +72,17 @@ class EventBLL(object):
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
self.redis = redis or redman.connection("apiserver")
self._metrics = EventMetrics(self.es)
self._skip_iteration_for_metric = set(
config.get("services.events.ignore_iteration.metrics", [])
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.debug_images_iterator = MetricDebugImagesIterator(
es=self.es, redis=self.redis
)
self.debug_image_sample_history = HistoryDebugImageIterator(
es=self.es, redis=self.redis
)
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
self.events_iterator = EventsIterator(es=self.es)

View File

@ -0,0 +1,56 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .history_sample_iterator import HistorySampleIterator, VariantState
class HistoryDebugImageIterator(HistorySampleIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
return {"bool": {"should": variants_conditions}}
def _process_event(self, event: dict) -> dict:
return event
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
# The min iteration is the lowest iteration that contains non-recycled image url
aggs = {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
}
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter
return aggs, get_min_max_data

View File

@ -0,0 +1,53 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get
from .event_common import EventType
from .metric_events_iterator import MetricEventsIterator, VariantState
class MetricDebugImagesIterator(MetricEventsIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image)
def _get_extra_conditions(self) -> Sequence[dict]:
return [{"exists": {"field": "url"}}]
def _get_variant_state_aggs(self) -> Tuple[dict, Callable[[dict, VariantState], None]]:
aggs = {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
}
def fill_variant_state_data(variant_bucket: dict, state: VariantState):
"""If the image urls get recycled then fill the last_invalid_iteration field"""
top_iter_url = nested_get(variant_bucket, ("urls", "buckets"))[0]
iters = nested_get(top_iter_url, ("iters", "hits", "hits"))
if len(iters) > 1:
state.last_invalid_iteration = nested_get(iters[1], ("_source", "iter"))
return aggs, fill_variant_state_data
def _process_event(self, event: dict) -> dict:
return event
def _get_same_variant_events_order(self) -> dict:
return {"url": {"order": "desc"}}

View File

@ -619,6 +619,13 @@ get_debug_image_sample {
}
response {"$ref": "#/definitions/debug_image_sample_response"}
}
"999.0": ${get_debug_image_sample."2.12"} {
request.properties.navigate_current_metric {
description: If set then subsequent navigation with next_debug_image_sample is done on the debug images for the passed metric only. Otherwise for all the metrics
type: boolean
default: true
}
}
}
next_debug_image_sample {
"2.12": {
@ -1251,7 +1258,7 @@ multi_task_scalar_metrics_iter_histogram {
type: array
items {
type: string
description: "List of task Task IDs"
description: "Task ID"
}
}
samples {

View File

@ -12,15 +12,15 @@ from apiserver.apierrors import errors
from apiserver.apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEventsRequest,
MetricEventsResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
GetHistorySampleRequest,
NextHistorySampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
TaskEventsRequest,
@ -757,10 +757,10 @@ def get_debug_images_v1_8(call, company_id, _):
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
request_data_model=MetricEventsRequest,
response_data_model=MetricEventsResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
def get_debug_images(call, company_id, request: MetricEventsRequest):
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task][tm.metric] = tm.variants
@ -790,7 +790,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
state_id=request.scroll_id,
)
call.result.data_model = DebugImageResponse(
call.result.data_model = MetricEventsResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
@ -808,13 +808,13 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint(
"events.get_debug_image_sample",
min_version="2.12",
request_data_model=GetDebugImageSampleRequest,
request_data_model=GetHistorySampleRequest,
)
def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest):
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
res = event_bll.debug_sample_history.get_debug_image_for_variant(
res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task.company,
task=request.task,
metric=request.metric,
@ -822,6 +822,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id,
navigate_current_metric=request.navigate_current_metric,
)
call.result.data = attr.asdict(res, recurse=False)
@ -829,13 +830,13 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
@endpoint(
"events.next_debug_image_sample",
min_version="2.12",
request_data_model=NextDebugImageSampleRequest,
request_data_model=NextHistorySampleRequest,
)
def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest):
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
res = event_bll.debug_sample_history.get_next_debug_image(
res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task.company,
task=request.task,
state_id=request.scroll_id,