Add events.get_task_log for improved log retrieval support

This commit is contained in:
allegroai 2020-07-06 21:54:25 +03:00
parent 8219e3d4e2
commit 21f2ea8b17
8 changed files with 99 additions and 172 deletions

View File

@ -1 +1 @@
__version__ = "2.8.0" __version__ = "2.9.0"

View File

@ -1,4 +1,4 @@
from typing import Sequence from typing import Sequence, Optional
from jsonmodels import validators from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField from jsonmodels.fields import StringField, BoolField
@ -44,8 +44,7 @@ class LogEventsRequest(Base):
task: str = StringField(required=True) task: str = StringField(required=True)
batch_size: int = IntField(default=500) batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True) navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False) from_timestamp: Optional[int] = IntField()
scroll_id: str = StringField()
class IterationEvents(Base): class IterationEvents(Base):

View File

@ -40,7 +40,7 @@ class EventBLL(object):
) )
self.redis = redis or redman.connection("apiserver") self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis) self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis) self.log_events_iterator = LogEventsIterator(es=self.es)
@property @property
def metrics(self) -> EventMetrics: def metrics(self) -> EventMetrics:

View File

@ -2,30 +2,12 @@ from typing import Optional, Tuple, Sequence
import attr import attr
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context from database.errors import translate_errors_context
from timing_context import TimingContext from timing_context import TimingContext
class LogEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
task: str = StringField(required=True)
last_min_timestamp: Optional[int] = IntField()
last_max_timestamp: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state """
self.last_min_timestamp = self.last_max_timestamp = None
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class TaskEventsResult: class TaskEventsResult:
total_events: int = 0 total_events: int = 0
@ -36,19 +18,8 @@ class TaskEventsResult:
class LogEventsIterator: class LogEventsIterator:
EVENT_TYPE = "log" EVENT_TYPE = "log"
@property def __init__(self, es: Elasticsearch):
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es self.es = es
self.cache_manager = RedisCacheManager(
state_class=LogEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events( def get_task_events(
self, self,
@ -56,48 +27,29 @@ class LogEventsIterator:
task_id: str, task_id: str,
batch_size: int, batch_size: int,
navigate_earlier: bool = True, navigate_earlier: bool = True,
refresh: bool = False, from_timestamp: Optional[int] = None,
state_id: str = None,
) -> TaskEventsResult: ) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
def init_state(state_: LogEventsScrollState): res = TaskEventsResult()
state_.task = task_id res.events, res.total_events = self._get_events(
es_index=es_index,
def validate_state(state_: LogEventsScrollState): task_id=task_id,
""" batch_size=batch_size,
Checks that the task id stored in the state navigate_earlier=navigate_earlier,
is equal to the one passed with the current call from_timestamp=from_timestamp,
Refresh the state if requested )
""" return res
if state_.task != task_id:
raise errors.bad_request.InvalidScrollId(
"Task stored in the state does not match the passed one",
scroll_id=state_.id,
)
if refresh:
state_.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res = TaskEventsResult(next_scroll_id=state.id)
res.events, res.total_events = self._get_events(
es_index=es_index,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
state=state,
)
return res
def _get_events( def _get_events(
self, self,
es_index, es_index,
task_id: str,
batch_size: int, batch_size: int,
navigate_earlier: bool, navigate_earlier: bool,
state: LogEventsScrollState, from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]: ) -> Tuple[Sequence[dict], int]:
""" """
Return up to 'batch size' events starting from the previous timestamp either in the Return up to 'batch size' events starting from the previous timestamp either in the
@ -111,29 +63,21 @@ class LogEventsIterator:
# retrieve the next batch of events # retrieve the next batch of events
es_req = { es_req = {
"size": batch_size, "size": batch_size,
"query": {"term": {"task": state.task}}, "query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"}, "sort": {"timestamp": "desc" if navigate_earlier else "asc"},
} }
if navigate_earlier and state.last_min_timestamp is not None: if from_timestamp:
es_req["search_after"] = [state.last_min_timestamp] es_req["search_after"] = [from_timestamp]
elif not navigate_earlier and state.last_max_timestamp is not None:
es_req["search_after"] = [state.last_max_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=state.task) es_result = self.es.search(index=es_index, body=es_req, routing=task_id)
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"] hits_total = es_result["hits"]["total"]
if not hits: if not hits:
return [], hits_total return [], hits_total
events = [hit["_source"] for hit in hits] events = [hit["_source"] for hit in hits]
if navigate_earlier:
state.last_max_timestamp = events[0]["timestamp"]
state.last_min_timestamp = events[-1]["timestamp"]
else:
state.last_min_timestamp = events[0]["timestamp"]
state.last_max_timestamp = events[-1]["timestamp"]
# retrieve the events that match the last event timestamp # retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation # but did not make it into the previous call due to batch_size limitation
@ -142,13 +86,13 @@ class LogEventsIterator:
"query": { "query": {
"bool": { "bool": {
"must": [ "must": [
{"term": {"task": state.task}}, {"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}}, {"term": {"timestamp": events[-1]["timestamp"]}},
] ]
} }
}, },
} }
es_result = self.es.search(index=es_index, body=es_req, routing=state.task) es_result = self.es.search(index=es_index, body=es_req, routing=task_id)
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2: if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp # if only one element is returned for the last timestamp

View File

@ -530,59 +530,51 @@
} }
} }
} }
// "2.7" { "2.9" {
// description: "Get 'log' events for this task" description: "Get 'log' events for this task"
// request { request {
// type: object type: object
// required: [ required: [
// task task
// ] ]
// properties { properties {
// task { task {
// type: string type: string
// description: "Task ID" description: "Task ID"
// } }
// batch_size { batch_size {
// type: integer type: integer
// description: "The amount of log events to return" description: "The amount of log events to return"
// } }
// navigate_earlier { navigate_earlier {
// type: boolean type: boolean
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True" description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
// } }
// refresh { from_timestamp {
// type: boolean type: number
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)" description: "Epoch time in UTC ms to use as the navigation start"
// } }
// scroll_id { }
// type: string }
// description: "Scroll ID of previous call (used for getting more results)" response {
// } type: object
// } properties {
// } events {
// response { type: array
// type: object items { type: object }
// properties { description: "Log items list"
// events { }
// type: array returned {
// items { type: object } type: integer
// description: "Log items list" description: "Number of log events returned"
// } }
// returned { total {
// type: integer type: number
// description: "Number of log events returned" description: "Total number of log events available for this query"
// } }
// total { }
// type: number }
// description: "Total number of log events available for this query" }
// }
// scroll_id {
// type: string
// description: "Scroll ID for getting more results"
// }
// }
// }
// }
} }
get_task_events { get_task_events {
"2.1" { "2.1" {

View File

@ -94,27 +94,24 @@ def get_task_log_v1_7(call, company_id, req_model):
) )
# uncomment this once the front end is ready @endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest)
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest) def get_task_log(call, company_id, req_model: LogEventsRequest):
# def get_task_log(call, company_id, req_model: LogEventsRequest): task_id = req_model.task
# task_id = req_model.task task_bll.assert_exists(company_id, task_id, allow_public=True)
# task_bll.assert_exists(company_id, task_id, allow_public=True)
# res = event_bll.log_events_iterator.get_task_events(
# res = event_bll.log_events_iterator.get_task_events( company_id=company_id,
# company_id=company_id, task_id=task_id,
# task_id=task_id, batch_size=req_model.batch_size,
# batch_size=req_model.batch_size, navigate_earlier=req_model.navigate_earlier,
# navigate_earlier=req_model.navigate_earlier, from_timestamp=req_model.from_timestamp,
# refresh=req_model.refresh, )
# state_id=req_model.scroll_id,
# ) call.result.data = dict(
# events=res.events,
# call.result.data = dict( returned=len(res.events),
# events=res.events, total=res.total_events,
# returned=len(res.events), )
# total=res.total_events,
# scroll_id=res.next_scroll_id,
# )
@endpoint("events.download_task_log", required_fields=["task"]) @endpoint("events.download_task_log", required_fields=["task"])

View File

@ -6,7 +6,7 @@ import operator
import unittest import unittest
from functools import partial from functools import partial
from statistics import mean from statistics import mean
from typing import Sequence from typing import Sequence, Optional, Tuple
from boltons.iterutils import first from boltons.iterutils import first
@ -16,7 +16,7 @@ from tests.automated import TestService
class TestTaskEvents(TestService): class TestTaskEvents(TestService):
def setUp(self, version="2.7"): def setUp(self, version="2.9"):
super().setUp(version=version) super().setUp(version=version)
def _temp_task(self, name="test task events"): def _temp_task(self, name="test task events"):
@ -213,7 +213,6 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res.events), 1) self.assertEqual(len(res.events), 1)
def test_task_logs(self): def test_task_logs(self):
# this test will fail until the new api is uncommented
task = self._temp_task() task = self._temp_task()
timestamp = es_factory.get_timestamp_millis() timestamp = es_factory.get_timestamp_millis()
events = [ events = [
@ -229,32 +228,28 @@ class TestTaskEvents(TestService):
self.send_batch(events) self.send_batch(events)
# test forward navigation # test forward navigation
scroll_id = None ftime, ltime = None, None
for page in range(3): for page in range(2):
scroll_id = self._assert_log_events( ftime, ltime = self._assert_log_events(
task=task, scroll_id=scroll_id, expected_page=page task=task, timestamp=ltime, expected_page=page
) )
# test backwards navigation # test backwards navigation
scroll_id = self._assert_log_events( self._assert_log_events(
task=task, scroll_id=scroll_id, navigate_earlier=False task=task, timestamp=ftime, navigate_earlier=False
) )
# refresh
self._assert_log_events(task=task, scroll_id=scroll_id)
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
def _assert_log_events( def _assert_log_events(
self, self,
task, task,
scroll_id,
batch_size: int = 5, batch_size: int = 5,
timestamp: Optional[int] = None,
expected_total: int = 10, expected_total: int = 10,
expected_page: int = 0, expected_page: int = 0,
**extra_params, **extra_params,
): ) -> Tuple[int, int]:
res = self.api.events.get_task_log( res = self.api.events.get_task_log(
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params, task=task, batch_size=batch_size, from_timestamp=timestamp, **extra_params,
) )
self.assertEqual(res.total, expected_total) self.assertEqual(res.total, expected_total)
expected_events = max( expected_events = max(
@ -274,7 +269,8 @@ class TestTaskEvents(TestService):
for first, second in zip(res.events, res.events[1:]) for first, second in zip(res.events, res.events[1:])
) )
) )
return res.scroll_id
return (res.events[0].timestamp, res.events[-1].timestamp) if res.events else (None, None)
def test_task_metric_value_intervals_keys(self): def test_task_metric_value_intervals_keys(self):
metric = "Metric1" metric = "Metric1"

View File

@ -1,7 +1,6 @@
from typing import Sequence from typing import Sequence
from uuid import uuid4 from uuid import uuid4
from apierrors import errors
from config import config from config import config
from tests.automated import TestService from tests.automated import TestService