diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index e8dabbf..0eca4e2 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -14,12 +14,20 @@ from apiserver.utilities.stringenum import StringEnum class HistogramRequestBase(Base): - samples: int = IntField(default=6000, validators=[Min(1), Max(6000)]) + samples: int = IntField(default=2000, validators=[Min(1), Max(6000)]) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) +class MetricVariants(Base): + metric: str = StringField(required=True) + variants: Sequence[str] = ListField( + items_types=str, validators=Length(minimum_value=1) + ) + + class ScalarMetricsIterHistogramRequest(HistogramRequestBase): task: str = StringField(required=True) + metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): @@ -39,6 +47,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): class TaskMetric(Base): task: str = StringField(required=True) metric: str = StringField(default=None) + variants: Sequence[str] = ListField(items_types=str) class DebugImagesRequest(Base): @@ -59,8 +68,8 @@ class TaskMetricVariant(Base): class GetDebugImageSampleRequest(TaskMetricVariant): iteration: Optional[int] = IntField() - scroll_id: Optional[str] = StringField() refresh: bool = BoolField(default=False) + scroll_id: Optional[str] = StringField() class NextDebugImageSampleRequest(Base): @@ -102,3 +111,10 @@ class TaskMetricsRequest(Base): items_types=str, validators=[Length(minimum_value=1)] ) event_type: EventType = ActualEnumField(EventType, required=True) + + +class TaskPlotsRequest(Base): + task: str = StringField(required=True) + iters: int = IntField(default=1) + scroll_id: str = StringField() + metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) diff --git a/apiserver/bll/event/debug_images_iterator.py b/apiserver/bll/event/debug_images_iterator.py index ea79efb..ecd0eda 100644 --- a/apiserver/bll/event/debug_images_iterator.py +++ b/apiserver/bll/event/debug_images_iterator.py @@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor from datetime import datetime from functools import partial from operator import itemgetter -from typing import Sequence, Tuple, Optional, Mapping, Set +from typing import Sequence, Tuple, Optional, Mapping import attr import dpath @@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import ( check_empty_data, search_company_events, EventType, + get_metric_variants_condition, ) from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.database.errors import translate_errors_context @@ -74,7 +75,7 @@ class DebugImagesIterator: def get_task_events( self, company_id: str, - task_metrics: Mapping[str, Set[str]], + task_metrics: Mapping[str, dict], iter_count: int, navigate_earlier: bool = True, refresh: bool = False, @@ -118,7 +119,7 @@ class DebugImagesIterator: self, company_id, state: DebugImageEventsScrollState, - task_metrics: Mapping[str, Set[str]], + task_metrics: Mapping[str, dict], ): """ Determine the metrics for which new debug image events were added @@ -158,11 +159,11 @@ class DebugImagesIterator: task_metrics_to_recalc = {} for task, metrics_times in update_times.items(): old_metric_states = task_metric_states[task] - metrics_to_recalc = set( - m + metrics_to_recalc = { + m: task_metrics[task].get(m) for m, t in metrics_times.items() if m not in old_metric_states or old_metric_states[m].timestamp < t - ) + } if metrics_to_recalc: task_metrics_to_recalc[task] = metrics_to_recalc @@ -196,7 +197,7 @@ class DebugImagesIterator: ] def _init_task_states( - self, company_id: str, task_metrics: Mapping[str, Set[str]] + self, company_id: str, task_metrics: Mapping[str, dict] ) -> Sequence[TaskScrollState]: """ Returned initialized metric scroll stated for the requested task metrics @@ -213,7 +214,7 @@ class DebugImagesIterator: ] def _init_metric_states_for_task( - self, task_metrics: Tuple[str, Set[str]], company_id: str + self, task_metrics: Tuple[str, dict], company_id: str ) -> Sequence[MetricState]: """ Return metric scroll states for the task filled with the variant states @@ -222,10 +223,11 @@ class DebugImagesIterator: task, metrics = task_metrics must = [{"term": {"task": task}}, {"exists": {"field": "url"}}] if metrics: - must.append({"terms": {"metric": list(metrics)}}) + must.append(get_metric_variants_condition(metrics)) + query = {"bool": {"must": must}} es_req: dict = { "size": 0, - "query": {"bool": {"must": must}}, + "query": query, "aggs": { "metrics": { "terms": { diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 365c4de..2a6af59 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -6,9 +6,8 @@ from collections import defaultdict from contextlib import closing from datetime import datetime from operator import attrgetter -from typing import Sequence, Set, Tuple, Optional, Dict +from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union -import six from elasticsearch import helpers from elasticsearch.helpers import BulkIndexError from mongoengine import Q @@ -22,6 +21,8 @@ from apiserver.bll.event.event_common import ( check_empty_data, search_company_events, delete_company_events, + MetricVariants, + get_metric_variants_condition, ) from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils @@ -43,8 +44,8 @@ from apiserver.utilities.json import loads # noinspection PyTypeChecker EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType)) LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) -MAX_LONG = 2**63 - 1 -MIN_LONG = -2**63 +MAX_LONG = 2 ** 63 - 1 +MIN_LONG = -(2 ** 63) class PlotFields: @@ -94,7 +95,7 @@ class EventBLL(object): def add_events( self, company_id, events, worker, allow_locked_tasks=False ) -> Tuple[int, int, dict]: - actions = [] + actions: List[dict] = [] task_ids = set() task_iteration = defaultdict(lambda: 0) task_last_scalar_events = nested_dict( @@ -197,7 +198,6 @@ class EventBLL(object): actions.append(es_action) - action: Dict[dict] plot_actions = [ action["_source"] for action in actions @@ -260,7 +260,8 @@ class EventBLL(object): invalid_iterations_count = errors_per_type.get(invalid_iteration_error) if invalid_iterations_count: raise BulkIndexError( - f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error] + f"{invalid_iterations_count} document(s) failed to index.", + [invalid_iteration_error], ) if not added: @@ -466,10 +467,16 @@ class EventBLL(object): task_id: str, num_last_iterations: int, event_type: EventType, + metric_variants: MetricVariants = None, ): if check_empty_data(self.es, company_id=company_id, event_type=event_type): return [] + must = [{"term": {"task": task_id}}] + if metric_variants: + must.append(get_metric_variants_condition(metric_variants)) + query = {"bool": {"must": must}} + es_req: dict = { "size": 0, "aggs": { @@ -499,7 +506,7 @@ class EventBLL(object): }, } }, - "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, + "query": query, } with translate_errors_context(), TimingContext( @@ -527,6 +534,7 @@ class EventBLL(object): sort=None, size: int = 500, scroll_id: str = None, + metric_variants: MetricVariants = None, ): if scroll_id == self.empty_scroll: return TaskEventsResult() @@ -555,6 +563,8 @@ class EventBLL(object): if last_iterations_per_plot is None: must.append({"terms": {"task": tasks}}) + if metric_variants: + must.append(get_metric_variants_condition(metric_variants)) else: should = [] for i, task_id in enumerate(tasks): @@ -563,6 +573,7 @@ class EventBLL(object): task_id=task_id, num_last_iterations=last_iterations_per_plot, event_type=event_type, + metric_variants=metric_variants, ) if not last_iters: continue @@ -669,19 +680,19 @@ class EventBLL(object): sort=None, size=500, scroll_id=None, - ): + ) -> TaskEventsResult: if scroll_id == self.empty_scroll: - return [], scroll_id, 0 + return TaskEventsResult() if scroll_id: with translate_errors_context(), TimingContext("es", "get_task_events"): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: - task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id - if check_empty_data(self.es, company_id=company_id, event_type=event_type): return TaskEventsResult() + task_ids = [task_id] if isinstance(task_id, str) else task_id + must = [] if metric: must.append({"term": {"metric": metric}}) @@ -691,26 +702,24 @@ class EventBLL(object): if last_iter_count is None: must.append({"terms": {"task": task_ids}}) else: - should = [] - for i, task_id in enumerate(task_ids): - last_iters = self.get_last_iters( - company_id=company_id, - event_type=event_type, - task_id=task_id, - iters=last_iter_count, - ) - if not last_iters: - continue - should.append( - { - "bool": { - "must": [ - {"term": {"task": task_id}}, - {"terms": {"iter": last_iters}}, - ] - } + tasks_iters = self.get_last_iters( + company_id=company_id, + event_type=event_type, + task_id=task_ids, + iters=last_iter_count, + ) + should = [ + { + "bool": { + "must": [ + {"term": {"task": task}}, + {"terms": {"iter": last_iters}}, + ] } - ) + } + for task, last_iters in tasks_iters.items() + if last_iters + ] if not should: return TaskEventsResult() must.append({"bool": {"should": should}}) @@ -748,6 +757,7 @@ class EventBLL(object): if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} + query = {"bool": {"must": [{"term": {"task": task_id}}]}} es_req = { "size": 0, "aggs": { @@ -768,7 +778,7 @@ class EventBLL(object): }, } }, - "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, + "query": query, } with translate_errors_context(), TimingContext( @@ -787,21 +797,24 @@ class EventBLL(object): return metrics - def get_task_latest_scalar_values(self, company_id: str, task_id: str): + def get_task_latest_scalar_values( + self, company_id, task_id + ) -> Tuple[Sequence[dict], int]: event_type = EventType.metrics_scalar if check_empty_data(self.es, company_id=company_id, event_type=event_type): - return {} + return [], 0 + query = { + "bool": { + "must": [ + {"query_string": {"query": "value:>0"}}, + {"term": {"task": task_id}}, + ] + } + } es_req = { "size": 0, - "query": { - "bool": { - "must": [ - {"query_string": {"query": "value:>0"}}, - {"term": {"task": task_id}}, - ] - } - }, + "query": query, "aggs": { "metrics": { "terms": { @@ -905,11 +918,16 @@ class EventBLL(object): return iterations, vectors def get_last_iters( - self, company_id: str, event_type: EventType, task_id: str, iters: int - ): + self, + company_id: str, + event_type: EventType, + task_id: Union[str, Sequence[str]], + iters: int, + ) -> Mapping[str, Sequence]: if check_empty_data(self.es, company_id=company_id, event_type=event_type): - return [] + return {} + task_ids = [task_id] if isinstance(task_id, str) else task_id es_req: dict = { "size": 0, "aggs": { @@ -921,7 +939,7 @@ class EventBLL(object): } } }, - "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, + "query": {"bool": {"must": [{"terms": {"task": task_ids}}]}}, } with translate_errors_context(), TimingContext("es", "task_last_iter"): @@ -930,9 +948,12 @@ class EventBLL(object): ) if "aggregations" not in es_res: - return [] + return {} - return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]] + return { + tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]] + for tb in es_res["aggregations"]["tasks"]["buckets"] + } def delete_task_events(self, company_id, task_id, allow_locked=False): with translate_errors_context(): @@ -965,7 +986,9 @@ class EventBLL(object): so it should be checked by the calling code """ es_req = {"query": {"terms": {"task": task_ids}}} - with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"): + with translate_errors_context(), TimingContext( + "es", "delete_multi_tasks_events" + ): es_res = delete_company_events( es=self.es, company_id=company_id, diff --git a/apiserver/bll/event/event_common.py b/apiserver/bll/event/event_common.py index bb9a075..a1dd0b1 100644 --- a/apiserver/bll/event/event_common.py +++ b/apiserver/bll/event/event_common.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Union, Sequence +from typing import Union, Sequence, Mapping from boltons.typeutils import classproperty from elasticsearch import Elasticsearch @@ -16,6 +16,9 @@ class EventType(Enum): all = "*" +MetricVariants = Mapping[str, Sequence[str]] + + class EventSettings: @classproperty def max_workers(self): @@ -64,3 +67,23 @@ def delete_company_events( ) -> dict: es_index = get_index_name(company_id, event_type.value) return es.delete_by_query(index=es_index, body=body, **kwargs) + + +def get_metric_variants_condition( + metric_variants: MetricVariants, +) -> Sequence: + conditions = [ + { + "bool": { + "must": [ + {"term": {"metric": metric}}, + {"terms": {"variant": variants}}, + ] + } + } + if variants + else {"term": {"metric": metric}} + for metric, variants in metric_variants.items() + ] + + return {"bool": {"should": conditions}} diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index ef45d23..6606786 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import ( EventSettings, search_company_events, check_empty_data, + MetricVariants, + get_metric_variants_condition, ) from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.config_repo import config @@ -34,7 +36,12 @@ class EventMetrics: self.es = es def get_scalar_metrics_average_per_iter( - self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum + self, + company_id: str, + task_id: str, + samples: int, + key: ScalarKeyEnum, + metric_variants: MetricVariants = None, ) -> dict: """ Get scalar metric histogram per metric and variant @@ -46,7 +53,12 @@ class EventMetrics: return {} return self._get_scalar_average_per_iter_core( - task_id, company_id, event_type, samples, ScalarKey.resolve(key) + task_id=task_id, + company_id=company_id, + event_type=event_type, + samples=samples, + key=ScalarKey.resolve(key), + metric_variants=metric_variants, ) def _get_scalar_average_per_iter_core( @@ -57,6 +69,7 @@ class EventMetrics: samples: int, key: ScalarKey, run_parallel: bool = True, + metric_variants: MetricVariants = None, ) -> dict: intervals = self._get_task_metric_intervals( company_id=company_id, @@ -64,6 +77,7 @@ class EventMetrics: task_id=task_id, samples=samples, field=key.field, + metric_variants=metric_variants, ) if not intervals: return {} @@ -197,6 +211,7 @@ class EventMetrics: task_id: str, samples: int, field: str = "iter", + metric_variants: MetricVariants = None, ) -> Sequence[MetricInterval]: """ Calculate interval per task metric variant so that the resulting @@ -204,9 +219,14 @@ class EventMetrics: Return the list og metric variant intervals as the following tuple: (metric, variant, interval, samples) """ + must = [{"term": {"task": task_id}}] + if metric_variants: + must.append(get_metric_variants_condition(metric_variants)) + query = {"bool": {"must": must}} + es_req = { "size": 0, - "query": {"term": {"task": task_id}}, + "query": query, "aggs": { "metrics": { "terms": { diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index bf56546..6a6bcc3 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -1,6 +1,18 @@ { _description : "Provides an API for running tasks to report events collected by the system." _definitions { + metric_variants { + type: object + metric { + description: The metric name + type: string + } + variants { + type: array + description: The names of the metric variants + items {type: string} + } + } metrics_scalar_event { description: "Used for reporting scalar metrics during training task" type: object @@ -193,6 +205,29 @@ description: "Task ID" type: string } + metric { + description: "Metric name" + type: string + } + } + } + task_metric_variants { + type: object + required: [task] + properties { + task { + description: "Task ID" + type: string + } + metric { + description: "Metric name" + type: string + } + variants { + description: Metric variant names + type: array + items {type: string} + } } } task_log_event { @@ -376,7 +411,7 @@ metrics { type: array items { "$ref": "#/definitions/task_metric" } - description: "List metrics for which the envents will be retreived" + description: "List of task metrics for which the envents will be retreived" } iters { type: integer @@ -411,6 +446,17 @@ } } } + "2.14": ${debug_images."2.7"} { + request { + properties { + metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/task_metric_variants" } + } + } + } + } } get_debug_image_sample { "2.12": { @@ -804,6 +850,17 @@ } } } + "2.14": ${get_task_plots."2.1"} { + request { + properties { + metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } + } + } + } } get_multi_task_plots { "2.1" { @@ -962,6 +1019,17 @@ } } } + "2.14": ${scalar_metrics_iter_histogram."2.1"} { + request { + properties { + metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } + } + } + } } multi_task_scalar_metrics_iter_histogram { "2.1" { diff --git a/apiserver/services/events.py b/apiserver/services/events.py index b792025..4e400a7 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -3,6 +3,7 @@ from collections import defaultdict from operator import itemgetter import attr +from typing import Sequence, Optional from apiserver.apierrors import errors from apiserver.apimodels.events import ( @@ -17,9 +18,11 @@ from apiserver.apimodels.events import ( LogOrderEnum, GetDebugImageSampleRequest, NextDebugImageSampleRequest, + MetricVariants as ApiMetrics, + TaskPlotsRequest, ) from apiserver.bll.event import EventBLL -from apiserver.bll.event.event_common import EventType +from apiserver.bll.event.event_common import EventType, MetricVariants from apiserver.bll.task import TaskBLL from apiserver.service_repo import APICall, endpoint from apiserver.utilities import json @@ -321,7 +324,7 @@ def get_task_latest_scalar_values(call, company_id, _): ) last_iters = event_bll.get_last_iters( company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1 - ) + ).get(task_id) call.result.data = dict( metrics=metrics, last_iter=last_iters[0] if last_iters else 0, @@ -494,11 +497,22 @@ def get_task_plots_v1_7(call, company_id, _): ) -@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"]) -def get_task_plots(call, company_id, _): - task_id = call.data["task"] - iters = call.data.get("iters", 1) - scroll_id = call.data.get("scroll_id") +def _get_metric_variants_from_request( + req_metrics: Sequence[ApiMetrics], +) -> Optional[MetricVariants]: + if not req_metrics: + return None + + return {m.metric: m.variants for m in req_metrics} + + +@endpoint( + "events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest +) +def get_task_plots(call, company_id, request: TaskPlotsRequest): + task_id = request.task + iters = request.iters + scroll_id = request.scroll_id task = task_bll.assert_exists( company_id, task_id, allow_public=True, only=("company", "company_origin") @@ -509,6 +523,7 @@ def get_task_plots(call, company_id, _): sort=[{"iter": {"order": "desc"}}], last_iterations_per_plot=iters, scroll_id=scroll_id, + metric_variants=_get_metric_variants_from_request(request.metrics), ) return_events = result.events @@ -594,9 +609,9 @@ def get_debug_images_v1_8(call, company_id, _): response_data_model=DebugImageResponse, ) def get_debug_images(call, company_id, request: DebugImagesRequest): - task_metrics = defaultdict(set) + task_metrics = defaultdict(dict) for tm in request.metrics: - task_metrics[tm.task].add(tm.metric) + task_metrics[tm.task][tm.metric] = tm.variants for metrics in task_metrics.values(): if None in metrics: metrics.clear() @@ -734,11 +749,11 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks): def _get_top_iter_unique_events(events, max_iters): top_unique_events = defaultdict(lambda: []) - for e in events: - key = e.get("metric", "") + e.get("variant", "") + for ev in events: + key = ev.get("metric", "") + ev.get("variant", "") evs = top_unique_events[key] if len(evs) < max_iters: - evs.append(e) + evs.append(ev) unique_events = list( itertools.chain.from_iterable(list(top_unique_events.values())) )