mirror of
https://github.com/clearml/clearml-server
synced 2025-03-12 06:51:37 +00:00
Fix events.get_multitask_plots to retrieve last iterations per each task metric separately
This commit is contained in:
parent
42556c8dbb
commit
5d3ba4fa73
@ -155,6 +155,13 @@ class TaskMetricsRequest(MultiTasksRequestBase):
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class MultiTaskPlotsRequest(MultiTasksRequestBase):
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
|
@ -57,6 +57,10 @@ class EventsRequest(Base):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class PlotEventsRequest(EventsRequest):
|
||||
last_iters_per_task_metric: bool = BoolField(default=True)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogram(HistogramRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
@ -67,7 +71,7 @@ class SingleValueMetrics(Base):
|
||||
|
||||
class GetTasksDataRequest(Base):
|
||||
debug_images: EventsRequest = EmbeddedField(EventsRequest)
|
||||
plots: EventsRequest = EmbeddedField(EventsRequest)
|
||||
plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest)
|
||||
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(
|
||||
ScalarMetricsIterHistogram
|
||||
)
|
||||
|
@ -717,6 +717,7 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
no_scroll=False,
|
||||
last_iters_per_task_metric=False,
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@ -743,25 +744,47 @@ class EventBLL(object):
|
||||
if last_iter_count is None:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
metrics=metrics,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
if last_iters_per_task_metric:
|
||||
task_metric_iters = self.get_last_iters_per_metric(
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
metrics=metrics,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
for task, last_iters in tasks_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
for (task, metric), last_iters in task_metric_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
else:
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
metrics=metrics,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for task, last_iters in tasks_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
@ -959,6 +982,68 @@ class EventBLL(object):
|
||||
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters_per_metric(
|
||||
self,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
metrics: MetricVariants = None,
|
||||
) -> Mapping[Tuple[str, str], Sequence]:
|
||||
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||
company_ids = [
|
||||
c_id
|
||||
for c_id in set(company_ids)
|
||||
if not check_empty_data(self.es, c_id, event_type)
|
||||
]
|
||||
if not company_ids:
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
must = [{"terms": {"task": task_ids}}]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return {}
|
||||
|
||||
return {
|
||||
(tb["key"], mb["key"]): [ib["key"] for ib in mb["iters"]["buckets"]]
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
for mb in tb["metrics"]["buckets"]
|
||||
}
|
||||
|
||||
def get_last_iters(
|
||||
self,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
|
@ -1149,6 +1149,13 @@ get_multi_task_plots {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"999.0": ${get_multi_task_plots."2.22"} {
|
||||
request.properties.last_iters_per_task_metric {
|
||||
type: boolean
|
||||
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
|
@ -587,6 +587,13 @@ get_task_data {
|
||||
items {"$ref": "#/definitions/single_value_task_metrics"}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_task_data."2.25"} {
|
||||
request.properties.plots.properties.last_iters_per_task_metric {
|
||||
type: boolean
|
||||
description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
"2.23" {
|
||||
|
@ -30,6 +30,7 @@ from apiserver.apimodels.events import (
|
||||
GetVariantSampleRequest,
|
||||
GetMetricSamplesRequest,
|
||||
TaskMetric,
|
||||
MultiTaskPlotsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||
@ -554,6 +555,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
def _get_multitask_plots(
|
||||
companies: TaskCompanies,
|
||||
last_iters: int,
|
||||
last_iters_per_task_metric: bool,
|
||||
metrics: MetricVariants = None,
|
||||
scroll_id=None,
|
||||
no_scroll=True,
|
||||
@ -573,6 +575,7 @@ def _get_multitask_plots(
|
||||
size=config.get(
|
||||
"services.events.events_retrieval.multi_plots_batch_size", 1000
|
||||
),
|
||||
last_iters_per_task_metric=last_iters_per_task_metric,
|
||||
)
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=last_iters, task_names=task_names
|
||||
@ -580,19 +583,17 @@ def _get_multitask_plots(
|
||||
return return_events, result.total_events, result.next_scroll_id
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
|
||||
def get_multi_task_plots(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
@endpoint("events.get_multi_task_plots", min_version="1.8")
|
||||
def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
|
||||
companies = _get_task_or_model_index_companies(
|
||||
company_id, task_ids, model_events=model_events
|
||||
company_id, request.tasks, model_events=request.model_events
|
||||
)
|
||||
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
||||
companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll,
|
||||
companies=companies,
|
||||
last_iters=request.iters,
|
||||
scroll_id=request.scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
last_iters_per_task_metric=request.last_iters_per_task_metric,
|
||||
)
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
|
@ -264,6 +264,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
companies=companies,
|
||||
last_iters=request.plots.iters,
|
||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
|
||||
)[0]
|
||||
|
||||
if request.scalar_metrics_iter_histogram:
|
||||
|
@ -153,7 +153,7 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(p.own_tasks, 0)
|
||||
self.assertIsNone(p.get("own_datasets"))
|
||||
self.assertEqual(
|
||||
p.stats.active.total_tasks, 1 if p.basename != "Project4" else 0
|
||||
p.stats.active.total_tasks, 1 if p.basename != "Project2" else 0
|
||||
)
|
||||
|
||||
def test_project_aggregations(self):
|
||||
|
@ -482,6 +482,36 @@ class TestTaskEvents(TestService):
|
||||
mean(v for v in range(curr * interval, (curr + 1) * interval)),
|
||||
)
|
||||
|
||||
def test_multitask_plots(self):
|
||||
task1 = self._temp_task()
|
||||
events = [
|
||||
self._create_task_event("plot", task1, 1, metric="A", variant="AX", plot_str="Task1_1_A_AX"),
|
||||
self._create_task_event("plot", task1, 2, metric="B", variant="BX", plot_str="Task1_2_B_BX"),
|
||||
self._create_task_event("plot", task1, 3, metric="B", variant="BX", plot_str="Task1_3_B_BX"),
|
||||
self._create_task_event("plot", task1, 3, metric="C", variant="CX", plot_str="Task1_3_C_CX"),
|
||||
]
|
||||
self.send_batch(events)
|
||||
task2 = self._temp_task()
|
||||
events = [
|
||||
self._create_task_event("plot", task2, 1, metric="C", variant="CX", plot_str="Task2_1_C_CX"),
|
||||
self._create_task_event("plot", task2, 2, metric="A", variant="AY", plot_str="Task2_2_A_AY"),
|
||||
]
|
||||
self.send_batch(events)
|
||||
plots = self.api.events.get_multi_task_plots(tasks=[task1, task2]).plots
|
||||
self.assertEqual(len(plots), 3)
|
||||
self.assertEqual(len(plots.A), 2)
|
||||
self.assertEqual(len(plots.A.AX), 1)
|
||||
self.assertEqual(len(plots.A.AY), 1)
|
||||
self.assertEqual(plots.A.AX[task1]["1"]["plots"][0]["plot_str"], "Task1_1_A_AX")
|
||||
self.assertEqual(plots.A.AY[task2]["2"]["plots"][0]["plot_str"], "Task2_2_A_AY")
|
||||
self.assertEqual(len(plots.B), 1)
|
||||
self.assertEqual(len(plots.B.BX), 1)
|
||||
self.assertEqual(plots.B.BX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_B_BX")
|
||||
self.assertEqual(len(plots.C), 1)
|
||||
self.assertEqual(len(plots.C.CX), 2)
|
||||
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
|
||||
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
|
||||
|
||||
def test_task_plots(self):
|
||||
task = self._temp_task()
|
||||
event = self._create_task_event("plot", task, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user