Allow filtering on event metrics in multi-task endpoints get_task_single_value_metrics, multi_task_scalar_metrics_iter_histogram and get_multi_task_plots

This commit is contained in:
allegroai 2024-01-10 15:07:46 +02:00
parent 35c4061992
commit 88a7773621
5 changed files with 94 additions and 32 deletions

View File

@ -41,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
)
],
)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
@ -148,7 +149,7 @@ class MultiTasksRequestBase(Base):
class SingleValueMetricsRequest(MultiTasksRequestBase):
pass
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskMetricsRequest(MultiTasksRequestBase):
@ -160,6 +161,7 @@ class MultiTaskPlotsRequest(MultiTasksRequestBase):
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
last_iters_per_task_metric: bool = BoolField(default=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class TaskPlotsRequest(Base):

View File

@ -161,7 +161,9 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, companies: TaskCompanies
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
@ -179,7 +181,13 @@ class EventMetrics:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
pool.map(
partial(
self._get_task_single_value_metrics,
metric_variants=metric_variants,
),
companies.items(),
)
),
)
@ -195,19 +203,19 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, tasks: Tuple[str, Sequence[str]]
self, tasks: Tuple[str, Sequence[str]], metric_variants: MetricVariants = None
) -> Sequence[dict]:
company_id, task_ids = tasks
must = [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context():
es_res = search_company_events(
@ -280,7 +288,8 @@ class EventMetrics:
query = {"bool": {"must": must}}
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@ -366,7 +375,8 @@ class EventMetrics:
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
query=query,
**search_args,
)
max_variants = int(max_variants // 2)
es_req = {
@ -432,7 +442,9 @@ class EventMetrics:
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
cls,
task_id: str,
metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:

View File

@ -971,10 +971,17 @@ get_task_events {
}
}
"2.22": ${get_task_events."2.1"} {
request.properties.model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
request.properties {
model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
@ -1156,6 +1163,13 @@ get_multi_task_plots {
default: true
}
}
"999.0": ${get_multi_task_plots."2.26"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@ -1342,6 +1356,13 @@ multi_task_scalar_metrics_iter_histogram {
default: false
}
}
"999.0": ${multi_task_scalar_metrics_iter_histogram."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_single_value_metrics {
"2.20" {
@ -1369,6 +1390,13 @@ get_task_single_value_metrics {
default: false
}
}
"999.0": ${get_task_single_value_metrics."2.22"} {
request.properties.metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
get_task_latest_scalar_values {
"2.1" {

View File

@ -38,6 +38,7 @@ from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@ -73,7 +74,7 @@ def add(call: APICall, company_id, _):
data = call.data.copy()
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
events=[data],
worker=call.worker,
)
@ -88,7 +89,7 @@ def add_batch(call: APICall, company_id, _):
added, err_count, err_info = event_bll.add_events(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
events=events,
worker=call.worker,
)
@ -521,6 +522,7 @@ def multi_task_scalar_metrics_iter_histogram(
),
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
)
@ -548,7 +550,8 @@ def get_task_single_value_metrics(
tasks=_get_single_value_metrics_response(
companies=companies,
value_metrics=event_bll.metrics.get_task_single_value_metrics(
companies=companies
companies=companies,
metric_variants=_get_metric_variants_from_request(request.metrics),
),
)
)
@ -591,10 +594,11 @@ def _get_multitask_plots(
companies: TaskCompanies,
last_iters: int,
last_iters_per_task_metric: bool,
metrics: MetricVariants = None,
request_metrics: Sequence[ApiMetrics] = None,
scroll_id=None,
no_scroll=True,
) -> Tuple[dict, int, str]:
metrics = _get_metric_variants_from_request(request_metrics)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
@ -629,6 +633,7 @@ def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest):
scroll_id=request.scroll_id,
no_scroll=request.no_scroll,
last_iters_per_task_metric=request.last_iters_per_task_metric,
request_metrics=request.metrics,
)
call.result.data = dict(
plots=return_events,
@ -965,7 +970,9 @@ def delete_for_task(call, company_id, _):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
@ -990,7 +997,9 @@ def delete_for_model(call: APICall, company_id: str, _):
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task
task_bll.assert_exists(company_id, task_id, return_tasks=False)
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=call.identity, only=("id",)
)
call.result.data = dict(
deleted=event_bll.clear_task_log(
company_id=company_id,

View File

@ -19,7 +19,9 @@ from apiserver.apimodels.reports import (
from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.database.model.model import Model
from apiserver.service_repo.auth import Identity
from apiserver.services.models import conform_model_data
from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL
@ -57,15 +59,15 @@ update_fields = {
}
def _assert_report(company_id, task_id, only_fields=None, requires_write_access=True):
def _assert_report(company_id: str, task_id: str, identity: Identity, only_fields=None):
if only_fields and "type" not in only_fields:
only_fields += ("type",)
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=identity,
only=only_fields,
requires_write_access=requires_write_access,
)
if task.type != TaskType.report:
raise errors.bad_request.OperationSupportedOnReportsOnly(id=task_id)
@ -78,6 +80,7 @@ def update_report(call: APICall, company_id: str, request: UpdateReportRequest):
task = _assert_report(
task_id=request.task,
company_id=company_id,
identity=call.identity,
only_fields=("status",),
)
@ -265,7 +268,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
res["plots"] = _get_multitask_plots(
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
request_metrics=request.plots.metrics,
last_iters_per_task_metric=request.plots.last_iters_per_task_metric,
)[0]
@ -302,6 +305,7 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
user_id = call.identity.user
@ -337,7 +341,9 @@ def move(call: APICall, company_id: str, request: MoveReportRequest):
response_data_model=UpdateResponse,
)
def publish(call: APICall, company_id, request: PublishReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
updates = ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
@ -352,7 +358,9 @@ def publish(call: APICall, company_id, request: PublishReportRequest):
@endpoint("reports.archive")
def archive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
archived = task.update(
status_message=request.message,
status_reason="",
@ -366,7 +374,9 @@ def archive(call: APICall, company_id, request: ArchiveReportRequest):
@endpoint("reports.unarchive")
def unarchive(call: APICall, company_id, request: ArchiveReportRequest):
task = _assert_report(company_id=company_id, task_id=request.task)
task = _assert_report(
company_id=company_id, task_id=request.task, identity=call.identity
)
unarchived = task.update(
status_message=request.message,
status_reason="",
@ -394,6 +404,7 @@ def delete(call: APICall, company_id, request: DeleteReportRequest):
task = _assert_report(
company_id=company_id,
task_id=request.task,
identity=call.identity,
only_fields=("project",),
)
if (