Add task names to events.get_single_value_metrics endpoint response

This commit is contained in:
allegroai 2023-07-26 18:50:53 +03:00
parent f7dcbd96ec
commit 3927604648
5 changed files with 97 additions and 35 deletions

View File

@ -111,6 +111,10 @@ single_value_task_metrics {
type: string
description: Task ID
}
task_name {
type: string
description: Task name
}
values {
type: array
items {

View File

@ -81,7 +81,11 @@ def add_batch(call: APICall, company_id, _):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker,)
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -250,7 +254,9 @@ def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
@ -264,7 +270,9 @@ def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
@ -282,7 +290,9 @@ def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
model_events = call.data["model_events"]
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
@ -315,7 +325,9 @@ def make_response(
def get_task_events(_, company_id, request: TaskEventsRequest):
task_id = request.task
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=request.model_events,
company_id,
task_id,
model_events=request.model_events,
)[0]
key = ScalarKeyEnum.iter
@ -393,7 +405,9 @@ def get_scalar_metric_data(call, company_id, _):
model_events = call.data.get("model_events", False)
task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
result = event_bll.get_task_events(
task_or_model.get_index_company(),
@ -457,13 +471,17 @@ def scalar_metrics_iter_histogram(
def _get_task_or_model_index_companies(
company_id: str, task_ids: Sequence[str], model_events=False,
company_id: str,
task_ids: Sequence[str],
model_events=False,
) -> TaskCompanies:
"""
Returns lists of tasks grouped by company
"""
tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events,
company_id,
task_ids,
model_events=model_events,
)
unique_ids = set(task_ids)
@ -502,21 +520,32 @@ def multi_task_scalar_metrics_iter_histogram(
def _get_single_value_metrics_response(
value_metrics: Mapping[str, dict]
companies: TaskCompanies, value_metrics: Mapping[str, dict]
) -> Sequence[dict]:
return [{"task": task, "values": values} for task, values in value_metrics.items()]
task_names = {
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
}
return [
{"task": task_id, "task_name": task_names.get(task_id), "values": values}
for task_id, values in value_metrics.items()
]
@endpoint("events.get_task_single_value_metrics")
def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest
):
res = event_bll.metrics.get_task_single_value_metrics(
companies=_get_task_or_model_index_companies(
company_id, request.tasks, request.model_events
),
companies = _get_task_or_model_index_companies(
company_id, request.tasks, request.model_events
)
call.result.data = dict(
tasks=_get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
),
)
)
call.result.data = dict(tasks=_get_single_value_metrics_response(res))
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
@ -770,7 +799,9 @@ def get_debug_images_v1_8(call, company_id, _):
model_events = call.data.get("model_events", False)
tasks_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
company_id,
task_id,
model_events=model_events,
)[0]
result = event_bll.get_task_events(
tasks_or_model.get_index_company(),
@ -826,7 +857,9 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
)
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task_or_model.get_index_company(),
@ -848,7 +881,9 @@ def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
)
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(),
@ -861,11 +896,14 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
@endpoint(
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
"events.get_plot_sample",
request_data_model=GetMetricSamplesRequest,
)
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_samples_for_metric(
company_id=task_or_model.get_index_company(),
@ -880,11 +918,14 @@ def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
@endpoint(
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
"events.next_plot_sample",
request_data_model=NextHistorySampleRequest,
)
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events,
company_id,
request.task,
model_events=request.model_events,
)[0]
res = event_bll.plot_sample_history.get_next_sample(
company_id=task_or_model.get_index_company(),
@ -899,7 +940,9 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task_or_models = _assert_task_or_model_exists(
company_id, request.tasks, model_events=request.model_events,
company_id,
request.tasks,
model_events=request.model_events,
)
res = event_bll.metrics.get_task_metrics(
task_or_models[0].get_index_company(),

View File

@ -76,7 +76,9 @@ def _assert_report(company_id, task_id, only_fields=None, requires_write_access=
@endpoint("reports.update", response_data_model=UpdateResponse)
def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task, company_id=company_id, only_fields=("status",),
task_id=request.task,
company_id=company_id,
only_fields=("status",),
)
partial_update_dict = {
@ -181,9 +183,9 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest):
project_ids = [project_ids]
query = Q(parent__in=project_ids) | Q(id__in=project_ids)
project_ids = Project.objects(
query & Q(basename=reports_project_name)
).scalar("id")
project_ids = Project.objects(query & Q(basename=reports_project_name)).scalar(
"id"
)
if not project_ids:
return {"tasks": []}
call_data["project"] = list(project_ids)
@ -281,7 +283,10 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
if request.single_value_metrics:
res["single_value_metrics"] = _get_single_value_metrics_response(
event_bll.metrics.get_task_single_value_metrics(companies=companies)
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
),
)
call.result.data = res
@ -295,7 +300,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
)
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
company_id=company_id,
task_id=request.task,
only_fields=("project",),
)
user_id = call.identity.user
project_name = request.project_name
@ -326,7 +333,8 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
@endpoint(
"reports.publish", response_data_model=UpdateResponse,
"reports.publish",
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
@ -384,7 +392,9 @@ def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.delete")
def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id, task_id=request.task, only_fields=("project",),
company_id=company_id,
task_id=request.task,
only_fields=("project",),
)
if (
task.status != TaskStatus.created

View File

@ -151,12 +151,13 @@ class TestReports(TestService):
def test_reports_task_data(self):
report_task = self._temp_report(name="Rep1")
non_reports_task_name = "test non-reports"
for model_events in (False, True):
if model_events:
non_report_task = self._temp_model(name="hello")
non_report_task = self._temp_model(name=non_reports_task_name)
event_args = {"model_event": True}
else:
non_report_task = self._temp_task(name="hello")
non_report_task = self._temp_task(name=non_reports_task_name)
event_args = {}
debug_image_events = [
self._create_task_event(
@ -235,6 +236,7 @@ class TestReports(TestService):
self.assertEqual(len(res.single_value_metrics), 1)
task_metrics = res.single_value_metrics[0]
self.assertEqual(task_metrics.task, non_report_task)
self.assertEqual(task_metrics.task_name, non_reports_task_name)
self.assertEqual(
{(v["metric"], v["variant"]) for v in task_metrics["values"]},
{(f"Metric_{x}", f"Variant_{y}") for x in range(2) for y in range(2)},
@ -253,7 +255,7 @@ class TestReports(TestService):
task_plots = tasks[non_report_task]
self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"]
self.assertEqual(iter_plots.name, "hello")
self.assertEqual(iter_plots.name, non_reports_task_name)
self.assertEqual(len(iter_plots.plots), 1)
ev = iter_plots.plots[0]
self.assertEqual(ev["metric"], m)

View File

@ -14,8 +14,9 @@ from apiserver.tests.automated import TestService
class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
default_task_name = "test task events"
def _temp_task(self, name="test task events"):
def _temp_task(self, name=default_task_name):
task_input = dict(name=name, type="training",)
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
@ -115,6 +116,7 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res), 1)
data = res[0]
self.assertEqual(data.task, task)
self.assertEqual(data.task_name, self.default_task_name)
self.assertEqual(len(data["values"]), 1)
value = data["values"][0]
self.assertEqual(value.metric, metric)
@ -147,6 +149,7 @@ class TestTaskEvents(TestService):
data = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks[0]
self.assertEqual(data.task, task)
self.assertEqual(data.task_name, self.default_task_name)
self.assertEqual(len(data["values"]), 1)
value = data["values"][0]
self.assertEqual(value.value, new_value)