Fix events.get_task_plots endpoint

This commit is contained in:
allegroai 2022-12-21 18:45:17 +02:00
parent 9c41124b81
commit 7e97ec5555
3 changed files with 46 additions and 123 deletions

View File

@ -531,29 +531,46 @@ class EventBLL(object):
return events, next_scroll_id, total_events return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant( def get_task_plots(
self, self,
company_id: str, company_id: str,
task_id: str, task_id: str,
num_last_iterations: int, last_iterations_per_plot: int,
event_type: EventType,
metric_variants: MetricVariants = None, metric_variants: MetricVariants = None,
): ):
if check_empty_data(self.es, company_id=company_id, event_type=event_type): event_type = EventType.metrics_plot
return [] if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
must = [{"term": {"task": task_id}}] plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition, {"term": {"task": task_id}}]
if metric_variants: if metric_variants:
must.append(get_metric_variants_condition(metric_variants)) must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}} query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type) search_args = dict(
es=self.es, company_id=company_id, event_type=event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts( max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args, query=query, **search_args,
) )
max_variants = int(max_variants // num_last_iterations) max_variants = int(max_variants // last_iterations_per_plot)
es_req: dict = { es_req = {
"sort": [{"iter": {"order": "desc"}}],
"size": 0, "size": 0,
"query": query,
"aggs": { "aggs": {
"metrics": { "metrics": {
"terms": { "terms": {
@ -569,11 +586,10 @@ class EventBLL(object):
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"iters": { "events": {
"terms": { "top_hits": {
"field": "iter", "sort": {"iter": {"order": "desc"}},
"size": num_last_iterations, "size": last_iterations_per_plot
"order": {"_key": "desc"},
} }
} }
}, },
@ -581,116 +597,28 @@ class EventBLL(object):
}, },
} }
}, },
"query": query,
} }
with translate_errors_context(): with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args) es_response = search_company_events(
body=es_req,
ignore=404,
**search_args,
)
if "aggregations" not in es_res: aggs_result = es_response.get("aggregations")
return [] if not aggs_result:
return [
(metric["key"], variant["key"], iter["key"])
for metric in es_res["aggregations"]["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
for iter in variant["iters"]["buckets"]
]
def get_task_plots(
self,
company_id: str,
tasks: Sequence[str],
last_iterations_per_plot: int = None,
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
model_events: bool = False,
):
if scroll_id == self.empty_scroll:
return TaskEventsResult() return TaskEventsResult()
if scroll_id: events = [
with translate_errors_context(): hit["_source"]
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") for metrics_bucket in aggs_result["metrics"]["buckets"]
else: for variants_bucket in metrics_bucket["variants"]["buckets"]
event_type = EventType.metrics_plot for hit in variants_bucket["events"]["hits"]["hits"]
if check_empty_data(self.es, company_id=company_id, event_type=event_type): ]
return TaskEventsResult()
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition]
if last_iterations_per_plot is None or model_events:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
company_id=company_id,
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
for metric, variant, iter in last_iters:
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"term": {"iter": iter}},
]
}
}
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
self.uncompress_plots(events) self.uncompress_plots(events)
return TaskEventsResult( return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events events=events, total_events=len(events)
) )
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]: def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:

View File

@ -1027,7 +1027,7 @@ get_task_plots {
} }
iters { iters {
type: integer type: integer
description: "Max number of latest iterations for which to return debug images" description: "Max number of latest iterations for which to return plots"
} }
scroll_id { scroll_id {
type: string type: string

View File

@ -661,20 +661,15 @@ def _get_metric_variants_from_request(
def get_task_plots(call, company_id, request: TaskPlotsRequest): def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task task_id = request.task
iters = request.iters iters = request.iters
scroll_id = request.scroll_id
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events company_id, task_id, model_events=request.model_events
)[0] )[0]
result = event_bll.get_task_plots( result = event_bll.get_task_plots(
task_or_model.get_index_company(), task_or_model.get_index_company(),
tasks=[task_id], task_id=task_id,
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters, last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics), metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
) )
return_events = result.events return_events = result.events