Add events.get_multi_task_metrics

This commit is contained in:
allegroai 2024-01-10 15:11:27 +02:00
parent 439911b84c
commit 3752db122b
5 changed files with 172 additions and 2 deletions

View File

@ -156,6 +156,10 @@ class TaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, required=True)
class MultiTaskMetricsRequest(MultiTasksRequestBase):
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
class MultiTaskPlotsRequest(MultiTasksRequestBase):
iters: int = IntField(default=1)
scroll_id: str = StringField()

View File

@ -21,6 +21,7 @@ from apiserver.bll.event.event_common import (
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.tools import safe_get
@ -463,12 +464,96 @@ class EventMetrics:
return {"bool": {"must": must}}
def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]:
"""
For the requested tasks return reported metrics and variants
"""
tasks_ids = {
company: [t.id for t in tasks]
for company, tasks in companies.items()
}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
companies_res: Sequence = list(
pool.map(
partial(
self._get_multi_task_metrics,
event_type=event_type,
),
tasks_ids.items(),
)
)
if len(companies_res) == 1:
return companies_res[0]
res = defaultdict(set)
for c_res in companies_res:
for m, vars_ in c_res.items():
res[m].update(vars_)
return {
k: list(v)
for k, v in res.items()
}
def _get_multi_task_metrics(
self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType
) -> Mapping[str, list]:
company_id, task_ids = company_tasks
if check_empty_data(self.es, company_id, event_type):
return {}
search_args = dict(
es=self.es,
company_id=company_id,
event_type=event_type,
)
query = QueryBuilder.terms("task", task_ids)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query,
**search_args,
)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
}
}
}
},
}
es_res = search_company_events(
body=es_req,
**search_args,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
return {
mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]]
for mb in aggs_result["metrics"]["buckets"]
}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
For the requested tasks return reported metrics per task
"""
if check_empty_data(self.es, company_id, event_type):
return {}

View File

@ -754,6 +754,42 @@ get_task_metrics{
}
}
}
get_multi_task_metrics {
"999.0" {
description: """Get unique metrics and variants from the events of the specified type.
Only events reported for the passed task or model ids are analyzed."""
request {
type: object
required: [ tasks ]
properties {
tasks {
description: task ids to get metrics from
type: array
items {type: string}
}
model_metrics {
description: If not set or set to false then passed ids are task ids otherwise model ids
type: boolean
default: false
}
event_type {
"description": Event type. If not specified then metrics are collected from the reported events of all types
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
get_task_log {
"1.5" {
description: "Get all 'log' events for this task"

View File

@ -31,6 +31,7 @@ from apiserver.apimodels.events import (
GetMetricSamplesRequest,
TaskMetric,
MultiTaskPlotsRequest,
MultiTaskMetricsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
@ -965,6 +966,30 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
}
@endpoint("events.get_multi_task_metrics")
def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest):
companies = _get_task_or_model_index_companies(
company_id, request.tasks, model_events=request.model_events
)
if not companies:
return {"metrics": []}
metrics = event_bll.metrics.get_multi_task_metrics(
companies=companies,
event_type=request.event_type
)
res = [
{
"metric": m,
"variants": sorted(vars_),
}
for m, vars_ in metrics.items()
]
call.result.data = {
"metrics": sorted(res, key=itemgetter("metric"))
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, _):
task_id = call.data["task"]

View File

@ -70,6 +70,26 @@ class TestTaskEvents(TestService):
self._assert_task_metrics(tasks, "log")
self._assert_task_metrics(tasks, "training_stats_scalar")
self._assert_multitask_metrics(
tasks=list(tasks), metrics=["Metric1", "Metric2", "Metric3"]
)
self._assert_multitask_metrics(
tasks=list(tasks),
event_type="training_debug_image",
metrics=["Metric1", "Metric2", "Metric3"],
)
self._assert_multitask_metrics(tasks=list(tasks), event_type="plot", metrics=[])
def _assert_multitask_metrics(
self, tasks: Sequence[str], metrics: Sequence[str], event_type: str = None
):
res = self.api.events.get_multi_task_metrics(
tasks=tasks,
**({"event_type": event_type} if event_type else {}),
).metrics
self.assertEqual([r.metric for r in res], metrics)
self.assertTrue(all(r.variants == ["Test variant"] for r in res))
def _assert_task_metrics(self, tasks: dict, event_type: str):
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
for task, metrics in tasks.items():