Support querying task events per specific metrics and variants

This commit is contained in:
allegroai 2021-07-25 14:29:41 +03:00
parent 677bb3ba6d
commit 9069cfe1da
7 changed files with 245 additions and 78 deletions

View File

@ -14,12 +14,20 @@ from apiserver.utilities.stringenum import StringEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class MetricVariants(Base):
metric: str = StringField(required=True)
variants: Sequence[str] = ListField(
items_types=str, validators=Length(minimum_value=1)
)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@ -39,6 +47,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(default=None)
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
@ -59,8 +68,8 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
class NextDebugImageSampleRequest(Base):
@ -102,3 +111,10 @@ class TaskMetricsRequest(Base):
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
@ -74,7 +75,7 @@ class DebugImagesIterator:
def get_task_events(
self,
company_id: str,
task_metrics: Mapping[str, Set[str]],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
@ -118,7 +119,7 @@ class DebugImagesIterator:
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new debug image events were added
@ -158,11 +159,11 @@ class DebugImagesIterator:
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = set(
m
metrics_to_recalc = {
m: task_metrics[task].get(m)
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
)
}
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
@ -196,7 +197,7 @@ class DebugImagesIterator:
]
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]]
self, company_id: str, task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
@ -213,7 +214,7 @@ class DebugImagesIterator:
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Set[str]], company_id: str
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
@ -222,10 +223,11 @@ class DebugImagesIterator:
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append({"terms": {"metric": list(metrics)}})
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@ -6,9 +6,8 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, Dict
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import six
from elasticsearch import helpers
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
@ -22,6 +21,8 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
delete_company_events,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2**63 - 1
MIN_LONG = -2**63
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
class PlotFields:
@ -94,7 +95,7 @@ class EventBLL(object):
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
actions: List[dict] = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
@ -197,7 +198,6 @@ class EventBLL(object):
actions.append(es_action)
action: Dict[dict]
plot_actions = [
action["_source"]
for action in actions
@ -260,7 +260,8 @@ class EventBLL(object):
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
)
if not added:
@ -466,10 +467,16 @@ class EventBLL(object):
task_id: str,
num_last_iterations: int,
event_type: EventType,
metric_variants: MetricVariants = None,
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"aggs": {
@ -499,7 +506,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@ -527,6 +534,7 @@ class EventBLL(object):
sort=None,
size: int = 500,
scroll_id: str = None,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@ -555,6 +563,8 @@ class EventBLL(object):
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
@ -563,6 +573,7 @@ class EventBLL(object):
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
@ -669,19 +680,19 @@ class EventBLL(object):
sort=None,
size=500,
scroll_id=None,
):
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metric:
must.append({"term": {"metric": metric}})
@ -691,26 +702,24 @@ class EventBLL(object):
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_id,
iters=last_iter_count,
)
if not last_iters:
continue
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"iter": last_iters}},
]
}
tasks_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
)
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@ -748,6 +757,7 @@ class EventBLL(object):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
es_req = {
"size": 0,
"aggs": {
@ -768,7 +778,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@ -787,21 +797,24 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
def get_task_latest_scalar_values(
self, company_id, task_id
) -> Tuple[Sequence[dict], int]:
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
return [], 0
query = {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
}
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
},
"query": query,
"aggs": {
"metrics": {
"terms": {
@ -905,11 +918,16 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters(
self, company_id: str, event_type: EventType, task_id: str, iters: int
):
self,
company_id: str,
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
es_req: dict = {
"size": 0,
"aggs": {
@ -921,7 +939,7 @@ class EventBLL(object):
}
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
@ -930,9 +948,12 @@ class EventBLL(object):
)
if "aggregations" not in es_res:
return []
return {}
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
return {
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
@ -965,7 +986,9 @@ class EventBLL(object):
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
es_res = delete_company_events(
es=self.es,
company_id=company_id,

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Union, Sequence
from typing import Union, Sequence, Mapping
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
@ -16,6 +16,9 @@ class EventType(Enum):
all = "*"
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings:
@classproperty
def max_workers(self):
@ -64,3 +67,23 @@ def delete_company_events(
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(
metric_variants: MetricVariants,
) -> Sequence:
conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"terms": {"variant": variants}},
]
}
}
if variants
else {"term": {"metric": metric}}
for metric, variants in metric_variants.items()
]
return {"bool": {"should": conditions}}

View File

@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
EventSettings,
search_company_events,
check_empty_data,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
@ -34,7 +36,12 @@ class EventMetrics:
self.es = es
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@ -46,7 +53,12 @@ class EventMetrics:
return {}
return self._get_scalar_average_per_iter_core(
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
)
def _get_scalar_average_per_iter_core(
@ -57,6 +69,7 @@ class EventMetrics:
samples: int,
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@ -64,6 +77,7 @@ class EventMetrics:
task_id=task_id,
samples=samples,
field=key.field,
metric_variants=metric_variants,
)
if not intervals:
return {}
@ -197,6 +211,7 @@ class EventMetrics:
task_id: str,
samples: int,
field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
@ -204,9 +219,14 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req = {
"size": 0,
"query": {"term": {"task": task_id}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@ -1,6 +1,18 @@
{
_description : "Provides an API for running tasks to report events collected by the system."
_definitions {
metric_variants {
type: object
metric {
description: The metric name
type: string
}
variants {
type: array
description: The names of the metric variants
items {type: string}
}
}
metrics_scalar_event {
description: "Used for reporting scalar metrics during training task"
type: object
@ -193,6 +205,29 @@
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
}
}
task_metric_variants {
type: object
required: [task]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
variants {
description: Metric variant names
type: array
items {type: string}
}
}
}
task_log_event {
@ -376,7 +411,7 @@
metrics {
type: array
items { "$ref": "#/definitions/task_metric" }
description: "List metrics for which the envents will be retreived"
description: "List of task metrics for which the envents will be retreived"
}
iters {
type: integer
@ -411,6 +446,17 @@
}
}
}
"2.14": ${debug_images."2.7"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/task_metric_variants" }
}
}
}
}
}
get_debug_image_sample {
"2.12": {
@ -804,6 +850,17 @@
}
}
}
"2.14": ${get_task_plots."2.1"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
get_multi_task_plots {
"2.1" {
@ -962,6 +1019,17 @@
}
}
}
"2.14": ${scalar_metrics_iter_histogram."2.1"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
}
multi_task_scalar_metrics_iter_histogram {
"2.1" {

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from operator import itemgetter
import attr
from typing import Sequence, Optional
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
@ -17,9 +18,11 @@ from apiserver.apimodels.events import (
LogOrderEnum,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
@ -321,7 +324,7 @@ def get_task_latest_scalar_values(call, company_id, _):
)
last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
)
).get(task_id)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
@ -494,11 +497,22 @@ def get_task_plots_v1_7(call, company_id, _):
)
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
def _get_metric_variants_from_request(
req_metrics: Sequence[ApiMetrics],
) -> Optional[MetricVariants]:
if not req_metrics:
return None
return {m.metric: m.variants for m in req_metrics}
@endpoint(
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
)
def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@ -509,6 +523,7 @@ def get_task_plots(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
return_events = result.events
@ -594,9 +609,9 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_metrics = defaultdict(set)
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task].add(tm.metric)
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
@ -734,11 +749,11 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: [])
for e in events:
key = e.get("metric", "") + e.get("variant", "")
for ev in events:
key = ev.get("metric", "") + ev.get("variant", "")
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
evs.append(ev)
unique_events = list(
itertools.chain.from_iterable(list(top_unique_events.values()))
)