Fix events.get_multitask_plots to retrieve last iterations per each task metric separately

This commit is contained in:
allegroai 2023-07-26 18:34:30 +03:00
parent 42556c8dbb
commit 5d3ba4fa73
9 changed files with 172 additions and 30 deletions

View File

@ -155,6 +155,13 @@ class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)

View File

@ -57,6 +57,10 @@ class EventsRequest(Base):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class PlotEventsRequest(EventsRequest):
last_iters_per_task_metric: bool = BoolField(default=True)
class ScalarMetricsIterHistogram(HistogramRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
@ -67,7 +71,7 @@ class SingleValueMetrics(Base):
class GetTasksDataRequest(Base):
debug_images: EventsRequest = EmbeddedField(EventsRequest)
plots: EventsRequest = EmbeddedField(EventsRequest)
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
ScalarMetricsIterHistogram
)

View File

@ -717,6 +717,7 @@ class EventBLL(object):
size=500,
scroll_id=None,
no_scroll=False,
last_iters_per_task_metric=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@ -743,25 +744,47 @@ class EventBLL(object):
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
tasks_iters = self.get_last_iters(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
if last_iters_per_task_metric:
task_metric_iters = self.get_last_iters_per_metric(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"terms": {"iter": last_iters}},
]
}
}
}
for task, last_iters in tasks_iters.items()
if last_iters
]
for (task, metric), last_iters in task_metric_iters.items()
if last_iters
]
else:
tasks_iters = self.get_last_iters(
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
metrics=metrics,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@ -959,6 +982,68 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters_per_metric(
self,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None,
) -> Mapping[Tuple[str, str], Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = [{"terms": {"task": task_ids}}]
if metrics:
must.append(get_metric_variants_condition(metrics))
es_req: dict = {
"size": 0,
"aggs": {
"tasks": {
"terms": {"field": "task"},
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
}
}
}
}
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_ids,
event_type=event_type,
body=es_req,
)
if "aggregations" not in es_res:
return {}
return {
(tb["key"], mb["key"]): [ib["key"] for ib in mb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
for mb in tb["metrics"]["buckets"]
}
def get_last_iters(
self,
company_id: Union[str, Sequence[str]],

View File

@ -1149,6 +1149,13 @@ get_multi_task_plots {
default: false
}
}
"999.0": ${get_multi_task_plots."2.22"} {
request.properties.last_iters_per_task_metric {
type: boolean
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
default: true
}
}
}
get_vector_metrics_and_variants {
"2.1" {

View File

@ -587,6 +587,13 @@ get_task_data {
items {"$ref": "#/definitions/single_value_task_metrics"}
}
}
"999.0": ${get_task_data."2.25"} {
request.properties.plots.properties.last_iters_per_task_metric {
type: boolean
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
default: true
}
}
}
get_all_ex {
"2.23" {

View File

@ -30,6 +30,7 @@ from apiserver.apimodels.events import (
GetVariantSampleRequest,
GetMetricSamplesRequest,
TaskMetric,
MultiTaskPlotsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@ -554,6 +555,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
def _get_multitask_plots(
companies: TaskCompanies,
last_iters: int,
last_iters_per_task_metric: bool,
metrics: MetricVariants = None,
scroll_id=None,
no_scroll=True,
@ -573,6 +575,7 @@ def _get_multitask_plots(
size=config.get(
"services.events.events_retrieval.multi_plots_batch_size", 1000
),
last_iters_per_task_metric=last_iters_per_task_metric,
)
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=last_iters, task_names=task_names
@ -580,19 +583,17 @@ def _get_multitask_plots(
return return_events, result.total_events, result.next_scroll_id
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
@endpoint("events.get_multi_task_plots", min_version="1.8")
def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
companies = _get_task_or_model_index_companies(
company_id, task_ids, model_events=model_events
company_id, request.tasks, model_events=request.model_events
)
return_events, total_events, next_scroll_id = _get_multitask_plots(
companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll,
companies=companies,
last_iters=request.iters,
scroll_id=request.scroll_id,
no_scroll=request.no_scroll,
last_iters_per_task_metric=request.last_iters_per_task_metric,
)
call.result.data = dict(
plots=return_events,

View File

@ -264,6 +264,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
)[0]
if request.scalar_metrics_iter_histogram:

View File

@ -153,7 +153,7 @@ class TestSubProjects(TestService):
self.assertEqual(p.own_tasks, 0)
self.assertIsNone(p.get("own_datasets"))
self.assertEqual(
p.stats.active.total_tasks, 1 if p.basename != "Project4" else 0
p.stats.active.total_tasks, 1 if p.basename != "Project2" else 0
)
def test_project_aggregations(self):

View File

@ -482,6 +482,36 @@ class TestTaskEvents(TestService):
mean(v for v in range(curr * interval, (curr + 1) * interval)),
)
def test_multitask_plots(self):
task1 = self._temp_task()
events = [
self._create_task_event("plot", task1, 1, metric="A", variant="AX", plot_str="Task1_1_A_AX"),
self._create_task_event("plot", task1, 2, metric="B", variant="BX", plot_str="Task1_2_B_BX"),
self._create_task_event("plot", task1, 3, metric="B", variant="BX", plot_str="Task1_3_B_BX"),
self._create_task_event("plot", task1, 3, metric="C", variant="CX", plot_str="Task1_3_C_CX"),
]
self.send_batch(events)
task2 = self._temp_task()
events = [
self._create_task_event("plot", task2, 1, metric="C", variant="CX", plot_str="Task2_1_C_CX"),
self._create_task_event("plot", task2, 2, metric="A", variant="AY", plot_str="Task2_2_A_AY"),
]
self.send_batch(events)
plots = self.api.events.get_multi_task_plots(tasks=[task1, task2]).plots
self.assertEqual(len(plots), 3)
self.assertEqual(len(plots.A), 2)
self.assertEqual(len(plots.A.AX), 1)
self.assertEqual(len(plots.A.AY), 1)
self.assertEqual(plots.A.AX[task1]["1"]["plots"][0]["plot_str"], "Task1_1_A_AX")
self.assertEqual(plots.A.AY[task2]["2"]["plots"][0]["plot_str"], "Task2_2_A_AY")
self.assertEqual(len(plots.B), 1)
self.assertEqual(len(plots.B.BX), 1)
self.assertEqual(plots.B.BX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_B_BX")
self.assertEqual(len(plots.C), 1)
self.assertEqual(len(plots.C.CX), 2)
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
def test_task_plots(self):
task = self._temp_task()
event = self._create_task_event("plot", task, 0)