Add events.get_task_log for improved log retrieval support

This commit is contained in:
allegroai 2020-07-06 21:54:25 +03:00
parent 8219e3d4e2
commit 21f2ea8b17
8 changed files with 99 additions and 172 deletions

View File

@ -1 +1 @@
__version__ = "2.8.0"
__version__ = "2.9.0"

View File

@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
@ -44,8 +44,7 @@ class LogEventsRequest(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
from_timestamp: Optional[int] = IntField()
class IterationEvents(Base):

View File

@ -40,7 +40,7 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:

View File

@ -2,30 +2,12 @@ from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from timing_context import TimingContext
class LogEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
task: str = StringField(required=True)
last_min_timestamp: Optional[int] = IntField()
last_max_timestamp: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state """
self.last_min_timestamp = self.last_max_timestamp = None
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
@ -36,19 +18,8 @@ class TaskEventsResult:
class LogEventsIterator:
EVENT_TYPE = "log"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
def __init__(self, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=LogEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
self,
@ -56,48 +27,29 @@ class LogEventsIterator:
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
def init_state(state_: LogEventsScrollState):
state_.task = task_id
def validate_state(state_: LogEventsScrollState):
"""
Checks that the task id stored in the state
is equal to the one passed with the current call
Refresh the state if requested
"""
if state_.task != task_id:
raise errors.bad_request.InvalidScrollId(
"Task stored in the state does not match the passed one",
scroll_id=state_.id,
)
if refresh:
state_.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res = TaskEventsResult(next_scroll_id=state.id)
res.events, res.total_events = self._get_events(
es_index=es_index,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
state=state,
)
return res
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
es_index=es_index,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
es_index,
task_id: str,
batch_size: int,
navigate_earlier: bool,
state: LogEventsScrollState,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
@ -111,29 +63,21 @@ class LogEventsIterator:
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": state.task}},
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if navigate_earlier and state.last_min_timestamp is not None:
es_req["search_after"] = [state.last_min_timestamp]
elif not navigate_earlier and state.last_max_timestamp is not None:
es_req["search_after"] = [state.last_max_timestamp]
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
es_result = self.es.search(index=es_index, body=es_req, routing=task_id)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
if navigate_earlier:
state.last_max_timestamp = events[0]["timestamp"]
state.last_min_timestamp = events[-1]["timestamp"]
else:
state.last_min_timestamp = events[0]["timestamp"]
state.last_max_timestamp = events[-1]["timestamp"]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
@ -142,13 +86,13 @@ class LogEventsIterator:
"query": {
"bool": {
"must": [
{"term": {"task": state.task}},
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
es_result = self.es.search(index=es_index, body=es_req, routing=task_id)
hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp

View File

@ -530,59 +530,51 @@
}
}
}
// "2.7" {
// description: "Get 'log' events for this task"
// request {
// type: object
// required: [
// task
// ]
// properties {
// task {
// type: string
// description: "Task ID"
// }
// batch_size {
// type: integer
// description: "The amount of log events to return"
// }
// navigate_earlier {
// type: boolean
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
// }
// refresh {
// type: boolean
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
// }
// scroll_id {
// type: string
// description: "Scroll ID of previous call (used for getting more results)"
// }
// }
// }
// response {
// type: object
// properties {
// events {
// type: array
// items { type: object }
// description: "Log items list"
// }
// returned {
// type: integer
// description: "Number of log events returned"
// }
// total {
// type: number
// description: "Total number of log events available for this query"
// }
// scroll_id {
// type: string
// description: "Scroll ID for getting more results"
// }
// }
// }
// }
"2.9" {
description: "Get 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
batch_size {
type: integer
description: "The amount of log events to return"
}
navigate_earlier {
type: boolean
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
}
from_timestamp {
type: number
description: "Epoch time in UTC ms to use as the navigation start"
}
}
}
response {
type: object
properties {
events {
type: array
items { type: object }
description: "Log items list"
}
returned {
type: integer
description: "Number of log events returned"
}
total {
type: number
description: "Total number of log events available for this query"
}
}
}
}
}
get_task_events {
"2.1" {

View File

@ -94,27 +94,24 @@ def get_task_log_v1_7(call, company_id, req_model):
)
# uncomment this once the front end is ready
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
# def get_task_log(call, company_id, req_model: LogEventsRequest):
# task_id = req_model.task
# task_bll.assert_exists(company_id, task_id, allow_public=True)
#
# res = event_bll.log_events_iterator.get_task_events(
# company_id=company_id,
# task_id=task_id,
# batch_size=req_model.batch_size,
# navigate_earlier=req_model.navigate_earlier,
# refresh=req_model.refresh,
# state_id=req_model.scroll_id,
# )
#
# call.result.data = dict(
# events=res.events,
# returned=len(res.events),
# total=res.total_events,
# scroll_id=res.next_scroll_id,
# )
@endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest)
def get_task_log(call, company_id, req_model: LogEventsRequest):
task_id = req_model.task
task_bll.assert_exists(company_id, task_id, allow_public=True)
res = event_bll.log_events_iterator.get_task_events(
company_id=company_id,
task_id=task_id,
batch_size=req_model.batch_size,
navigate_earlier=req_model.navigate_earlier,
from_timestamp=req_model.from_timestamp,
)
call.result.data = dict(
events=res.events,
returned=len(res.events),
total=res.total_events,
)
@endpoint("events.download_task_log", required_fields=["task"])

View File

@ -6,7 +6,7 @@ import operator
import unittest
from functools import partial
from statistics import mean
from typing import Sequence
from typing import Sequence, Optional, Tuple
from boltons.iterutils import first
@ -16,7 +16,7 @@ from tests.automated import TestService
class TestTaskEvents(TestService):
def setUp(self, version="2.7"):
def setUp(self, version="2.9"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):
@ -213,7 +213,6 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res.events), 1)
def test_task_logs(self):
# this test will fail until the new api is uncommented
task = self._temp_task()
timestamp = es_factory.get_timestamp_millis()
events = [
@ -229,32 +228,28 @@ class TestTaskEvents(TestService):
self.send_batch(events)
# test forward navigation
scroll_id = None
for page in range(3):
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, expected_page=page
ftime, ltime = None, None
for page in range(2):
ftime, ltime = self._assert_log_events(
task=task, timestamp=ltime, expected_page=page
)
# test backwards navigation
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, navigate_earlier=False
self._assert_log_events(
task=task, timestamp=ftime, navigate_earlier=False
)
# refresh
self._assert_log_events(task=task, scroll_id=scroll_id)
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
def _assert_log_events(
self,
task,
scroll_id,
batch_size: int = 5,
timestamp: Optional[int] = None,
expected_total: int = 10,
expected_page: int = 0,
**extra_params,
):
) -> Tuple[int, int]:
res = self.api.events.get_task_log(
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
task=task, batch_size=batch_size, from_timestamp=timestamp, **extra_params,
)
self.assertEqual(res.total, expected_total)
expected_events = max(
@ -274,7 +269,8 @@ class TestTaskEvents(TestService):
for first, second in zip(res.events, res.events[1:])
)
)
return res.scroll_id
return (res.events[0].timestamp, res.events[-1].timestamp) if res.events else (None, None)
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"

View File

@ -1,7 +1,6 @@
from typing import Sequence
from uuid import uuid4
from apierrors import errors
from config import config
from tests.automated import TestService