Fix events.get_task_plots endpoint

This commit is contained in:
allegroai 2022-12-21 18:45:17 +02:00
parent 9c41124b81
commit 7e97ec5555
3 changed files with 46 additions and 123 deletions

View File

@ -531,29 +531,46 @@ class EventBLL(object):
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
def get_task_plots(
self,
company_id: str,
task_id: str,
num_last_iterations: int,
event_type: EventType,
last_iterations_per_plot: int,
metric_variants: MetricVariants = None,
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
event_type = EventType.metrics_plot
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
must = [{"term": {"task": task_id}}]
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition, {"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // num_last_iterations)
max_variants = int(max_variants // last_iterations_per_plot)
es_req: dict = {
es_req = {
"sort": [{"iter": {"order": "desc"}}],
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
@ -569,11 +586,10 @@ class EventBLL(object):
"order": {"_key": "asc"},
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": num_last_iterations,
"order": {"_key": "desc"},
"events": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": last_iterations_per_plot
}
}
},
@ -581,116 +597,28 @@ class EventBLL(object):
},
}
},
"query": query,
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
es_response = search_company_events(
body=es_req,
ignore=404,
**search_args,
)
if "aggregations" not in es_res:
return []
return [
(metric["key"], variant["key"], iter["key"])
for metric in es_res["aggregations"]["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
for iter in variant["iters"]["buckets"]
]
def get_task_plots(
self,
company_id: str,
tasks: Sequence[str],
last_iterations_per_plot: int = None,
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
model_events: bool = False,
):
if scroll_id == self.empty_scroll:
aggs_result = es_response.get("aggregations")
if not aggs_result:
return TaskEventsResult()
if scroll_id:
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = EventType.metrics_plot
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition]
if last_iterations_per_plot is None or model_events:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
company_id=company_id,
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
for metric, variant, iter in last_iters:
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"term": {"iter": iter}},
]
}
}
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
events = [
hit["_source"]
for metrics_bucket in aggs_result["metrics"]["buckets"]
for variants_bucket in metrics_bucket["variants"]["buckets"]
for hit in variants_bucket["events"]["hits"]["hits"]
]
self.uncompress_plots(events)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
events=events, total_events=len(events)
)
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:

View File

@ -1027,7 +1027,7 @@ get_task_plots {
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
description: "Max number of latest iterations for which to return plots"
}
scroll_id {
type: string

View File

@ -661,20 +661,15 @@ def _get_metric_variants_from_request(
def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events
)[0]
result = event_bll.get_task_plots(
task_or_model.get_index_company(),
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
task_id=task_id,
last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
)
return_events = result.events