From 6c8508eb7f7565f00166d9e85fa295b39ff6457b Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 1 Mar 2020 18:00:07 +0200 Subject: [PATCH] Add support for pagination in events.debug_images --- server/apierrors/__init__.py | 2 + server/apimodels/events.py | 47 ++- server/bll/event/debug_images_iterator.py | 464 +++++++++++++++++++++ server/bll/event/event_bll.py | 86 ++-- server/bll/event/event_metrics.py | 66 ++- server/bll/redis_cache_manager.py | 44 ++ server/bll/task/task_bll.py | 34 +- server/config/default/hosts.conf | 5 + server/database/model/task/metrics.py | 29 +- server/database/model/task/task.py | 3 +- server/schema/services/events.conf | 102 +++++ server/services/events.py | 58 ++- server/tests/automated/test_task_events.py | 138 +++++- 13 files changed, 1021 insertions(+), 57 deletions(-) create mode 100644 server/bll/event/debug_images_iterator.py create mode 100644 server/bll/redis_cache_manager.py diff --git a/server/apierrors/__init__.py b/server/apierrors/__init__.py index 46fad7b..915c555 100644 --- a/server/apierrors/__init__.py +++ b/server/apierrors/__init__.py @@ -89,6 +89,8 @@ _error_codes = { 1003: ('worker_registered', 'worker is already registered'), 1004: ('worker_not_registered', 'worker is not registered'), 1005: ('worker_stats_not_found', 'worker stats not found'), + + 1104: ('invalid_scroll_id', 'Invalid scroll id'), }, (401, 'unauthorized'): { diff --git a/server/apimodels/events.py b/server/apimodels/events.py index 97aded5..f5315b7 100644 --- a/server/apimodels/events.py +++ b/server/apimodels/events.py @@ -1,9 +1,12 @@ from typing import Sequence -from jsonmodels.fields import StringField +from jsonmodels import validators +from jsonmodels.fields import StringField, BoolField from jsonmodels.models import Base +from jsonmodels.validators import Length from apimodels import ListField, IntField, ActualEnumField +from bll.event.event_metrics import EventType from bll.event.scalar_key import ScalarKeyEnum @@ -17,4 +20,44 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase): class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): - tasks: Sequence[str] = ListField(items_types=str, required=True) + tasks: Sequence[str] = ListField( + items_types=str, validators=[Length(minimum_value=1)] + ) + + +class TaskMetric(Base): + task: str = StringField(required=True) + metric: str = StringField(required=True) + + +class DebugImagesRequest(Base): + metrics: Sequence[TaskMetric] = ListField( + items_types=TaskMetric, validators=[Length(minimum_value=1)] + ) + iters: int = IntField(default=1, validators=validators.Min(1)) + navigate_earlier: bool = BoolField(default=True) + refresh: bool = BoolField(default=False) + scroll_id: str = StringField() + + +class IterationEvents(Base): + iter: int = IntField() + events: Sequence[dict] = ListField(items_types=dict) + + +class MetricEvents(Base): + task: str = StringField() + metric: str = StringField() + iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents) + + +class DebugImageResponse(Base): + metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents) + scroll_id: str = StringField() + + +class TaskMetricsRequest(Base): + tasks: Sequence[str] = ListField( + items_types=str, validators=[Length(minimum_value=1)] + ) + event_type: EventType = ActualEnumField(EventType, required=True) diff --git a/server/bll/event/debug_images_iterator.py b/server/bll/event/debug_images_iterator.py new file mode 100644 index 0000000..b755fc3 --- /dev/null +++ b/server/bll/event/debug_images_iterator.py @@ -0,0 +1,464 @@ +from collections import defaultdict +from concurrent.futures.thread import ThreadPoolExecutor +from functools import partial +from itertools import chain +from operator import attrgetter, itemgetter + +import attr +import dpath +from boltons.iterutils import bucketize +from elasticsearch import Elasticsearch +from redis import StrictRedis +from typing import Sequence, Tuple, Optional, Mapping + +import database +from apierrors import errors +from bll.redis_cache_manager import RedisCacheManager +from bll.event.event_metrics import EventMetrics +from config import config +from database.errors import translate_errors_context +from jsonmodels.models import Base +from jsonmodels.fields import StringField, ListField, IntField + +from database.model.task.metrics import MetricEventStats +from database.model.task.task import Task +from timing_context import TimingContext +from utilities.json import loads, dumps + + +class VariantScrollState(Base): + name: str = StringField(required=True) + recycle_url_marker: str = StringField() + last_invalid_iteration: int = IntField() + + +class MetricScrollState(Base): + task: str = StringField(required=True) + name: str = StringField(required=True) + last_min_iter: Optional[int] = IntField() + last_max_iter: Optional[int] = IntField() + timestamp: int = IntField(default=0) + variants: Sequence[VariantScrollState] = ListField([VariantScrollState]) + + def reset(self): + """Reset the scrolling state for the metric""" + self.last_min_iter = self.last_max_iter = None + + +class DebugImageEventsScrollState(Base): + id: str = StringField(required=True) + metrics: Sequence[MetricScrollState] = ListField([MetricScrollState]) + + def to_json(self): + return dumps(self.to_struct()) + + @classmethod + def from_json(cls, s): + return cls(**loads(s)) + + +@attr.s(auto_attribs=True) +class DebugImagesResult(object): + metric_events: Sequence[tuple] = [] + next_scroll_id: str = None + + +class DebugImagesIterator: + EVENT_TYPE = "training_debug_image" + STATE_EXPIRATION_SECONDS = 3600 + + @property + def _max_workers(self): + return config.get("services.events.max_metrics_concurrency", 4) + + def __init__(self, redis: StrictRedis, es: Elasticsearch): + self.es = es + self.cache_manager = RedisCacheManager( + state_class=DebugImageEventsScrollState, + redis=redis, + expiration_interval=self.STATE_EXPIRATION_SECONDS, + ) + + def get_task_events( + self, + company_id: str, + metrics: Sequence[Tuple[str, str]], + iter_count: int, + navigate_earlier: bool = True, + refresh: bool = False, + state_id: str = None, + ) -> DebugImagesResult: + es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) + if not self.es.indices.exists(es_index): + return DebugImagesResult() + + unique_metrics = set(metrics) + state = self.cache_manager.get_state(state_id) if state_id else None + if not state: + state = DebugImageEventsScrollState( + id=database.utils.id(), + metrics=self._init_metric_states(es_index, list(unique_metrics)), + ) + else: + state_metrics = set((m.task, m.name) for m in state.metrics) + if state_metrics != unique_metrics: + raise errors.bad_request.InvalidScrollId( + "while getting debug images events", scroll_id=state_id + ) + + if refresh: + self._reinit_outdated_metric_states(company_id, es_index, state) + for metric_state in state.metrics: + metric_state.reset() + + res = DebugImagesResult(next_scroll_id=state.id) + try: + with ThreadPoolExecutor(self._max_workers) as pool: + res.metric_events = list( + pool.map( + partial( + self._get_task_metric_events, + es_index=es_index, + iter_count=iter_count, + navigate_earlier=navigate_earlier, + ), + state.metrics, + ) + ) + finally: + self.cache_manager.set_state(state) + + return res + + def _reinit_outdated_metric_states( + self, company_id, es_index, state: DebugImageEventsScrollState + ): + """ + Determines the metrics for which new debug image events were added + since their states were initialized and reinits these states + """ + task_ids = set(metric.task for metric in state.metrics) + tasks = Task.objects(id__in=list(task_ids), company=company_id).only( + "id", "metric_stats" + ) + + def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]: + """For metrics that reported debug image events get tuples of task_id/metric_name and last update times""" + metric_stats: Mapping[str, MetricEventStats] = task.metric_stats + if not metric_stats: + return [] + + return [ + ( + (task.id, stats.metric), + stats.event_stats_by_type[self.EVENT_TYPE].last_update, + ) + for stats in metric_stats.values() + if self.EVENT_TYPE in stats.event_stats_by_type + ] + + update_times = dict( + chain.from_iterable( + get_last_update_times_for_task_metrics(task) for task in tasks + ) + ) + outdated_metrics = [ + metric + for metric in state.metrics + if (metric.task, metric.name) in update_times + and update_times[metric.task, metric.name] > metric.timestamp + ] + state.metrics = [ + *(metric for metric in state.metrics if metric not in outdated_metrics), + *( + self._init_metric_states( + es_index, + [(metric.task, metric.name) for metric in outdated_metrics], + ) + ), + ] + + def _init_metric_states( + self, es_index, metrics: Sequence[Tuple[str, str]] + ) -> Sequence[MetricScrollState]: + """ + Returned initialized metric scroll stated for the requested task metrics + """ + tasks = defaultdict(list) + for (task, metric) in metrics: + tasks[task].append(metric) + + with ThreadPoolExecutor(self._max_workers) as pool: + return list( + chain.from_iterable( + pool.map( + partial(self._init_metric_states_for_task, es_index=es_index), + tasks.items(), + ) + ) + ) + + def _init_metric_states_for_task( + self, task_metrics: Tuple[str, Sequence[str]], es_index + ) -> Sequence[MetricScrollState]: + """ + Return metric scroll states for the task filled with the variant states + for the variants that reported any debug images + """ + task, metrics = task_metrics + es_req: dict = { + "size": 0, + "query": { + "bool": { + "must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}] + } + }, + "aggs": { + "metrics": { + "terms": { + "field": "metric", + "size": EventMetrics.MAX_METRICS_COUNT, + }, + "aggs": { + "last_event_timestamp": {"max": {"field": "timestamp"}}, + "variants": { + "terms": { + "field": "variant", + "size": EventMetrics.MAX_VARIANTS_COUNT, + }, + "aggs": { + "urls": { + "terms": { + "field": "url", + "order": {"max_iter": "desc"}, + "size": 1, # we need only one url from the most recent iteration + }, + "aggs": { + "max_iter": {"max": {"field": "iter"}}, + "iters": { + "top_hits": { + "sort": {"iter": {"order": "desc"}}, + "size": 2, # need two last iterations so that we can take + # the second one as invalid + "_source": "iter", + } + }, + }, + } + }, + }, + }, + } + }, + } + + with translate_errors_context(), TimingContext("es", "_init_metric_states"): + es_res = self.es.search(index=es_index, body=es_req, routing=task) + if "aggregations" not in es_res: + return [] + + def init_variant_scroll_state(variant: dict): + """ + Return new variant scroll state for the passed variant bucket + If the image urls get recycled then fill the last_invalid_iteration field + """ + state = VariantScrollState(name=variant["key"]) + top_iter_url = dpath.get(variant, "urls/buckets")[0] + iters = dpath.get(top_iter_url, "iters/hits/hits") + if len(iters) > 1: + state.last_invalid_iteration = dpath.get(iters[1], "_source/iter") + return state + + return [ + MetricScrollState( + task=task, + name=metric["key"], + variants=[ + init_variant_scroll_state(variant) + for variant in dpath.get(metric, "variants/buckets") + ], + timestamp=dpath.get(metric, "last_event_timestamp/value"), + ) + for metric in dpath.get(es_res, "aggregations/metrics/buckets") + ] + + def _get_task_metric_events( + self, + metric: MetricScrollState, + es_index: str, + iter_count: int, + navigate_earlier: bool, + ) -> Tuple: + """ + Return task metric events grouped by iterations + Update metric scroll state + """ + if metric.last_max_iter is None: + # the first fetch is always from the latest iteration to the earlier ones + navigate_earlier = True + + must_conditions = [ + {"term": {"task": metric.task}}, + {"term": {"metric": metric.name}}, + ] + must_not_conditions = [] + + range_condition = None + if navigate_earlier and metric.last_min_iter is not None: + range_condition = {"lt": metric.last_min_iter} + elif not navigate_earlier and metric.last_max_iter is not None: + range_condition = {"gt": metric.last_max_iter} + if range_condition: + must_conditions.append({"range": {"iter": range_condition}}) + + if navigate_earlier: + """ + When navigating to earlier iterations consider only + variants whose invalid iterations border is lower than + our starting iteration. For these variants make sure + that only events from the valid iterations are returned + """ + if not metric.last_min_iter: + variants = metric.variants + else: + variants = list( + v + for v in metric.variants + if v.last_invalid_iteration is None + or v.last_invalid_iteration < metric.last_min_iter + ) + if not variants: + return metric.task, metric.name, [] + must_conditions.append( + {"terms": {"variant": list(v.name for v in variants)}} + ) + else: + """ + When navigating to later iterations all variants may be relevant. + For the variants whose invalid border is higher than our starting + iteration make sure that only events from valid iterations are returned + """ + variants = list( + v + for v in metric.variants + if v.last_invalid_iteration is not None + and v.last_invalid_iteration > metric.last_max_iter + ) + + variants_conditions = [ + { + "bool": { + "must": [ + {"term": {"variant": v.name}}, + {"range": {"iter": {"lte": v.last_invalid_iteration}}}, + ] + } + } + for v in variants + if v.last_invalid_iteration is not None + ] + if variants_conditions: + must_not_conditions.append({"bool": {"should": variants_conditions}}) + + es_req = { + "size": 0, + "query": { + "bool": {"must": must_conditions, "must_not": must_not_conditions} + }, + "aggs": { + "iters": { + "terms": { + "field": "iter", + "size": iter_count, + "order": {"_term": "desc" if navigate_earlier else "asc"}, + }, + "aggs": { + "variants": { + "terms": { + "field": "variant", + "size": EventMetrics.MAX_VARIANTS_COUNT, + }, + "aggs": { + "events": { + "top_hits": {"sort": {"url": {"order": "desc"}}} + } + }, + } + }, + } + }, + } + with translate_errors_context(), TimingContext("es", "get_debug_image_events"): + es_res = self.es.search(index=es_index, body=es_req, routing=metric.task) + if "aggregations" not in es_res: + return metric.task, metric.name, [] + + def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence: + return [ + ev["_source"] + for v in variant_buckets + for ev in dpath.get(v, "events/hits/hits") + ] + + iterations = [ + { + "iter": it["key"], + "events": get_iteration_events(dpath.get(it, "variants/buckets")), + } + for it in dpath.get(es_res, "aggregations/iters/buckets") + ] + if not navigate_earlier: + iterations.sort(key=itemgetter("iter"), reverse=True) + if iterations: + metric.last_max_iter = iterations[0]["iter"] + metric.last_min_iter = iterations[-1]["iter"] + + # Commented for now since the last invalid iteration is calculated in the beginning + # if navigate_earlier and any( + # variant.last_invalid_iteration is None for variant in variants + # ): + # """ + # Variants validation flags due to recycling can + # be set only on navigation to earlier frames + # """ + # iterations = self._update_variants_invalid_iterations(variants, iterations) + + return metric.task, metric.name, iterations + + @staticmethod + def _update_variants_invalid_iterations( + variants: Sequence[VariantScrollState], iterations: Sequence[dict] + ) -> Sequence[dict]: + """ + This code is currently not in used since the invalid iterations + are calculated during MetricState initialization + For variants that do not have recycle url marker set it from the + first event + For variants that do not have last_invalid_iteration set check if the + recycle marker was reached on a certain iteration and set it to the + corresponding iteration + For variants that have a newly set last_invalid_iteration remove + events from the invalid iterations + Return the updated iterations list + """ + variants_lookup = bucketize(variants, attrgetter("name")) + for it in iterations: + iteration = it["iter"] + events_to_remove = [] + for event in it["events"]: + variant = variants_lookup[event["variant"]][0] + if ( + variant.last_invalid_iteration + and variant.last_invalid_iteration >= iteration + ): + events_to_remove.append(event) + continue + event_url = event.get("url") + if not variant.recycle_url_marker: + variant.recycle_url_marker = event_url + elif variant.recycle_url_marker == event_url: + variant.last_invalid_iteration = iteration + events_to_remove.append(event) + if events_to_remove: + it["events"] = [ev for ev in it["events"] if ev not in events_to_remove] + return [it for it in iterations if it["events"]] diff --git a/server/bll/event/event_bll.py b/server/bll/event/event_bll.py index 789d82e..5f1c217 100644 --- a/server/bll/event/event_bll.py +++ b/server/bll/event/event_bll.py @@ -2,7 +2,6 @@ import hashlib from collections import defaultdict from contextlib import closing from datetime import datetime -from enum import Enum from operator import attrgetter from typing import Sequence @@ -15,46 +14,39 @@ from nested_dict import nested_dict import database.utils as dbutils import es_factory from apierrors import errors -from bll.event.event_metrics import EventMetrics +from bll.event.debug_images_iterator import DebugImagesIterator +from bll.event.event_metrics import EventMetrics, EventType from bll.task import TaskBLL from config import config from database.errors import translate_errors_context from database.model.task.task import Task, TaskStatus +from redis_manager import redman from timing_context import TimingContext from utilities.dicts import flatten_nested_items - -class EventType(Enum): - metrics_scalar = "training_stats_scalar" - metrics_vector = "training_stats_vector" - metrics_image = "training_debug_image" - metrics_plot = "plot" - task_log = "log" - - # noinspection PyTypeChecker EVENT_TYPES = set(map(attrgetter("value"), EventType)) - - LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) -@attr.s +@attr.s(auto_attribs=True) class TaskEventsResult(object): - events = attr.ib(type=list, default=attr.Factory(list)) - total_events = attr.ib(type=int, default=0) - next_scroll_id = attr.ib(type=str, default=None) + total_events: int = 0 + next_scroll_id: str = None + events: list = attr.ib(factory=list) class EventBLL(object): id_fields = ("task", "iter", "metric", "variant", "key") - def __init__(self, events_es=None): + def __init__(self, events_es=None, redis=None): self.es = events_es or es_factory.connect("events") self._metrics = EventMetrics(self.es) self._skip_iteration_for_metric = set( config.get("services.events.ignore_iteration.metrics", []) ) + self.redis = redis or redman.connection("apiserver") + self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis) @property def metrics(self) -> EventMetrics: @@ -64,9 +56,12 @@ class EventBLL(object): actions = [] task_ids = set() task_iteration = defaultdict(lambda: 0) - task_last_events = nested_dict( + task_last_scalar_events = nested_dict( 3, dict ) # task_id -> metric_hash -> variant_hash -> MetricEvent + task_last_events = nested_dict( + 3, dict + ) # task_id -> metric_hash -> event_type -> MetricEvent for event in events: # remove spaces from event type @@ -108,6 +103,9 @@ class EventBLL(object): event["value"] = event["values"] del event["values"] + event["metric"] = event.get("metric") or "" + event["variant"] = event.get("variant") or "" + index_name = EventMetrics.get_index_name(company_id, event_type) es_action = { "_op_type": "index", # overwrite if exists with same ID @@ -132,9 +130,12 @@ class EventBLL(object): ): task_iteration[task_id] = max(iter, task_iteration[task_id]) + self._update_last_metric_events_for_task( + last_events=task_last_events[task_id], event=event, + ) if event_type == EventType.metrics_scalar.value: - self._update_last_metric_event_for_task( - task_last_events=task_last_events, task_id=task_id, event=event + self._update_last_scalar_events_for_task( + last_events=task_last_scalar_events[task_id], event=event ) else: es_action["_routing"] = task_id @@ -187,6 +188,7 @@ class EventBLL(object): task_id=task_id, now=now, iter_max=task_iteration.get(task_id), + last_scalar_events=task_last_scalar_events.get(task_id), last_events=task_last_events.get(task_id), ) @@ -202,12 +204,12 @@ class EventBLL(object): return added, errors_in_bulk - def _update_last_metric_event_for_task(self, task_last_events, task_id, event): + def _update_last_scalar_events_for_task(self, last_events, event): """ - Update task_last_events structure for the provided task_id with the provided event details if this event is more + Update last_events structure with the provided event details if this event is more recent than the currently stored event for its metric/variant combination. - task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb + last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb key conflicts due to invalid characters and/or long field names. """ metric = event.get("metric") @@ -218,13 +220,34 @@ class EventBLL(object): metric_hash = dbutils.hash_field_name(metric) variant_hash = dbutils.hash_field_name(variant) - last_events = task_last_events[task_id] - timestamp = last_events[metric_hash][variant_hash].get("timestamp", None) if timestamp is None or timestamp < event["timestamp"]: last_events[metric_hash][variant_hash] = event - def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None): + def _update_last_metric_events_for_task(self, last_events, event): + """ + Update last_events structure with the provided event details if this event is more + recent than the currently stored event for its metric/event_type combination. + last_events contains [metric_name -> event_type -> event] + """ + metric = event.get("metric") + event_type = event.get("type") + if not (metric and event_type): + return + + timestamp = last_events[metric][event_type].get("timestamp", None) + if timestamp is None or timestamp < event["timestamp"]: + last_events[metric][event_type] = event + + def _update_task( + self, + company_id, + task_id, + now, + iter_max=None, + last_scalar_events=None, + last_events=None, + ): """ Update task information in DB with aggregated results after handling event(s) related to this task. @@ -237,15 +260,18 @@ class EventBLL(object): if iter_max is not None: fields["last_iteration_max"] = iter_max - if last_events: - fields["last_values"] = list( + if last_scalar_events: + fields["last_scalar_values"] = list( flatten_nested_items( - last_events, + last_scalar_events, nesting=2, include_leaves=["value", "metric", "variant"], ) ) + if last_events: + fields["last_events"] = last_events + if not fields: return False diff --git a/server/bll/event/event_metrics.py b/server/bll/event/event_metrics.py index a90ab3a..d83d6b0 100644 --- a/server/bll/event/event_metrics.py +++ b/server/bll/event/event_metrics.py @@ -1,13 +1,13 @@ import itertools from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from enum import Enum from functools import partial from operator import itemgetter +from typing import Sequence, Tuple, Callable, Iterable from boltons.iterutils import bucketize from elasticsearch import Elasticsearch -from typing import Sequence, Tuple, Callable, Iterable - from mongoengine import Q from apierrors import errors @@ -21,6 +21,14 @@ from utilities import safe_get log = config.logger(__file__) +class EventType(Enum): + metrics_scalar = "training_stats_scalar" + metrics_vector = "training_stats_vector" + metrics_image = "training_debug_image" + metrics_plot = "plot" + task_log = "log" + + class EventMetrics: MAX_TASKS_COUNT = 50 MAX_METRICS_COUNT = 200 @@ -66,7 +74,8 @@ class EventMetrics: """ if len(task_ids) > self.MAX_TASKS_COUNT: raise errors.BadRequest( - f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", len(task_ids) + f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", + len(task_ids), ) task_name_by_id = {} @@ -168,9 +177,7 @@ class EventMetrics: with ThreadPoolExecutor(max_workers=max_concurrency) as pool: metrics = itertools.chain.from_iterable( pool.map( - partial( - get_func, task_ids=task_ids, es_index=es_index, key=key - ), + partial(get_func, task_ids=task_ids, es_index=es_index, key=key), intervals, ) ) @@ -440,3 +447,50 @@ class EventMetrics: ] } } + + def get_tasks_metrics( + self, company_id, task_ids: Sequence, event_type: EventType + ) -> Sequence: + """ + For the requested tasks return all the metrics that + reported events of the requested types + """ + es_index = EventMetrics.get_index_name(company_id, event_type.value) + if not self.es.indices.exists(es_index): + return {} + + max_concurrency = config.get("services.events.max_metrics_concurrency", 4) + with ThreadPoolExecutor(max_concurrency) as pool: + res = pool.map( + partial( + self._get_task_metrics, es_index=es_index, event_type=event_type, + ), + task_ids, + ) + return list(zip(task_ids, res)) + + def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence: + es_req = { + "size": 0, + "query": { + "bool": { + "must": [ + {"term": {"task": task_id}}, + {"term": {"type": event_type.value}}, + ] + } + }, + "aggs": { + "metrics": { + "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT} + } + }, + } + + with translate_errors_context(), TimingContext("es", "_get_task_metrics"): + es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + + return [ + metric["key"] + for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[]) + ] diff --git a/server/bll/redis_cache_manager.py b/server/bll/redis_cache_manager.py new file mode 100644 index 0000000..3674de1 --- /dev/null +++ b/server/bll/redis_cache_manager.py @@ -0,0 +1,44 @@ +from typing import Optional, TypeVar, Generic, Type + +from redis import StrictRedis + +from timing_context import TimingContext + +T = TypeVar("T") + + +class RedisCacheManager(Generic[T]): + """ + Class for store/retreive of state objects from redis + + self.state_class - class of the state + self.redis - instance of redis + self.expiration_interval - expiration interval in seconds + """ + + def __init__( + self, state_class: Type[T], redis: StrictRedis, expiration_interval: int + ): + self.state_class = state_class + self.redis = redis + self.expiration_interval = expiration_interval + + def set_state(self, state: T) -> None: + redis_key = self._get_redis_key(state.id) + with TimingContext("redis", "cache_set_state"): + self.redis.set(redis_key, state.to_json()) + self.redis.expire(redis_key, self.expiration_interval) + + def get_state(self, state_id) -> Optional[T]: + redis_key = self._get_redis_key(state_id) + with TimingContext("redis", "cache_get_state"): + response = self.redis.get(redis_key) + if response: + return self.state_class.from_json(response) + + def delete_state(self, state_id) -> None: + with TimingContext("redis", "cache_delete_state"): + self.redis.delete(self._get_redis_key(state_id)) + + def _get_redis_key(self, state_id): + return f"{self.state_class}/{state_id}" diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index 2df6f8a..38e53e4 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -3,13 +3,14 @@ from datetime import datetime, timedelta from operator import attrgetter from random import random from time import sleep -from typing import Collection, Sequence, Tuple, Any, Optional, List +from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict import pymongo.results import six from mongoengine import Q from six import string_types +import database.utils as dbutils import es_factory from apierrors import errors from apimodels.tasks import Artifact as ApiArtifact @@ -17,6 +18,7 @@ from config import config from database.errors import translate_errors_context from database.model.model import Model from database.model.project import Project +from database.model.task.metrics import EventStats, MetricEventStats from database.model.task.output import Output from database.model.task.task import ( Task, @@ -197,7 +199,9 @@ class TaskBLL(object): system_tags=system_tags or [], type=task.type, script=task.script, - output=Output(destination=task.output.destination) if task.output else None, + output=Output(destination=task.output.destination) + if task.output + else None, execution=execution_dict, ) cls.validate(new_task) @@ -277,7 +281,8 @@ class TaskBLL(object): last_update: datetime = None, last_iteration: int = None, last_iteration_max: int = None, - last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None, + last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None, + last_events: Dict[str, Dict[str, dict]] = None, **extra_updates, ): """ @@ -289,7 +294,8 @@ class TaskBLL(object): task's last iteration value. :param last_iteration_max: Last reported iteration. Use this to conditionally set a value only if the current task's last iteration value is smaller than the provided value. - :param last_values: Last reported metrics summary (value, metric, variant). + :param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant). + :param last_events: Last reported metrics summary (value, metric, event type). :param extra_updates: Extra task updates to include in this update call. :return: """ @@ -300,17 +306,33 @@ class TaskBLL(object): elif last_iteration_max is not None: extra_updates.update(max__last_iteration=last_iteration_max) - if last_values is not None: + if last_scalar_values is not None: def op_path(op, *path): return "__".join((op, "last_metrics") + path) - for path, value in last_values: + for path, value in last_scalar_values: extra_updates[op_path("set", *path)] = value if path[-1] == "value": extra_updates[op_path("min", *path[:-1], "min_value")] = value extra_updates[op_path("max", *path[:-1], "max_value")] = value + if last_events is not None: + + def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]: + return { + event_type: EventStats(last_update=event["timestamp"]) + for event_type, event in metric_data.items() + } + + metric_stats = { + dbutils.hash_field_name(metric_key): MetricEventStats( + metric=metric_key, event_stats_by_type=events_per_type(metric_data), + ) + for metric_key, metric_data in last_events.items() + } + extra_updates["metric_stats"] = metric_stats + Task.objects(id=task_id, company=company_id).update( upsert=False, last_update=last_update, **extra_updates ) diff --git a/server/config/default/hosts.conf b/server/config/default/hosts.conf index 2d91366..17d9ab8 100644 --- a/server/config/default/hosts.conf +++ b/server/config/default/hosts.conf @@ -32,6 +32,11 @@ mongo { } redis { + apiserver { + host: "127.0.0.1" + port: 6379 + db: 0 + } workers { host: "127.0.0.1" port: 6379 diff --git a/server/database/model/task/metrics.py b/server/database/model/task/metrics.py index ce65aae..64c3a25 100644 --- a/server/database/model/task/metrics.py +++ b/server/database/model/task/metrics.py @@ -1,10 +1,18 @@ -from mongoengine import EmbeddedDocument, StringField, DynamicField +from mongoengine import ( + EmbeddedDocument, + StringField, + DynamicField, + LongField, + EmbeddedDocumentField, +) + +from database.fields import SafeMapField class MetricEvent(EmbeddedDocument): meta = { # For backwards compatibility reasons - 'strict': False, + "strict": False, } metric = StringField(required=True) @@ -12,3 +20,20 @@ class MetricEvent(EmbeddedDocument): value = DynamicField(required=True) min_value = DynamicField() # for backwards compatibility reasons max_value = DynamicField() # for backwards compatibility reasons + + +class EventStats(EmbeddedDocument): + meta = { + # For backwards compatibility reasons + "strict": False, + } + last_update = LongField() + + +class MetricEventStats(EmbeddedDocument): + meta = { + # For backwards compatibility reasons + "strict": False, + } + metric = StringField(required=True) + event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats)) diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index e7b2f49..0b55aa5 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -22,7 +22,7 @@ from database.model.base import ProperDictMixin from database.model.model_labels import ModelLabels from database.model.project import Project from database.utils import get_options -from .metrics import MetricEvent +from .metrics import MetricEvent, MetricEventStats from .output import Output DEFAULT_LAST_ITERATION = 0 @@ -162,3 +162,4 @@ class Task(AttributedDocument): last_update = DateTimeField() last_iteration = IntField(default=DEFAULT_LAST_ITERATION) last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent))) + metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) diff --git a/server/schema/services/events.conf b/server/schema/services/events.conf index ffd363c..07f0057 100644 --- a/server/schema/services/events.conf +++ b/server/schema/services/events.conf @@ -171,6 +171,30 @@ critical ] } + event_type_enum { + type: string + enum: [ + training_stats_scalar + training_stats_vector + training_debug_image + plot + log + ] + } + task_metric { + type: object + required: [task, metric] + properties { + task { + description: "Task ID" + type: string + } + metric { + description: "Metric name" + type: string + } + } + } task_log_event { description: """A log event associated with a task.""" type: object @@ -319,6 +343,84 @@ } } } + "2.7" { + description: "Get the debug image events for the requested amount of iterations per each task's metric" + request { + type: object + required: [ + metrics + ] + properties { + metrics { + type: array + items { "$ref": "#/definitions/task_metric" } + description: "List metrics for which the envents will be retreived" + } + iters { + type: integer + description: "Max number of latest iterations for which to return debug images" + } + navigate_earlier { + type: boolean + description: "If set then events are retreived from later iterations to earlier ones. Otherwise from earlier iterations to the later. The default is True" + } + refresh { + type: boolean + description: "If set then scroll will be moved to the latest iterations. The default is False" + } + scroll_id { + type: string + description: "Scroll ID of previous call (used for getting more results)" + } + } + } + response { + type: object + properties { + metrics { + type: array + items: { type: object } + description: "Debug image events grouped by task metrics and iterations" + } + scroll_id { + type: string + description: "Scroll ID for getting more results" + } + } + } + } + } + get_task_metrics{ + "2.7": { + description: "For each task, get a list of metrics for which the requested event type was reported" + request { + type: object + required: [ + tasks + ] + properties { + tasks { + type: array + items { type: string } + description: "Task IDs" + } + event_type { + "description": "Event type" + "$ref": "#/definitions/event_type_enum" + } + } + } + response { + type: object + properties { + metrics { + type: array + items { type: object } + description: "List of task with their metrics" + } + } + } + } } get_task_log { "1.5" { diff --git a/server/services/events.py b/server/services/events.py index 0f95a58..83601cb 100644 --- a/server/services/events.py +++ b/server/services/events.py @@ -2,12 +2,15 @@ import itertools from collections import defaultdict from operator import itemgetter -import six - from apierrors import errors from apimodels.events import ( MultiTaskScalarMetricsIterHistogramRequest, ScalarMetricsIterHistogramRequest, + DebugImagesRequest, + DebugImageResponse, + MetricEvents, + IterationEvents, + TaskMetricsRequest, ) from bll.event import EventBLL from bll.event.event_metrics import EventMetrics @@ -299,7 +302,7 @@ def multi_task_scalar_metrics_iter_histogram( call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest ): task_ids = req_model.tasks - if isinstance(task_ids, six.string_types): + if isinstance(task_ids, str): task_ids = [s.strip() for s in task_ids.split(",")] # Note, bll already validates task ids as it needs their names call.result.data = dict( @@ -481,7 +484,7 @@ def get_debug_images_v1_7(call, company_id, req_model): @endpoint("events.debug_images", min_version="1.8", required_fields=["task"]) -def get_debug_images(call, company_id, req_model): +def get_debug_images_v1_8(call, company_id, req_model): task_id = call.data["task"] iters = call.data.get("iters") or 1 scroll_id = call.data.get("scroll_id") @@ -507,6 +510,53 @@ def get_debug_images(call, company_id, req_model): ) +@endpoint( + "events.debug_images", + min_version="2.7", + request_data_model=DebugImagesRequest, + response_data_model=DebugImageResponse, +) +def get_debug_images(call, company_id, req_model: DebugImagesRequest): + tasks = set(m.task for m in req_model.metrics) + task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True) + result = event_bll.debug_images_iterator.get_task_events( + company_id=company_id, + metrics=[(m.task, m.metric) for m in req_model.metrics], + iter_count=req_model.iters, + navigate_earlier=req_model.navigate_earlier, + refresh=req_model.refresh, + state_id=req_model.scroll_id, + ) + + call.result.data_model = DebugImageResponse( + scroll_id=result.next_scroll_id, + metrics=[ + MetricEvents( + task=task, + metric=metric, + iterations=[ + IterationEvents(iter=iteration["iter"], events=iteration["events"]) + for iteration in iterations + ], + ) + for (task, metric, iterations) in result.metric_events + ], + ) + + +@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) +def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest): + task_bll.assert_exists( + call.identity.company, task_ids=req_model.tasks, allow_public=True + ) + res = event_bll.metrics.get_tasks_metrics( + company_id, task_ids=req_model.tasks, event_type=req_model.event_type + ) + call.result.data = { + "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] + } + + @endpoint("events.delete_for_task", required_fields=["task"]) def delete_for_task(call, company_id, req_model): task_id = call.data["task"] diff --git a/server/tests/automated/test_task_events.py b/server/tests/automated/test_task_events.py index 4e43934..217eff5 100644 --- a/server/tests/automated/test_task_events.py +++ b/server/tests/automated/test_task_events.py @@ -4,19 +4,16 @@ Comprehensive test of all(?) use cases of datasets and frames import json import time import unittest +from functools import partial from statistics import mean - from typing import Sequence import es_factory -from config import config from tests.automated import TestService -log = config.logger(__file__) - class TestTaskEvents(TestService): - def setUp(self, version="1.7"): + def setUp(self, version="2.7"): super().setUp(version=version) def _temp_task(self, name="test task events"): @@ -25,13 +22,14 @@ class TestTaskEvents(TestService): ) return self.create_temp("tasks", **task_input) - def _create_task_event(self, type_, task, iteration): + def _create_task_event(self, type_, task, iteration, **kwargs): return { "worker": "test", "type": type_, "task": task, "iter": iteration, "timestamp": es_factory.get_timestamp_millis(), + **kwargs, } def _copy_and_update(self, src_obj, new_data): @@ -39,6 +37,134 @@ class TestTaskEvents(TestService): obj.update(new_data) return obj + def test_task_metrics(self): + tasks = { + self._temp_task(): { + "Metric1": ["training_debug_image"], + "Metric2": ["training_debug_image", "log"], + }, + self._temp_task(): {"Metric3": ["training_debug_image"]}, + } + events = [ + self._create_task_event( + event_type, + task=task, + iteration=1, + metric=metric, + variant="Test variant", + ) + for task, metrics in tasks.items() + for metric, event_types in metrics.items() + for event_type in event_types + ] + self.send_batch(events) + self._assert_task_metrics(tasks, "training_debug_image") + self._assert_task_metrics(tasks, "log") + self._assert_task_metrics(tasks, "training_stats_scalar") + + def _assert_task_metrics(self, tasks: dict, event_type: str): + res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type) + for task, metrics in tasks.items(): + res_metrics = next( + (tm.metrics for tm in res.metrics if tm.task == task), () + ) + self.assertEqual( + set(res_metrics), + set( + metric for metric, events in metrics.items() if event_type in events + ), + ) + + def test_task_debug_images(self): + task = self._temp_task() + metric = "Metric1" + variants = [("Variant1", 7), ("Variant2", 4)] + iterations = 10 + + # test empty + res = self.api.events.debug_images( + metrics=[{"task": task, "metric": metric}], + iters=5, + ) + self.assertFalse(res.metrics) + + # create events + events = [ + self._create_task_event( + "training_debug_image", + task=task, + iteration=n, + metric=metric, + variant=variant, + url=f"{metric}_{variant}_{n % unique_images}", + ) + for n in range(iterations) + for (variant, unique_images) in variants + ] + self.send_batch(events) + + # init testing + unique_images = [unique for (_, unique) in variants] + scroll_id = None + assert_debug_images = partial( + self._assertDebugImages, + task=task, + metric=metric, + max_iter=iterations - 1, + unique_images=unique_images, + ) + + # test forward navigation + for page in range(3): + scroll_id = assert_debug_images(scroll_id=scroll_id, page=page) + + # test backwards navigation + scroll_id = assert_debug_images( + scroll_id=scroll_id, page=0, navigate_earlier=False + ) + + # beyond the latest iteration and back + res = self.api.events.debug_images( + metrics=[{"task": task, "metric": metric}], + iters=5, + scroll_id=scroll_id, + navigate_earlier=False, + ) + self.assertEqual(len(res["metrics"][0]["iterations"]), 0) + assert_debug_images(scroll_id=scroll_id, page=1) + + # refresh + assert_debug_images(scroll_id=scroll_id, page=0, refresh=True) + + def _assertDebugImages( + self, + task, + metric, + max_iter: int, + unique_images: Sequence[int], + scroll_id, + page: int, + iters: int = 5, + **extra_params, + ): + res = self.api.events.debug_images( + metrics=[{"task": task, "metric": metric}], + iters=iters, + scroll_id=scroll_id, + **extra_params, + ) + data = res["metrics"][0] + self.assertEqual(data["task"], task) + self.assertEqual(data["metric"], metric) + left_iterations = max(0, max(unique_images) - page * iters) + self.assertEqual(len(data["iterations"]), min(iters, left_iterations)) + for it in data["iterations"]: + events_per_iter = sum( + 1 for unique in unique_images if unique > max_iter - it["iter"] + ) + self.assertEqual(len(it["events"]), events_per_iter) + return res.scroll_id + def test_task_logs(self): events = [] task = self._temp_task()