Add task names to events.get_single_value_metrics endpoint response

This commit is contained in:
allegroai 2023-07-26 18:50:53 +03:00
parent f7dcbd96ec
commit 3927604648
5 changed files with 97 additions and 35 deletions

View File

@ -111,6 +111,10 @@ single_value_task_metrics {
type: string type: string
description: Task ID description: Task ID
} }
task_name {
type: string
description: Task name
}
values { values {
type: array type: array
items { items {

View File

@ -81,7 +81,11 @@ def add_batch(call: APICall, company_id, _):
if events is None or len(events) == 0: if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems() raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker,) added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info) call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -250,7 +254,9 @@ def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
model_events = call.data["model_events"] model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events, company_id,
task_id,
model_events=model_events,
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
@ -264,7 +270,9 @@ def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
model_events = call.data["model_events"] model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events, company_id,
task_id,
model_events=model_events,
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
@ -282,7 +290,9 @@ def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
model_events = call.data["model_events"] model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events, company_id,
task_id,
model_events=model_events,
)[0] )[0]
metric = call.data["metric"] metric = call.data["metric"]
variant = call.data["variant"] variant = call.data["variant"]
@ -315,7 +325,9 @@ def make_response(
def get_task_events(_, company_id, request: TaskEventsRequest): def get_task_events(_, company_id, request: TaskEventsRequest):
task_id = request.task task_id = request.task
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events, company_id,
task_id,
model_events=request.model_events,
)[0] )[0]
key = ScalarKeyEnum.iter key = ScalarKeyEnum.iter
@ -393,7 +405,9 @@ def get_scalar_metric_data(call, company_id, _):
model_events = call.data.get("model_events", False) model_events = call.data.get("model_events", False)
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events, company_id,
task_id,
model_events=model_events,
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
task_or_model.get_index_company(), task_or_model.get_index_company(),
@ -457,13 +471,17 @@ def scalar_metrics_iter_histogram(
def _get_task_or_model_index_companies( def _get_task_or_model_index_companies(
company_id: str, task_ids: Sequence[str], model_events=False, company_id: str,
task_ids: Sequence[str],
model_events=False,
) -> TaskCompanies: ) -> TaskCompanies:
""" """
Returns lists of tasks grouped by company Returns lists of tasks grouped by company
""" """
tasks_or_models = _assert_task_or_model_exists( tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events, company_id,
task_ids,
model_events=model_events,
) )
unique_ids = set(task_ids) unique_ids = set(task_ids)
@ -502,21 +520,32 @@ def multi_task_scalar_metrics_iter_histogram(
def _get_single_value_metrics_response( def _get_single_value_metrics_response(
value_metrics: Mapping[str, dict] companies: TaskCompanies, value_metrics: Mapping[str, dict]
) -> Sequence[dict]: ) -> Sequence[dict]:
return [{"task": task, "values": values} for task, values in value_metrics.items()] task_names = {
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
}
return [
{"task": task_id, "task_name": task_names.get(task_id), "values": values}
for task_id, values in value_metrics.items()
]
@endpoint("events.get_task_single_value_metrics") @endpoint("events.get_task_single_value_metrics")
def get_task_single_value_metrics( def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest call, company_id: str, request: SingleValueMetricsRequest
): ):
res = event_bll.metrics.get_task_single_value_metrics( companies = _get_task_or_model_index_companies(
companies=_get_task_or_model_index_companies( company_id, request.tasks, request.model_events
company_id, request.tasks, request.model_events )
), call.result.data = dict(
tasks=_get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
),
)
) )
call.result.data = dict(tasks=_get_single_value_metrics_response(res))
@endpoint("events.get_multi_task_plots", required_fields=["tasks"]) @endpoint("events.get_multi_task_plots", required_fields=["tasks"])
@ -770,7 +799,9 @@ def get_debug_images_v1_8(call, company_id, _):
model_events = call.data.get("model_events", False) model_events = call.data.get("model_events", False)
tasks_or_model = _assert_task_or_model_exists( tasks_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events, company_id,
task_id,
model_events=model_events,
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
tasks_or_model.get_index_company(), tasks_or_model.get_index_company(),
@ -826,7 +857,9 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
) )
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest): def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events, company_id,
request.task,
model_events=request.model_events,
)[0] )[0]
res = event_bll.debug_image_sample_history.get_sample_for_variant( res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task_or_model.get_index_company(), company_id=task_or_model.get_index_company(),
@ -848,7 +881,9 @@ def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
) )
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest): def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events, company_id,
request.task,
model_events=request.model_events,
)[0] )[0]
res = event_bll.debug_image_sample_history.get_next_sample( res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(), company_id=task_or_model.get_index_company(),
@ -861,11 +896,14 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
@endpoint( @endpoint(
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest, "events.get_plot_sample",
request_data_model=GetMetricSamplesRequest,
) )
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest): def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events, company_id,
request.task,
model_events=request.model_events,
)[0] )[0]
res = event_bll.plot_sample_history.get_samples_for_metric( res = event_bll.plot_sample_history.get_samples_for_metric(
company_id=task_or_model.get_index_company(), company_id=task_or_model.get_index_company(),
@ -880,11 +918,14 @@ def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
@endpoint( @endpoint(
"events.next_plot_sample", request_data_model=NextHistorySampleRequest, "events.next_plot_sample",
request_data_model=NextHistorySampleRequest,
) )
def next_plot_sample(call, company_id, request: NextHistorySampleRequest): def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events, company_id,
request.task,
model_events=request.model_events,
)[0] )[0]
res = event_bll.plot_sample_history.get_next_sample( res = event_bll.plot_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(), company_id=task_or_model.get_index_company(),
@ -899,7 +940,9 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task_or_models = _assert_task_or_model_exists( task_or_models = _assert_task_or_model_exists(
company_id, request.tasks, model_events=request.model_events, company_id,
request.tasks,
model_events=request.model_events,
) )
res = event_bll.metrics.get_task_metrics( res = event_bll.metrics.get_task_metrics(
task_or_models[0].get_index_company(), task_or_models[0].get_index_company(),

View File

@ -76,7 +76,9 @@ def _assert_report(company_id, task_id, only_fields=None, requires_write_access=
@endpoint("reports.update", response_data_model=UpdateResponse) @endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest): def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report( task = _assert_report(
task_id=request.task, company_id=company_id, only_fields=("status",), task_id=request.task,
company_id=company_id,
only_fields=("status",),
) )
partial_update_dict = { partial_update_dict = {
@ -181,9 +183,9 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest):
project_ids = [project_ids] project_ids = [project_ids]
query = Q(parent__in=project_ids) | Q(id__in=project_ids) query = Q(parent__in=project_ids) | Q(id__in=project_ids)
project_ids = Project.objects( project_ids = Project.objects(query & Q(basename=reports_project_name)).scalar(
query & Q(basename=reports_project_name) "id"
).scalar("id") )
if not project_ids: if not project_ids:
return {"tasks": []} return {"tasks": []}
call_data["project"] = list(project_ids) call_data["project"] = list(project_ids)
@ -281,7 +283,10 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
if request.single_value_metrics: if request.single_value_metrics:
res["single_value_metrics"] = _get_single_value_metrics_response( res["single_value_metrics"] = _get_single_value_metrics_response(
event_bll.metrics.get_task_single_value_metrics(companies=companies) companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
),
) )
call.result.data = res call.result.data = res
@ -295,7 +300,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
) )
task = _assert_report( task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",), company_id=company_id,
task_id=request.task,
only_fields=("project",),
) )
user_id = call.identity.user user_id = call.identity.user
project_name = request.project_name project_name = request.project_name
@ -326,7 +333,8 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
@endpoint( @endpoint(
"reports.publish", response_data_model=UpdateResponse, "reports.publish",
response_data_model=UpdateResponse,
) )
def publish(call: APICall, company_id, request: PublishReportRequest): def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task) task = _assert_report(company_id=company_id, task_id=request.task)
@ -384,7 +392,9 @@ def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.delete") @endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest): def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report( task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",), company_id=company_id,
task_id=request.task,
only_fields=("project",),
) )
if ( if (
task.status != TaskStatus.created task.status != TaskStatus.created

View File

@ -151,12 +151,13 @@ class TestReports(TestService):
def test_reports_task_data(self): def test_reports_task_data(self):
report_task = self._temp_report(name="Rep1") report_task = self._temp_report(name="Rep1")
non_reports_task_name = "test non-reports"
for model_events in (False, True): for model_events in (False, True):
if model_events: if model_events:
non_report_task = self._temp_model(name="hello") non_report_task = self._temp_model(name=non_reports_task_name)
event_args = {"model_event": True} event_args = {"model_event": True}
else: else:
non_report_task = self._temp_task(name="hello") non_report_task = self._temp_task(name=non_reports_task_name)
event_args = {} event_args = {}
debug_image_events = [ debug_image_events = [
self._create_task_event( self._create_task_event(
@ -235,6 +236,7 @@ class TestReports(TestService):
self.assertEqual(len(res.single_value_metrics), 1) self.assertEqual(len(res.single_value_metrics), 1)
task_metrics = res.single_value_metrics[0] task_metrics = res.single_value_metrics[0]
self.assertEqual(task_metrics.task, non_report_task) self.assertEqual(task_metrics.task, non_report_task)
self.assertEqual(task_metrics.task_name, non_reports_task_name)
self.assertEqual( self.assertEqual(
{(v["metric"], v["variant"]) for v in task_metrics["values"]}, {(v["metric"], v["variant"]) for v in task_metrics["values"]},
{(f"Metric_{x}", f"Variant_{y}") for x in range(2) for y in range(2)}, {(f"Metric_{x}", f"Variant_{y}") for x in range(2) for y in range(2)},
@ -253,7 +255,7 @@ class TestReports(TestService):
task_plots = tasks[non_report_task] task_plots = tasks[non_report_task]
self.assertEqual(len(task_plots), 1) self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"] iter_plots = task_plots["1"]
self.assertEqual(iter_plots.name, "hello") self.assertEqual(iter_plots.name, non_reports_task_name)
self.assertEqual(len(iter_plots.plots), 1) self.assertEqual(len(iter_plots.plots), 1)
ev = iter_plots.plots[0] ev = iter_plots.plots[0]
self.assertEqual(ev["metric"], m) self.assertEqual(ev["metric"], m)

View File

@ -14,8 +14,9 @@ from apiserver.tests.automated import TestService
class TestTaskEvents(TestService): class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True) delete_params = dict(can_fail=True, force=True)
default_task_name = "test task events"
def _temp_task(self, name="test task events"): def _temp_task(self, name=default_task_name):
task_input = dict(name=name, type="training",) task_input = dict(name=name, type="training",)
return self.create_temp( return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input "tasks", delete_paramse=self.delete_params, **task_input
@ -115,6 +116,7 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
data = res[0] data = res[0]
self.assertEqual(data.task, task) self.assertEqual(data.task, task)
self.assertEqual(data.task_name, self.default_task_name)
self.assertEqual(len(data["values"]), 1) self.assertEqual(len(data["values"]), 1)
value = data["values"][0] value = data["values"][0]
self.assertEqual(value.metric, metric) self.assertEqual(value.metric, metric)
@ -147,6 +149,7 @@ class TestTaskEvents(TestService):
data = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks[0] data = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks[0]
self.assertEqual(data.task, task) self.assertEqual(data.task, task)
self.assertEqual(data.task_name, self.default_task_name)
self.assertEqual(len(data["values"]), 1) self.assertEqual(len(data["values"]), 1)
value = data["values"][0] value = data["values"][0]
self.assertEqual(value.value, new_value) self.assertEqual(value.value, new_value)