mirror of
https://github.com/clearml/clearml-server
synced 2025-05-10 14:50:44 +00:00
Fix events.get_multitask_plots to retrieve last iterations per each task metric separately
This commit is contained in:
parent
42556c8dbb
commit
5d3ba4fa73
@ -155,6 +155,13 @@ class TaskMetricsRequest(MultiTasksRequestBase):
|
|||||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||||
|
iters: int = IntField(default=1)
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
no_scroll: bool = BoolField(default=False)
|
||||||
|
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||||
|
|
||||||
|
|
||||||
class TaskPlotsRequest(Base):
|
class TaskPlotsRequest(Base):
|
||||||
task: str = StringField(required=True)
|
task: str = StringField(required=True)
|
||||||
iters: int = IntField(default=1)
|
iters: int = IntField(default=1)
|
||||||
|
@ -57,6 +57,10 @@ class EventsRequest(Base):
|
|||||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||||
|
|
||||||
|
|
||||||
|
class PlotEventsRequest(EventsRequest):
|
||||||
|
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||||
|
|
||||||
|
|
||||||
class ScalarMetricsIterHistogram(HistogramRequestBase):
|
class ScalarMetricsIterHistogram(HistogramRequestBase):
|
||||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||||
|
|
||||||
@ -67,7 +71,7 @@ class SingleValueMetrics(Base):
|
|||||||
|
|
||||||
class GetTasksDataRequest(Base):
|
class GetTasksDataRequest(Base):
|
||||||
debug_images: EventsRequest = EmbeddedField(EventsRequest)
|
debug_images: EventsRequest = EmbeddedField(EventsRequest)
|
||||||
plots: EventsRequest = EmbeddedField(EventsRequest)
|
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
|
||||||
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
|
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
|
||||||
ScalarMetricsIterHistogram
|
ScalarMetricsIterHistogram
|
||||||
)
|
)
|
||||||
|
@ -717,6 +717,7 @@ class EventBLL(object):
|
|||||||
size=500,
|
size=500,
|
||||||
scroll_id=None,
|
scroll_id=None,
|
||||||
no_scroll=False,
|
no_scroll=False,
|
||||||
|
last_iters_per_task_metric=False,
|
||||||
) -> TaskEventsResult:
|
) -> TaskEventsResult:
|
||||||
if scroll_id == self.empty_scroll:
|
if scroll_id == self.empty_scroll:
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
@ -743,25 +744,47 @@ class EventBLL(object):
|
|||||||
if last_iter_count is None:
|
if last_iter_count is None:
|
||||||
must.append({"terms": {"task": task_ids}})
|
must.append({"terms": {"task": task_ids}})
|
||||||
else:
|
else:
|
||||||
tasks_iters = self.get_last_iters(
|
if last_iters_per_task_metric:
|
||||||
company_id=company_ids,
|
task_metric_iters = self.get_last_iters_per_metric(
|
||||||
event_type=event_type,
|
company_id=company_ids,
|
||||||
task_id=task_ids,
|
event_type=event_type,
|
||||||
iters=last_iter_count,
|
task_id=task_ids,
|
||||||
metrics=metrics,
|
iters=last_iter_count,
|
||||||
)
|
metrics=metrics,
|
||||||
should = [
|
)
|
||||||
{
|
should = [
|
||||||
"bool": {
|
{
|
||||||
"must": [
|
"bool": {
|
||||||
{"term": {"task": task}},
|
"must": [
|
||||||
{"terms": {"iter": last_iters}},
|
{"term": {"task": task}},
|
||||||
]
|
{"term": {"metric": metric}},
|
||||||
|
{"terms": {"iter": last_iters}},
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
for (task, metric), last_iters in task_metric_iters.items()
|
||||||
for task, last_iters in tasks_iters.items()
|
if last_iters
|
||||||
if last_iters
|
]
|
||||||
]
|
else:
|
||||||
|
tasks_iters = self.get_last_iters(
|
||||||
|
company_id=company_ids,
|
||||||
|
event_type=event_type,
|
||||||
|
task_id=task_ids,
|
||||||
|
iters=last_iter_count,
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
should = [
|
||||||
|
{
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
{"terms": {"iter": last_iters}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for task, last_iters in tasks_iters.items()
|
||||||
|
if last_iters
|
||||||
|
]
|
||||||
if not should:
|
if not should:
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
must.append({"bool": {"should": should}})
|
must.append({"bool": {"should": should}})
|
||||||
@ -959,6 +982,68 @@ class EventBLL(object):
|
|||||||
|
|
||||||
return iterations, vectors
|
return iterations, vectors
|
||||||
|
|
||||||
|
def get_last_iters_per_metric(
|
||||||
|
self,
|
||||||
|
company_id: Union[str, Sequence[str]],
|
||||||
|
event_type: EventType,
|
||||||
|
task_id: Union[str, Sequence[str]],
|
||||||
|
iters: int,
|
||||||
|
metrics: MetricVariants = None,
|
||||||
|
) -> Mapping[Tuple[str, str], Sequence]:
|
||||||
|
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||||
|
company_ids = [
|
||||||
|
c_id
|
||||||
|
for c_id in set(company_ids)
|
||||||
|
if not check_empty_data(self.es, c_id, event_type)
|
||||||
|
]
|
||||||
|
if not company_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||||
|
must = [{"terms": {"task": task_ids}}]
|
||||||
|
if metrics:
|
||||||
|
must.append(get_metric_variants_condition(metrics))
|
||||||
|
|
||||||
|
es_req: dict = {
|
||||||
|
"size": 0,
|
||||||
|
"aggs": {
|
||||||
|
"tasks": {
|
||||||
|
"terms": {"field": "task"},
|
||||||
|
"aggs": {
|
||||||
|
"metrics": {
|
||||||
|
"terms": {"field": "metric"},
|
||||||
|
"aggs": {
|
||||||
|
"iters": {
|
||||||
|
"terms": {
|
||||||
|
"field": "iter",
|
||||||
|
"size": iters,
|
||||||
|
"order": {"_key": "desc"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": {"bool": {"must": must}},
|
||||||
|
}
|
||||||
|
|
||||||
|
with translate_errors_context():
|
||||||
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_ids,
|
||||||
|
event_type=event_type,
|
||||||
|
body=es_req,
|
||||||
|
)
|
||||||
|
if "aggregations" not in es_res:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
(tb["key"], mb["key"]): [ib["key"] for ib in mb["iters"]["buckets"]]
|
||||||
|
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||||
|
for mb in tb["metrics"]["buckets"]
|
||||||
|
}
|
||||||
|
|
||||||
def get_last_iters(
|
def get_last_iters(
|
||||||
self,
|
self,
|
||||||
company_id: Union[str, Sequence[str]],
|
company_id: Union[str, Sequence[str]],
|
||||||
|
@ -1149,6 +1149,13 @@ get_multi_task_plots {
|
|||||||
default: false
|
default: false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${get_multi_task_plots."2.22"} {
|
||||||
|
request.properties.last_iters_per_task_metric {
|
||||||
|
type: boolean
|
||||||
|
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
|
||||||
|
default: true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_vector_metrics_and_variants {
|
get_vector_metrics_and_variants {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
|
@ -587,6 +587,13 @@ get_task_data {
|
|||||||
items {"$ref": "#/definitions/single_value_task_metrics"}
|
items {"$ref": "#/definitions/single_value_task_metrics"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${get_task_data."2.25"} {
|
||||||
|
request.properties.plots.properties.last_iters_per_task_metric {
|
||||||
|
type: boolean
|
||||||
|
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
|
||||||
|
default: true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_all_ex {
|
get_all_ex {
|
||||||
"2.23" {
|
"2.23" {
|
||||||
|
@ -30,6 +30,7 @@ from apiserver.apimodels.events import (
|
|||||||
GetVariantSampleRequest,
|
GetVariantSampleRequest,
|
||||||
GetMetricSamplesRequest,
|
GetMetricSamplesRequest,
|
||||||
TaskMetric,
|
TaskMetric,
|
||||||
|
MultiTaskPlotsRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||||
@ -554,6 +555,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
|||||||
def _get_multitask_plots(
|
def _get_multitask_plots(
|
||||||
companies: TaskCompanies,
|
companies: TaskCompanies,
|
||||||
last_iters: int,
|
last_iters: int,
|
||||||
|
last_iters_per_task_metric: bool,
|
||||||
metrics: MetricVariants = None,
|
metrics: MetricVariants = None,
|
||||||
scroll_id=None,
|
scroll_id=None,
|
||||||
no_scroll=True,
|
no_scroll=True,
|
||||||
@ -573,6 +575,7 @@ def _get_multitask_plots(
|
|||||||
size=config.get(
|
size=config.get(
|
||||||
"services.events.events_retrieval.multi_plots_batch_size", 1000
|
"services.events.events_retrieval.multi_plots_batch_size", 1000
|
||||||
),
|
),
|
||||||
|
last_iters_per_task_metric=last_iters_per_task_metric,
|
||||||
)
|
)
|
||||||
return_events = _get_top_iter_unique_events_per_task(
|
return_events = _get_top_iter_unique_events_per_task(
|
||||||
result.events, max_iters=last_iters, task_names=task_names
|
result.events, max_iters=last_iters, task_names=task_names
|
||||||
@ -580,19 +583,17 @@ def _get_multitask_plots(
|
|||||||
return return_events, result.total_events, result.next_scroll_id
|
return return_events, result.total_events, result.next_scroll_id
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
|
@endpoint("events.get_multi_task_plots", min_version="1.8")
|
||||||
def get_multi_task_plots(call, company_id, _):
|
def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
||||||
task_ids = call.data["tasks"]
|
|
||||||
iters = call.data.get("iters", 1)
|
|
||||||
scroll_id = call.data.get("scroll_id")
|
|
||||||
no_scroll = call.data.get("no_scroll", False)
|
|
||||||
model_events = call.data.get("model_events", False)
|
|
||||||
|
|
||||||
companies = _get_task_or_model_index_companies(
|
companies = _get_task_or_model_index_companies(
|
||||||
company_id, task_ids, model_events=model_events
|
company_id, request.tasks, model_events=request.model_events
|
||||||
)
|
)
|
||||||
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
||||||
companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll,
|
companies=companies,
|
||||||
|
last_iters=request.iters,
|
||||||
|
scroll_id=request.scroll_id,
|
||||||
|
no_scroll=request.no_scroll,
|
||||||
|
last_iters_per_task_metric=request.last_iters_per_task_metric,
|
||||||
)
|
)
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
plots=return_events,
|
plots=return_events,
|
||||||
|
@ -264,6 +264,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
|||||||
companies=companies,
|
companies=companies,
|
||||||
last_iters=request.plots.iters,
|
last_iters=request.plots.iters,
|
||||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||||
|
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if request.scalar_metrics_iter_histogram:
|
if request.scalar_metrics_iter_histogram:
|
||||||
|
@ -153,7 +153,7 @@ class TestSubProjects(TestService):
|
|||||||
self.assertEqual(p.own_tasks, 0)
|
self.assertEqual(p.own_tasks, 0)
|
||||||
self.assertIsNone(p.get("own_datasets"))
|
self.assertIsNone(p.get("own_datasets"))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
p.stats.active.total_tasks, 1 if p.basename != "Project4" else 0
|
p.stats.active.total_tasks, 1 if p.basename != "Project2" else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_project_aggregations(self):
|
def test_project_aggregations(self):
|
||||||
|
@ -482,6 +482,36 @@ class TestTaskEvents(TestService):
|
|||||||
mean(v for v in range(curr * interval, (curr + 1) * interval)),
|
mean(v for v in range(curr * interval, (curr + 1) * interval)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_multitask_plots(self):
|
||||||
|
task1 = self._temp_task()
|
||||||
|
events = [
|
||||||
|
self._create_task_event("plot", task1, 1, metric="A", variant="AX", plot_str="Task1_1_A_AX"),
|
||||||
|
self._create_task_event("plot", task1, 2, metric="B", variant="BX", plot_str="Task1_2_B_BX"),
|
||||||
|
self._create_task_event("plot", task1, 3, metric="B", variant="BX", plot_str="Task1_3_B_BX"),
|
||||||
|
self._create_task_event("plot", task1, 3, metric="C", variant="CX", plot_str="Task1_3_C_CX"),
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
task2 = self._temp_task()
|
||||||
|
events = [
|
||||||
|
self._create_task_event("plot", task2, 1, metric="C", variant="CX", plot_str="Task2_1_C_CX"),
|
||||||
|
self._create_task_event("plot", task2, 2, metric="A", variant="AY", plot_str="Task2_2_A_AY"),
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
plots = self.api.events.get_multi_task_plots(tasks=[task1, task2]).plots
|
||||||
|
self.assertEqual(len(plots), 3)
|
||||||
|
self.assertEqual(len(plots.A), 2)
|
||||||
|
self.assertEqual(len(plots.A.AX), 1)
|
||||||
|
self.assertEqual(len(plots.A.AY), 1)
|
||||||
|
self.assertEqual(plots.A.AX[task1]["1"]["plots"][0]["plot_str"], "Task1_1_A_AX")
|
||||||
|
self.assertEqual(plots.A.AY[task2]["2"]["plots"][0]["plot_str"], "Task2_2_A_AY")
|
||||||
|
self.assertEqual(len(plots.B), 1)
|
||||||
|
self.assertEqual(len(plots.B.BX), 1)
|
||||||
|
self.assertEqual(plots.B.BX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_B_BX")
|
||||||
|
self.assertEqual(len(plots.C), 1)
|
||||||
|
self.assertEqual(len(plots.C.CX), 2)
|
||||||
|
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
|
||||||
|
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
|
||||||
|
|
||||||
def test_task_plots(self):
|
def test_task_plots(self):
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
event = self._create_task_event("plot", task, 0)
|
event = self._create_task_event("plot", task, 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user