From 35a11db58e8b8444563b1ac2ffc745212738816d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 1 Jun 2020 11:27:36 +0300 Subject: [PATCH] Support task log retrieval with no scroll --- server/apimodels/__init__.py | 10 ++ server/apimodels/events.py | 8 + server/apimodels/workers.py | 12 +- server/bll/event/debug_images_iterator.py | 72 +++++---- server/bll/event/event_bll.py | 10 +- server/bll/event/log_events_iterator.py | 169 +++++++++++++++++++++ server/bll/redis_cache_manager.py | 39 ++++- server/config/default/services/events.conf | 6 +- server/schema/services/events.conf | 1 - server/services/events.py | 25 ++- server/services/projects.py | 4 +- server/tests/automated/test_task_events.py | 98 ++++++++---- 12 files changed, 361 insertions(+), 93 deletions(-) create mode 100644 server/bll/event/log_events_iterator.py diff --git a/server/apimodels/__init__.py b/server/apimodels/__init__.py index 1367a79..e4fac44 100644 --- a/server/apimodels/__init__.py +++ b/server/apimodels/__init__.py @@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError from validators import email as email_validator, domain as domain_validator from apierrors import errors +from utilities.json import loads, dumps def make_default(field_cls, default_value): @@ -213,3 +214,12 @@ class StringEnum(Enum): # noinspection PyMethodParameters def _generate_next_value_(name, start, count, last_values): return name + + +class JsonSerializableMixin: + def to_json(self: ModelBase): + return dumps(self.to_struct()) + + @classmethod + def from_json(cls: Type[ModelBase], s): + return cls(**loads(s)) diff --git a/server/apimodels/events.py b/server/apimodels/events.py index f5315b7..5b8dcdc 100644 --- a/server/apimodels/events.py +++ b/server/apimodels/events.py @@ -40,6 +40,14 @@ class DebugImagesRequest(Base): scroll_id: str = StringField() +class LogEventsRequest(Base): + task: str = StringField(required=True) + batch_size: int = IntField(default=500) + navigate_earlier: bool = BoolField(default=True) + refresh: bool = BoolField(default=False) + scroll_id: str = StringField() + + class IterationEvents(Base): iter: int = IntField() events: Sequence[dict] = ListField(items_types=dict) diff --git a/server/apimodels/workers.py b/server/apimodels/workers.py index 9ecac8d..40d8f4d 100644 --- a/server/apimodels/workers.py +++ b/server/apimodels/workers.py @@ -1,4 +1,3 @@ -import json from enum import Enum import six @@ -13,7 +12,7 @@ from jsonmodels.fields import ( ) from jsonmodels.models import Base -from apimodels import make_default, ListField, EnumField +from apimodels import make_default, ListField, EnumField, JsonSerializableMixin DEFAULT_TIMEOUT = 10 * 60 @@ -61,7 +60,7 @@ class IdNameEntry(Base): name = StringField() -class WorkerEntry(Base): +class WorkerEntry(Base, JsonSerializableMixin): key = StringField() # not required due to migration issues id = StringField(required=True) user = EmbeddedField(IdNameEntry) @@ -75,13 +74,6 @@ class WorkerEntry(Base): last_activity_time = DateTimeField(required=True) last_report_time = DateTimeField() - def to_json(self): - return json.dumps(self.to_struct()) - - @classmethod - def from_json(cls, s): - return cls(**json.loads(s)) - class CurrentTaskEntry(IdNameEntry): running_time = IntField() diff --git a/server/bll/event/debug_images_iterator.py b/server/bll/event/debug_images_iterator.py index b755fc3..ccedf09 100644 --- a/server/bll/event/debug_images_iterator.py +++ b/server/bll/event/debug_images_iterator.py @@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor from functools import partial from itertools import chain from operator import attrgetter, itemgetter +from typing import Sequence, Tuple, Optional, Mapping import attr import dpath from boltons.iterutils import bucketize from elasticsearch import Elasticsearch +from jsonmodels.fields import StringField, ListField, IntField +from jsonmodels.models import Base from redis import StrictRedis -from typing import Sequence, Tuple, Optional, Mapping -import database from apierrors import errors -from bll.redis_cache_manager import RedisCacheManager +from apimodels import JsonSerializableMixin from bll.event.event_metrics import EventMetrics +from bll.redis_cache_manager import RedisCacheManager from config import config from database.errors import translate_errors_context -from jsonmodels.models import Base -from jsonmodels.fields import StringField, ListField, IntField - from database.model.task.metrics import MetricEventStats from database.model.task.task import Task from timing_context import TimingContext -from utilities.json import loads, dumps class VariantScrollState(Base): @@ -45,17 +43,10 @@ class MetricScrollState(Base): self.last_min_iter = self.last_max_iter = None -class DebugImageEventsScrollState(Base): +class DebugImageEventsScrollState(Base, JsonSerializableMixin): id: str = StringField(required=True) metrics: Sequence[MetricScrollState] = ListField([MetricScrollState]) - def to_json(self): - return dumps(self.to_struct()) - - @classmethod - def from_json(cls, s): - return cls(**loads(s)) - @attr.s(auto_attribs=True) class DebugImagesResult(object): @@ -65,7 +56,12 @@ class DebugImagesResult(object): class DebugImagesIterator: EVENT_TYPE = "training_debug_image" - STATE_EXPIRATION_SECONDS = 3600 + + @property + def state_expiration_sec(self): + return config.get( + f"services.events.events_retrieval.state_expiration_sec", 3600 + ) @property def _max_workers(self): @@ -76,7 +72,7 @@ class DebugImagesIterator: self.cache_manager = RedisCacheManager( state_class=DebugImageEventsScrollState, redis=redis, - expiration_interval=self.STATE_EXPIRATION_SECONDS, + expiration_interval=self.state_expiration_sec, ) def get_task_events( @@ -92,27 +88,31 @@ class DebugImagesIterator: if not self.es.indices.exists(es_index): return DebugImagesResult() - unique_metrics = set(metrics) - state = self.cache_manager.get_state(state_id) if state_id else None - if not state: - state = DebugImageEventsScrollState( - id=database.utils.id(), - metrics=self._init_metric_states(es_index, list(unique_metrics)), - ) - else: - state_metrics = set((m.task, m.name) for m in state.metrics) - if state_metrics != unique_metrics: - raise errors.bad_request.InvalidScrollId( - "while getting debug images events", scroll_id=state_id - ) + def init_state(state_: DebugImageEventsScrollState): + unique_metrics = set(metrics) + state_.metrics = self._init_metric_states(es_index, list(unique_metrics)) + def validate_state(state_: DebugImageEventsScrollState): + """ + Validate that the metrics stored in the state are the same + as requested in the current call. + Refresh the state if requested + """ + state_metrics = set((m.task, m.name) for m in state_.metrics) + if state_metrics != set(metrics): + raise errors.bad_request.InvalidScrollId( + "Task metrics stored in the state do not match the passed ones", + scroll_id=state_.id, + ) if refresh: - self._reinit_outdated_metric_states(company_id, es_index, state) - for metric_state in state.metrics: + self._reinit_outdated_metric_states(company_id, es_index, state_) + for metric_state in state_.metrics: metric_state.reset() - res = DebugImagesResult(next_scroll_id=state.id) - try: + with self.cache_manager.get_or_create_state( + state_id=state_id, init_state=init_state, validate_state=validate_state + ) as state: + res = DebugImagesResult(next_scroll_id=state.id) with ThreadPoolExecutor(self._max_workers) as pool: res.metric_events = list( pool.map( @@ -125,10 +125,8 @@ class DebugImagesIterator: state.metrics, ) ) - finally: - self.cache_manager.set_state(state) - return res + return res def _reinit_outdated_metric_states( self, company_id, es_index, state: DebugImageEventsScrollState diff --git a/server/bll/event/event_bll.py b/server/bll/event/event_bll.py index 3c0600f..fb33d3d 100644 --- a/server/bll/event/event_bll.py +++ b/server/bll/event/event_bll.py @@ -5,7 +5,6 @@ from datetime import datetime from operator import attrgetter from typing import Sequence -import attr import six from elasticsearch import helpers from mongoengine import Q @@ -16,6 +15,7 @@ import es_factory from apierrors import errors from bll.event.debug_images_iterator import DebugImagesIterator from bll.event.event_metrics import EventMetrics, EventType +from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult from bll.task import TaskBLL from config import config from database.errors import translate_errors_context @@ -29,13 +29,6 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType)) LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) -@attr.s(auto_attribs=True) -class TaskEventsResult(object): - total_events: int = 0 - next_scroll_id: str = None - events: list = attr.ib(factory=list) - - class EventBLL(object): id_fields = ("task", "iter", "metric", "variant", "key") @@ -47,6 +40,7 @@ class EventBLL(object): ) self.redis = redis or redman.connection("apiserver") self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis) + self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis) @property def metrics(self) -> EventMetrics: diff --git a/server/bll/event/log_events_iterator.py b/server/bll/event/log_events_iterator.py new file mode 100644 index 0000000..79259d7 --- /dev/null +++ b/server/bll/event/log_events_iterator.py @@ -0,0 +1,169 @@ +from typing import Optional, Tuple, Sequence + +import attr +from elasticsearch import Elasticsearch +from jsonmodels.fields import StringField, IntField +from jsonmodels.models import Base +from redis import StrictRedis + +from apierrors import errors +from apimodels import JsonSerializableMixin +from bll.event.event_metrics import EventMetrics +from bll.redis_cache_manager import RedisCacheManager +from config import config +from database.errors import translate_errors_context +from timing_context import TimingContext + + +class LogEventsScrollState(Base, JsonSerializableMixin): + id: str = StringField(required=True) + task: str = StringField(required=True) + last_min_timestamp: Optional[int] = IntField() + last_max_timestamp: Optional[int] = IntField() + + def reset(self): + """Reset the scrolling state """ + self.last_min_timestamp = self.last_max_timestamp = None + + +@attr.s(auto_attribs=True) +class TaskEventsResult: + total_events: int = 0 + next_scroll_id: str = None + events: list = attr.Factory(list) + + +class LogEventsIterator: + EVENT_TYPE = "log" + + @property + def state_expiration_sec(self): + return config.get( + f"services.events.events_retrieval.state_expiration_sec", 3600 + ) + + def __init__(self, redis: StrictRedis, es: Elasticsearch): + self.es = es + self.cache_manager = RedisCacheManager( + state_class=LogEventsScrollState, + redis=redis, + expiration_interval=self.state_expiration_sec, + ) + + def get_task_events( + self, + company_id: str, + task_id: str, + batch_size: int, + navigate_earlier: bool = True, + refresh: bool = False, + state_id: str = None, + ) -> TaskEventsResult: + es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) + if not self.es.indices.exists(es_index): + return TaskEventsResult() + + def init_state(state_: LogEventsScrollState): + state_.task = task_id + + def validate_state(state_: LogEventsScrollState): + """ + Checks that the task id stored in the state + is equal to the one passed with the current call + Refresh the state if requested + """ + if state_.task != task_id: + raise errors.bad_request.InvalidScrollId( + "Task stored in the state does not match the passed one", + scroll_id=state_.id, + ) + if refresh: + state_.reset() + + with self.cache_manager.get_or_create_state( + state_id=state_id, init_state=init_state, validate_state=validate_state, + ) as state: + res = TaskEventsResult(next_scroll_id=state.id) + res.events, res.total_events = self._get_events( + es_index=es_index, + batch_size=batch_size, + navigate_earlier=navigate_earlier, + state=state, + ) + return res + + def _get_events( + self, + es_index, + batch_size: int, + navigate_earlier: bool, + state: LogEventsScrollState, + ) -> Tuple[Sequence[dict], int]: + """ + Return up to 'batch size' events starting from the previous timestamp either in the + direction of earlier events (navigate_earlier=True) or in the direction of later events. + If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest. + For the last timestamp all the events are brought (even if the resulting size + exceeds batch_size) so that this timestamp events will not be lost between the calls. + In case any events were received update 'last_min_timestamp' and 'last_max_timestamp' + """ + + # retrieve the next batch of events + es_req = { + "size": batch_size, + "query": {"term": {"task": state.task}}, + "sort": {"timestamp": "desc" if navigate_earlier else "asc"}, + } + + if navigate_earlier and state.last_min_timestamp is not None: + es_req["search_after"] = [state.last_min_timestamp] + elif not navigate_earlier and state.last_max_timestamp is not None: + es_req["search_after"] = [state.last_max_timestamp] + + with translate_errors_context(), TimingContext("es", "get_task_events"): + es_result = self.es.search(index=es_index, body=es_req, routing=state.task) + hits = es_result["hits"]["hits"] + hits_total = es_result["hits"]["total"] + if not hits: + return [], hits_total + + events = [hit["_source"] for hit in hits] + if navigate_earlier: + state.last_max_timestamp = events[0]["timestamp"] + state.last_min_timestamp = events[-1]["timestamp"] + else: + state.last_min_timestamp = events[0]["timestamp"] + state.last_max_timestamp = events[-1]["timestamp"] + + # retrieve the events that match the last event timestamp + # but did not make it into the previous call due to batch_size limitation + es_req = { + "size": 10000, + "query": { + "bool": { + "must": [ + {"term": {"task": state.task}}, + {"term": {"timestamp": events[-1]["timestamp"]}}, + ] + } + }, + } + es_result = self.es.search(index=es_index, body=es_req, routing=state.task) + hits = es_result["hits"]["hits"] + if not hits or len(hits) < 2: + # if only one element is returned for the last timestamp + # then it is already present in the events + return events, hits_total + + last_events = [hit["_source"] for hit in es_result["hits"]["hits"]] + already_present_ids = set(ev["_id"] for ev in events) + + # return the list merged from original query results + + # leftovers from the last timestamp + return ( + [ + *events, + *(ev for ev in last_events if ev["_id"] not in already_present_ids), + ], + hits_total, + ) diff --git a/server/bll/redis_cache_manager.py b/server/bll/redis_cache_manager.py index 3674de1..6aac68b 100644 --- a/server/bll/redis_cache_manager.py +++ b/server/bll/redis_cache_manager.py @@ -1,15 +1,21 @@ -from typing import Optional, TypeVar, Generic, Type +from contextlib import contextmanager +from typing import Optional, TypeVar, Generic, Type, Callable from redis import StrictRedis +import database from timing_context import TimingContext T = TypeVar("T") +def _do_nothing(_: T): + return + + class RedisCacheManager(Generic[T]): """ - Class for store/retreive of state objects from redis + Class for store/retrieve of state objects from redis self.state_class - class of the state self.redis - instance of redis @@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]): def _get_redis_key(self, state_id): return f"{self.state_class}/{state_id}" + + @contextmanager + def get_or_create_state( + self, + state_id=None, + init_state: Callable[[T], None] = _do_nothing, + validate_state: Callable[[T], None] = _do_nothing, + ): + """ + Try to retrieve state with the given id from the Redis cache if yes then validates it + If no then create a new one with randomly generated id + Yield the state and write it back to redis once the user code block exits + :param state_id: id of the state to retrieve + :param init_state: user callback to init the newly created state + If not passed then no init except for the id generation is done + :param validate_state: user callback to validate the state if retrieved from cache + Should throw an exception if the state is not valid. If not passed then no validation is done + """ + state = self.get_state(state_id) if state_id else None + if state: + validate_state(state) + else: + state = self.state_class(id=database.utils.id()) + init_state(state) + + try: + yield state + finally: + self.set_state(state) diff --git a/server/config/default/services/events.conf b/server/config/default/services/events.conf index 91f5810..5adea87 100644 --- a/server/config/default/services/events.conf +++ b/server/config/default/services/events.conf @@ -6,4 +6,8 @@ ignore_iteration { # max number of concurrent queries to ES when calculating events metrics # should not exceed the amount of concurrent connections set in the ES driver -max_metrics_concurrency: 4 \ No newline at end of file +max_metrics_concurrency: 4 + +events_retrieval { + state_expiration_sec: 3600 +} diff --git a/server/schema/services/events.conf b/server/schema/services/events.conf index cbd28b5..8f9ec96 100644 --- a/server/schema/services/events.conf +++ b/server/schema/services/events.conf @@ -258,7 +258,6 @@ properties { added { type: integer } errors { type: integer } - errors_info { type: object } } } } diff --git a/server/services/events.py b/server/services/events.py index 83601cb..219ce3d 100644 --- a/server/services/events.py +++ b/server/services/events.py @@ -11,6 +11,7 @@ from apimodels.events import ( MetricEvents, IterationEvents, TaskMetricsRequest, + LogEventsRequest, ) from bll.event import EventBLL from bll.event.event_metrics import EventMetrics @@ -45,7 +46,7 @@ def add_batch(call: APICall, company_id, req_model): @endpoint("events.get_task_log", required_fields=["task"]) -def get_task_log(call, company_id, req_model): +def get_task_log_v1_5(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id, allow_public=True) order = call.data.get("order") or "desc" @@ -93,6 +94,28 @@ def get_task_log_v1_7(call, company_id, req_model): ) +@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest) +def get_task_log(call, company_id, req_model: LogEventsRequest): + task_id = req_model.task + task_bll.assert_exists(company_id, task_id, allow_public=True) + + res = event_bll.log_events_iterator.get_task_events( + company_id=company_id, + task_id=task_id, + batch_size=req_model.batch_size, + navigate_earlier=req_model.navigate_earlier, + refresh=req_model.refresh, + state_id=req_model.scroll_id, + ) + + call.result.data = dict( + events=res.events, + returned=len(res.events), + total=res.total_events, + scroll_id=res.next_scroll_id, + ) + + @endpoint("events.download_task_log", required_fields=["task"]) def download_task_log(call, company_id, req_model): task_id = call.data["task"] diff --git a/server/services/projects.py b/server/services/projects.py index ff98102..cdaecb4 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -210,7 +210,7 @@ def get_all_ex(call: APICall): status_count = defaultdict(lambda: {}) key = itemgetter(EntityVisibility.archived.value) - for result in Task.aggregate(status_count_pipeline): + for result in Task.aggregate(*status_count_pipeline): for k, group in groupby(sorted(result["counts"], key=key), key): section = ( EntityVisibility.archived if k else EntityVisibility.active @@ -224,7 +224,7 @@ def get_all_ex(call: APICall): runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} - for result in Task.aggregate(runtime_pipeline) + for result in Task.aggregate(*runtime_pipeline) } def safe_get(obj, path, default=None): diff --git a/server/tests/automated/test_task_events.py b/server/tests/automated/test_task_events.py index 217eff5..e4bf3af 100644 --- a/server/tests/automated/test_task_events.py +++ b/server/tests/automated/test_task_events.py @@ -2,7 +2,7 @@ Comprehensive test of all(?) use cases of datasets and frames """ import json -import time +import operator import unittest from functools import partial from statistics import mean @@ -22,21 +22,17 @@ class TestTaskEvents(TestService): ) return self.create_temp("tasks", **task_input) - def _create_task_event(self, type_, task, iteration, **kwargs): + @staticmethod + def _create_task_event(type_, task, iteration, **kwargs): return { "worker": "test", "type": type_, "task": task, "iter": iteration, - "timestamp": es_factory.get_timestamp_millis(), + "timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(), **kwargs, } - def _copy_and_update(self, src_obj, new_data): - obj = src_obj.copy() - obj.update(new_data) - return obj - def test_task_metrics(self): tasks = { self._temp_task(): { @@ -83,8 +79,7 @@ class TestTaskEvents(TestService): # test empty res = self.api.events.debug_images( - metrics=[{"task": task, "metric": metric}], - iters=5, + metrics=[{"task": task, "metric": metric}], iters=5, ) self.assertFalse(res.metrics) @@ -116,11 +111,11 @@ class TestTaskEvents(TestService): # test forward navigation for page in range(3): - scroll_id = assert_debug_images(scroll_id=scroll_id, page=page) + scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page) # test backwards navigation scroll_id = assert_debug_images( - scroll_id=scroll_id, page=0, navigate_earlier=False + scroll_id=scroll_id, expected_page=0, navigate_earlier=False ) # beyond the latest iteration and back @@ -131,10 +126,10 @@ class TestTaskEvents(TestService): navigate_earlier=False, ) self.assertEqual(len(res["metrics"][0]["iterations"]), 0) - assert_debug_images(scroll_id=scroll_id, page=1) + assert_debug_images(scroll_id=scroll_id, expected_page=1) # refresh - assert_debug_images(scroll_id=scroll_id, page=0, refresh=True) + assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True) def _assertDebugImages( self, @@ -143,7 +138,7 @@ class TestTaskEvents(TestService): max_iter: int, unique_images: Sequence[int], scroll_id, - page: int, + expected_page: int, iters: int = 5, **extra_params, ): @@ -156,7 +151,7 @@ class TestTaskEvents(TestService): data = res["metrics"][0] self.assertEqual(data["task"], task) self.assertEqual(data["metric"], metric) - left_iterations = max(0, max(unique_images) - page * iters) + left_iterations = max(0, max(unique_images) - expected_page * iters) self.assertEqual(len(data["iterations"]), min(iters, left_iterations)) for it in data["iterations"]: events_per_iter = sum( @@ -166,26 +161,67 @@ class TestTaskEvents(TestService): return res.scroll_id def test_task_logs(self): - events = [] task = self._temp_task() - for iter_ in range(10): - log_event = self._create_task_event("log", task, iteration=iter_) - events.append( - self._copy_and_update( - log_event, - {"msg": "This is a log message from test task iter " + str(iter_)}, - ) + timestamp = es_factory.get_timestamp_millis() + events = [ + self._create_task_event( + "log", + task=task, + iteration=iter_, + timestamp=timestamp + iter_ * 1000, + msg=f"This is a log message from test task iter {iter_}", ) - # sleep so timestamp is not the same - time.sleep(0.01) + for iter_ in range(10) + ] self.send_batch(events) - data = self.api.events.get_task_log(task=task) - assert len(data["events"]) == 10 + # test forward navigation + scroll_id = None + for page in range(3): + scroll_id = self._assert_log_events( + task=task, scroll_id=scroll_id, expected_page=page + ) - self.api.tasks.reset(task=task) - data = self.api.events.get_task_log(task=task) - assert len(data["events"]) == 0 + # test backwards navigation + scroll_id = self._assert_log_events( + task=task, scroll_id=scroll_id, navigate_earlier=False + ) + + # refresh + self._assert_log_events(task=task, scroll_id=scroll_id) + self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True) + + def _assert_log_events( + self, + task, + scroll_id, + batch_size: int = 5, + expected_total: int = 10, + expected_page: int = 0, + **extra_params, + ): + res = self.api.events.get_task_log( + task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params, + ) + self.assertEqual(res.total, expected_total) + expected_events = max( + 0, batch_size - max(0, (expected_page + 1) * batch_size - expected_total) + ) + self.assertEqual(res.returned, expected_events) + self.assertEqual(len(res.events), expected_events) + unique_events = len({ev.iter for ev in res.events}) + self.assertEqual(len(res.events), unique_events) + if res.events: + cmp_operator = operator.ge + if not extra_params.get("navigate_earlier", True): + cmp_operator = operator.le + self.assertTrue( + all( + cmp_operator(first.timestamp, second.timestamp) + for first, second in zip(res.events, res.events[1:]) + ) + ) + return res.scroll_id def test_task_metric_value_intervals_keys(self): metric = "Metric1"