From 88a7773621fc2d9fac4aac447c2f15eba69721b9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 10 Jan 2024 15:07:46 +0200 Subject: [PATCH] Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots --- apiserver/apimodels/events.py | 4 ++- apiserver/bll/event/event_metrics.py | 40 +++++++++++++++++---------- apiserver/schema/services/events.conf | 36 +++++++++++++++++++++--- apiserver/services/events.py | 21 ++++++++++---- apiserver/services/reports.py | 25 ++++++++++++----- 5 files changed, 94 insertions(+), 32 deletions(-) diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index 8bd52d9..ca39cdc 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -41,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): ) ], ) + metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) model_events: bool = BoolField(default=False) @@ -148,7 +149,7 @@ class MultiTasksRequestBase(Base): class SingleValueMetricsRequest(MultiTasksRequestBase): - pass + metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) class TaskMetricsRequest(MultiTasksRequestBase): @@ -160,6 +161,7 @@ class MultiTaskPlotsRequest(MultiTasksRequestBase): scroll_id: str = StringField() no_scroll: bool = BoolField(default=False) last_iters_per_task_metric: bool = BoolField(default=True) + metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) class TaskPlotsRequest(Base): diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index a3e2dd8..cdbbe5a 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -161,7 +161,9 @@ class EventMetrics: return res def get_task_single_value_metrics( - self, companies: TaskCompanies + self, + companies: TaskCompanies, + metric_variants: MetricVariants = None, ) -> Mapping[str, dict]: """ For the requested tasks return all the events delivered for the single iteration (-2**31) @@ -179,7 +181,13 @@ class EventMetrics: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: task_events = list( itertools.chain.from_iterable( - pool.map(self._get_task_single_value_metrics, companies.items()) + pool.map( + partial( + self._get_task_single_value_metrics, + metric_variants=metric_variants, + ), + companies.items(), + ) ), ) @@ -195,19 +203,19 @@ class EventMetrics: } def _get_task_single_value_metrics( - self, tasks: Tuple[str, Sequence[str]] + self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None ) -> Sequence[dict]: company_id, task_ids = tasks + must = [ + {"terms": {"task": task_ids}}, + {"term": {"iter": SINGLE_SCALAR_ITERATION}}, + ] + if metric_variants: + must.append(get_metric_variants_condition(metric_variants)) + es_req = { "size": 10000, - "query": { - "bool": { - "must": [ - {"terms": {"task": task_ids}}, - {"term": {"iter": SINGLE_SCALAR_ITERATION}}, - ] - } - }, + "query": {"bool": {"must": must}}, } with translate_errors_context(): es_res = search_company_events( @@ -280,7 +288,8 @@ class EventMetrics: query = {"bool": {"must": must}} search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( - query=query, **search_args, + query=query, + **search_args, ) max_variants = int(max_variants // 2) es_req = { @@ -366,7 +375,8 @@ class EventMetrics: query = self._get_task_metrics_query(task_id=task_id, metrics=metrics) search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( - query=query, **search_args, + query=query, + **search_args, ) max_variants = int(max_variants // 2) es_req = { @@ -432,7 +442,9 @@ class EventMetrics: @classmethod def _get_task_metrics_query( - cls, task_id: str, metrics: Sequence[Tuple[str, str]], + cls, + task_id: str, + metrics: Sequence[Tuple[str, str]], ): must = cls._task_conditions(task_id) if metrics: diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index fcde8fc..e86a8d4 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -971,10 +971,17 @@ get_task_events { } } "2.22": ${get_task_events."2.1"} { - request.properties.model_events { - type: boolean - description: If set then get retrieving model events. Otherwise task events - default: false + request.properties { + model_events { + type: boolean + description: If set then get retrieving model events. Otherwise task events + default: false + } + metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } } } } @@ -1156,6 +1163,13 @@ get_multi_task_plots { default: true } } + "999.0": ${get_multi_task_plots."2.26"} { + request.properties.metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } + } } get_vector_metrics_and_variants { "2.1" { @@ -1342,6 +1356,13 @@ multi_task_scalar_metrics_iter_histogram { default: false } } + "999.0": ${multi_task_scalar_metrics_iter_histogram."2.22"} { + request.properties.metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } + } } get_task_single_value_metrics { "2.20" { @@ -1369,6 +1390,13 @@ get_task_single_value_metrics { default: false } } + "999.0": ${get_task_single_value_metrics."2.22"} { + request.properties.metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } + } } get_task_latest_scalar_values { "2.1" { diff --git a/apiserver/services/events.py b/apiserver/services/events.py index b547949..26fcb09 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -38,6 +38,7 @@ from apiserver.bll.event.events_iterator import Scroll from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey from apiserver.bll.model import ModelBLL from apiserver.bll.task import TaskBLL +from apiserver.bll.task.utils import get_task_with_write_access from apiserver.config_repo import config from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task @@ -73,7 +74,7 @@ def add(call: APICall, company_id, _): data = call.data.copy() added, err_count, err_info = event_bll.add_events( company_id=company_id, - user_id=call.identity.user, + identity=call.identity, events=[data], worker=call.worker, ) @@ -88,7 +89,7 @@ def add_batch(call: APICall, company_id, _): added, err_count, err_info = event_bll.add_events( company_id=company_id, - user_id=call.identity.user, + identity=call.identity, events=events, worker=call.worker, ) @@ -521,6 +522,7 @@ def multi_task_scalar_metrics_iter_histogram( ), samples=request.samples, key=request.key, + metric_variants=_get_metric_variants_from_request(request.metrics), ) ) @@ -548,7 +550,8 @@ def get_task_single_value_metrics( tasks=_get_single_value_metrics_response( companies=companies, value_metrics=event_bll.metrics.get_task_single_value_metrics( - companies=companies + companies=companies, + metric_variants=_get_metric_variants_from_request(request.metrics), ), ) ) @@ -591,10 +594,11 @@ def _get_multitask_plots( companies: TaskCompanies, last_iters: int, last_iters_per_task_metric: bool, - metrics: MetricVariants = None, + request_metrics: Sequence[ApiMetrics] = None, scroll_id=None, no_scroll=True, ) -> Tuple[dict, int, str]: + metrics = _get_metric_variants_from_request(request_metrics) task_names = { t.id: t.name for t in itertools.chain.from_iterable(companies.values()) } @@ -629,6 +633,7 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest): scroll_id=request.scroll_id, no_scroll=request.no_scroll, last_iters_per_task_metric=request.last_iters_per_task_metric, + request_metrics=request.metrics, ) call.result.data = dict( plots=return_events, @@ -965,7 +970,9 @@ def delete_for_task(call, company_id, _): task_id = call.data["task"] allow_locked = call.data.get("allow_locked", False) - task_bll.assert_exists(company_id, task_id, return_tasks=False) + get_task_with_write_access( + task_id=task_id, company_id=company_id, identity=call.identity, only=("id",) + ) call.result.data = dict( deleted=event_bll.delete_task_events( company_id, task_id, allow_locked=allow_locked @@ -990,7 +997,9 @@ def delete_for_model(call: APICall, company_id: str, _): def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest): task_id = request.task - task_bll.assert_exists(company_id, task_id, return_tasks=False) + get_task_with_write_access( + task_id=task_id, company_id=company_id, identity=call.identity, only=("id",) + ) call.result.data = dict( deleted=event_bll.clear_task_log( company_id=company_id, diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 645f563..de4ba60 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -19,7 +19,9 @@ from apiserver.apimodels.reports import ( from apiserver.apierrors import errors from apiserver.apimodels.base import UpdateResponse from apiserver.bll.project.project_bll import reports_project_name, reports_tag +from apiserver.bll.task.utils import get_task_with_write_access from apiserver.database.model.model import Model +from apiserver.service_repo.auth import Identity from apiserver.services.models import conform_model_data from apiserver.services.utils import process_include_subprojects, sort_tags_response from apiserver.bll.organization import OrgBLL @@ -57,15 +59,15 @@ update_fields = { } -def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True): +def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None): if only_fields and "type" not in only_fields: only_fields += ("type",) - task = TaskBLL.get_task_with_access( + task = get_task_with_write_access( task_id=task_id, company_id=company_id, + identity=identity, only=only_fields, - requires_write_access=requires_write_access, ) if task.type != TaskType.report: raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id) @@ -78,6 +80,7 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest): task = _assert_report( task_id=request.task, company_id=company_id, + identity=call.identity, only_fields=("status",), ) @@ -265,7 +268,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): res["plots"] = _get_multitask_plots( companies=companies, last_iters=request.plots.iters, - metrics=_get_metric_variants_from_request(request.plots.metrics), + request_metrics=request.plots.metrics, last_iters_per_task_metric=request.plots.last_iters_per_task_metric, )[0] @@ -302,6 +305,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest): task = _assert_report( company_id=company_id, task_id=request.task, + identity=call.identity, only_fields=("project",), ) user_id = call.identity.user @@ -337,7 +341,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest): response_data_model=UpdateResponse, ) def publish(call: APICall, company_id, request: PublishReportRequest): - task = _assert_report(company_id=company_id, task_id=request.task) + task = _assert_report( + company_id=company_id, task_id=request.task, identity=call.identity + ) updates = ChangeStatusRequest( task=task, new_status=TaskStatus.published, @@ -352,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest): @endpoint("reports.archive") def archive(call: APICall, company_id, request: ArchiveReportRequest): - task = _assert_report(company_id=company_id, task_id=request.task) + task = _assert_report( + company_id=company_id, task_id=request.task, identity=call.identity + ) archived = task.update( status_message=request.message, status_reason="", @@ -366,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest): @endpoint("reports.unarchive") def unarchive(call: APICall, company_id, request: ArchiveReportRequest): - task = _assert_report(company_id=company_id, task_id=request.task) + task = _assert_report( + company_id=company_id, task_id=request.task, identity=call.identity + ) unarchived = task.update( status_message=request.message, status_reason="", @@ -394,6 +404,7 @@ def delete(call: APICall, company_id, request: DeleteReportRequest): task = _assert_report( company_id=company_id, task_id=request.task, + identity=call.identity, only_fields=("project",), ) if (