mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 10:43:10 +00:00
Fix events.get_task_plots endpoint
This commit is contained in:
parent
9c41124b81
commit
7e97ec5555
@ -531,29 +531,46 @@ class EventBLL(object):
|
||||
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
def get_task_plots(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
last_iterations_per_plot: int,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
event_type = EventType.metrics_plot
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
must = [{"term": {"task": task_id}}]
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition, {"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
max_variants = int(max_variants // num_last_iterations)
|
||||
max_variants = int(max_variants // last_iterations_per_plot)
|
||||
|
||||
es_req: dict = {
|
||||
es_req = {
|
||||
"sort": [{"iter": {"order": "desc"}}],
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@ -569,11 +586,10 @@ class EventBLL(object):
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_key": "desc"},
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": last_iterations_per_plot
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -581,116 +597,28 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
es_response = search_company_events(
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
**search_args,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
return [
|
||||
(metric["key"], variant["key"], iter["key"])
|
||||
for metric in es_res["aggregations"]["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
for iter in variant["iters"]["buckets"]
|
||||
]
|
||||
|
||||
def get_task_plots(
|
||||
self,
|
||||
company_id: str,
|
||||
tasks: Sequence[str],
|
||||
last_iterations_per_plot: int = None,
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
model_events: bool = False,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
aggs_result = es_response.get("aggregations")
|
||||
if not aggs_result:
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = EventType.metrics_plot
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None or model_events:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
|
||||
for metric, variant, iter in last_iters:
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"term": {"iter": iter}},
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
events = [
|
||||
hit["_source"]
|
||||
for metrics_bucket in aggs_result["metrics"]["buckets"]
|
||||
for variants_bucket in metrics_bucket["variants"]["buckets"]
|
||||
for hit in variants_bucket["events"]["hits"]["hits"]
|
||||
]
|
||||
self.uncompress_plots(events)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
events=events, total_events=len(events)
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
|
@ -1027,7 +1027,7 @@ get_task_plots {
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
description: "Max number of latest iterations for which to return plots"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
|
@ -661,20 +661,15 @@ def _get_metric_variants_from_request(
|
||||
def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=request.model_events
|
||||
)[0]
|
||||
result = event_bll.get_task_plots(
|
||||
task_or_model.get_index_company(),
|
||||
tasks=[task_id],
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
task_id=task_id,
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
model_events=request.model_events,
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
|
Loading…
Reference in New Issue
Block a user