Support task log retrieval with no scroll

This commit is contained in:
allegroai 2020-06-01 11:27:36 +03:00
parent d9bdebefc7
commit 35a11db58e
12 changed files with 361 additions and 93 deletions

View File

@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
from validators import email as email_validator, domain as domain_validator
from apierrors import errors
from utilities.json import loads, dumps
def make_default(field_cls, default_value):
@ -213,3 +214,12 @@ class StringEnum(Enum):
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name
class JsonSerializableMixin:
def to_json(self: ModelBase):
return dumps(self.to_struct())
@classmethod
def from_json(cls: Type[ModelBase], s):
return cls(**loads(s))

View File

@ -40,6 +40,14 @@ class DebugImagesRequest(Base):
scroll_id: str = StringField()
class LogEventsRequest(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)

View File

@ -1,4 +1,3 @@
import json
from enum import Enum
import six
@ -13,7 +12,7 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apimodels import make_default, ListField, EnumField
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
@ -61,7 +60,7 @@ class IdNameEntry(Base):
name = StringField()
class WorkerEntry(Base):
class WorkerEntry(Base, JsonSerializableMixin):
key = StringField() # not required due to migration issues
id = StringField(required=True)
user = EmbeddedField(IdNameEntry)
@ -75,13 +74,6 @@ class WorkerEntry(Base):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
def to_json(self):
return json.dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**json.loads(s))
class CurrentTaskEntry(IdNameEntry):
running_time = IntField()

View File

@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from typing import Sequence, Tuple, Optional, Mapping
import database
from apierrors import errors
from bll.redis_cache_manager import RedisCacheManager
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from jsonmodels.models import Base
from jsonmodels.fields import StringField, ListField, IntField
from database.model.task.metrics import MetricEventStats
from database.model.task.task import Task
from timing_context import TimingContext
from utilities.json import loads, dumps
class VariantScrollState(Base):
@ -45,17 +43,10 @@ class MetricScrollState(Base):
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base):
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
def to_json(self):
return dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**loads(s))
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
@ -65,7 +56,12 @@ class DebugImagesResult(object):
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
STATE_EXPIRATION_SECONDS = 3600
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@property
def _max_workers(self):
@ -76,7 +72,7 @@ class DebugImagesIterator:
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.STATE_EXPIRATION_SECONDS,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
@ -92,27 +88,31 @@ class DebugImagesIterator:
if not self.es.indices.exists(es_index):
return DebugImagesResult()
unique_metrics = set(metrics)
state = self.cache_manager.get_state(state_id) if state_id else None
if not state:
state = DebugImageEventsScrollState(
id=database.utils.id(),
metrics=self._init_metric_states(es_index, list(unique_metrics)),
)
else:
state_metrics = set((m.task, m.name) for m in state.metrics)
if state_metrics != unique_metrics:
raise errors.bad_request.InvalidScrollId(
"while getting debug images events", scroll_id=state_id
)
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState):
"""
Validate that the metrics stored in the state are the same
as requested in the current call.
Refresh the state if requested
"""
state_metrics = set((m.task, m.name) for m in state_.metrics)
if state_metrics != set(metrics):
raise errors.bad_request.InvalidScrollId(
"Task metrics stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state)
for metric_state in state.metrics:
self._reinit_outdated_metric_states(company_id, es_index, state_)
for metric_state in state_.metrics:
metric_state.reset()
res = DebugImagesResult(next_scroll_id=state.id)
try:
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(self._max_workers) as pool:
res.metric_events = list(
pool.map(
@ -125,10 +125,8 @@ class DebugImagesIterator:
state.metrics,
)
)
finally:
self.cache_manager.set_state(state)
return res
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState

View File

@ -5,7 +5,6 @@ from datetime import datetime
from operator import attrgetter
from typing import Sequence
import attr
import six
from elasticsearch import helpers
from mongoengine import Q
@ -16,6 +15,7 @@ import es_factory
from apierrors import errors
from bll.event.debug_images_iterator import DebugImagesIterator
from bll.event.event_metrics import EventMetrics, EventType
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
@ -29,13 +29,6 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s(auto_attribs=True)
class TaskEventsResult(object):
total_events: int = 0
next_scroll_id: str = None
events: list = attr.ib(factory=list)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
@ -47,6 +40,7 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
@property
def metrics(self) -> EventMetrics:

View File

@ -0,0 +1,169 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from timing_context import TimingContext
class LogEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
task: str = StringField(required=True)
last_min_timestamp: Optional[int] = IntField()
last_max_timestamp: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state """
self.last_min_timestamp = self.last_max_timestamp = None
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = "log"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=LogEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
def init_state(state_: LogEventsScrollState):
state_.task = task_id
def validate_state(state_: LogEventsScrollState):
"""
Checks that the task id stored in the state
is equal to the one passed with the current call
Refresh the state if requested
"""
if state_.task != task_id:
raise errors.bad_request.InvalidScrollId(
"Task stored in the state does not match the passed one",
scroll_id=state_.id,
)
if refresh:
state_.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res = TaskEventsResult(next_scroll_id=state.id)
res.events, res.total_events = self._get_events(
es_index=es_index,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
state=state,
)
return res
def _get_events(
self,
es_index,
batch_size: int,
navigate_earlier: bool,
state: LogEventsScrollState,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": state.task}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if navigate_earlier and state.last_min_timestamp is not None:
es_req["search_after"] = [state.last_min_timestamp]
elif not navigate_earlier and state.last_max_timestamp is not None:
es_req["search_after"] = [state.last_max_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
if navigate_earlier:
state.last_max_timestamp = events[0]["timestamp"]
state.last_min_timestamp = events[-1]["timestamp"]
else:
state.last_min_timestamp = events[0]["timestamp"]
state.last_max_timestamp = events[-1]["timestamp"]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": state.task}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
already_present_ids = set(ev["_id"] for ev in events)
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[
*events,
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
],
hits_total,
)

View File

@ -1,15 +1,21 @@
from typing import Optional, TypeVar, Generic, Type
from contextlib import contextmanager
from typing import Optional, TypeVar, Generic, Type, Callable
from redis import StrictRedis
import database
from timing_context import TimingContext
T = TypeVar("T")
def _do_nothing(_: T):
return
class RedisCacheManager(Generic[T]):
"""
Class for store/retreive of state objects from redis
Class for store/retrieve of state objects from redis
self.state_class - class of the state
self.redis - instance of redis
@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
@contextmanager
def get_or_create_state(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
):
"""
Try to retrieve state with the given id from the Redis cache if yes then validates it
If no then create a new one with randomly generated id
Yield the state and write it back to redis once the user code block exits
:param state_id: id of the state to retrieve
:param init_state: user callback to init the newly created state
If not passed then no init except for the id generation is done
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
try:
yield state
finally:
self.set_state(state)

View File

@ -6,4 +6,8 @@ ignore_iteration {
# max number of concurrent queries to ES when calculating events metrics
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4
max_metrics_concurrency: 4
events_retrieval {
state_expiration_sec: 3600
}

View File

@ -258,7 +258,6 @@
properties {
added { type: integer }
errors { type: integer }
errors_info { type: object }
}
}
}

View File

@ -11,6 +11,7 @@ from apimodels.events import (
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
@ -45,7 +46,7 @@ def add_batch(call: APICall, company_id, req_model):
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log(call, company_id, req_model):
def get_task_log_v1_5(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
@ -93,6 +94,28 @@ def get_task_log_v1_7(call, company_id, req_model):
)
@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
def get_task_log(call, company_id, req_model: LogEventsRequest):
task_id = req_model.task
task_bll.assert_exists(company_id, task_id, allow_public=True)
res = event_bll.log_events_iterator.get_task_events(
company_id=company_id,
task_id=task_id,
batch_size=req_model.batch_size,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
)
call.result.data = dict(
events=res.events,
returned=len(res.events),
total=res.total_events,
scroll_id=res.next_scroll_id,
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, req_model):
task_id = call.data["task"]

View File

@ -210,7 +210,7 @@ def get_all_ex(call: APICall):
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for result in Task.aggregate(*status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
@ -224,7 +224,7 @@ def get_all_ex(call: APICall):
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
for result in Task.aggregate(*runtime_pipeline)
}
def safe_get(obj, path, default=None):

View File

@ -2,7 +2,7 @@
Comprehensive test of all(?) use cases of datasets and frames
"""
import json
import time
import operator
import unittest
from functools import partial
from statistics import mean
@ -22,21 +22,17 @@ class TestTaskEvents(TestService):
)
return self.create_temp("tasks", **task_input)
def _create_task_event(self, type_, task, iteration, **kwargs):
@staticmethod
def _create_task_event(type_, task, iteration, **kwargs):
return {
"worker": "test",
"type": type_,
"task": task,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis(),
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
**kwargs,
}
def _copy_and_update(self, src_obj, new_data):
obj = src_obj.copy()
obj.update(new_data)
return obj
def test_task_metrics(self):
tasks = {
self._temp_task(): {
@ -83,8 +79,7 @@ class TestTaskEvents(TestService):
# test empty
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics)
@ -116,11 +111,11 @@ class TestTaskEvents(TestService):
# test forward navigation
for page in range(3):
scroll_id = assert_debug_images(scroll_id=scroll_id, page=page)
scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page)
# test backwards navigation
scroll_id = assert_debug_images(
scroll_id=scroll_id, page=0, navigate_earlier=False
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
)
# beyond the latest iteration and back
@ -131,10 +126,10 @@ class TestTaskEvents(TestService):
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_debug_images(scroll_id=scroll_id, page=1)
assert_debug_images(scroll_id=scroll_id, expected_page=1)
# refresh
assert_debug_images(scroll_id=scroll_id, page=0, refresh=True)
assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True)
def _assertDebugImages(
self,
@ -143,7 +138,7 @@ class TestTaskEvents(TestService):
max_iter: int,
unique_images: Sequence[int],
scroll_id,
page: int,
expected_page: int,
iters: int = 5,
**extra_params,
):
@ -156,7 +151,7 @@ class TestTaskEvents(TestService):
data = res["metrics"][0]
self.assertEqual(data["task"], task)
self.assertEqual(data["metric"], metric)
left_iterations = max(0, max(unique_images) - page * iters)
left_iterations = max(0, max(unique_images) - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
events_per_iter = sum(
@ -166,26 +161,67 @@ class TestTaskEvents(TestService):
return res.scroll_id
def test_task_logs(self):
events = []
task = self._temp_task()
for iter_ in range(10):
log_event = self._create_task_event("log", task, iteration=iter_)
events.append(
self._copy_and_update(
log_event,
{"msg": "This is a log message from test task iter " + str(iter_)},
)
timestamp = es_factory.get_timestamp_millis()
events = [
self._create_task_event(
"log",
task=task,
iteration=iter_,
timestamp=timestamp + iter_ * 1000,
msg=f"This is a log message from test task iter {iter_}",
)
# sleep so timestamp is not the same
time.sleep(0.01)
for iter_ in range(10)
]
self.send_batch(events)
data = self.api.events.get_task_log(task=task)
assert len(data["events"]) == 10
# test forward navigation
scroll_id = None
for page in range(3):
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, expected_page=page
)
self.api.tasks.reset(task=task)
data = self.api.events.get_task_log(task=task)
assert len(data["events"]) == 0
# test backwards navigation
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, navigate_earlier=False
)
# refresh
self._assert_log_events(task=task, scroll_id=scroll_id)
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
def _assert_log_events(
self,
task,
scroll_id,
batch_size: int = 5,
expected_total: int = 10,
expected_page: int = 0,
**extra_params,
):
res = self.api.events.get_task_log(
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
)
self.assertEqual(res.total, expected_total)
expected_events = max(
0, batch_size - max(0, (expected_page + 1) * batch_size - expected_total)
)
self.assertEqual(res.returned, expected_events)
self.assertEqual(len(res.events), expected_events)
unique_events = len({ev.iter for ev in res.events})
self.assertEqual(len(res.events), unique_events)
if res.events:
cmp_operator = operator.ge
if not extra_params.get("navigate_earlier", True):
cmp_operator = operator.le
self.assertTrue(
all(
cmp_operator(first.timestamp, second.timestamp)
for first, second in zip(res.events, res.events[1:])
)
)
return res.scroll_id
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"