Improve internal events implementation

This commit is contained in:
allegroai 2021-01-05 18:20:38 +02:00
parent e2deff4eef
commit 6974aa3a99
8 changed files with 160 additions and 109 deletions

View File

@ -7,7 +7,7 @@ from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max
from apiserver.apimodels import ListField, IntField, ActualEnumField
from apiserver.bll.event.event_metrics import EventType
from apiserver.bll.event.event_common import EventType
from apiserver.bll.event.scalar_key import ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum

View File

@ -15,9 +15,12 @@ from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
@ -58,22 +61,12 @@ class DebugImagesResult(object):
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@property
def _max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_task_events(
@ -85,13 +78,12 @@ class DebugImagesIterator:
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState):
"""
@ -106,7 +98,7 @@ class DebugImagesIterator:
scroll_id=state_.id,
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state_)
self._reinit_outdated_metric_states(company_id, state_)
for metric_state in state_.metrics:
metric_state.reset()
@ -114,12 +106,12 @@ class DebugImagesIterator:
state_id=state_id, init_state=init_state, validate_state=validate_state
) as state:
res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(self._max_workers) as pool:
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
es_index=es_index,
company_id=company_id,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
@ -130,7 +122,7 @@ class DebugImagesIterator:
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState
self, company_id, state: DebugImageEventsScrollState
):
"""
Determines the metrics for which new debug image events were added
@ -171,14 +163,14 @@ class DebugImagesIterator:
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
es_index,
company_id,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
]
def _init_metric_states(
self, es_index, metrics: Sequence[Tuple[str, str]]
self, company_id: str, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
@ -187,18 +179,20 @@ class DebugImagesIterator:
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(self._max_workers) as pool:
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(self._init_metric_states_for_task, es_index=es_index),
partial(
self._init_metric_states_for_task, company_id=company_id
),
tasks.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], es_index
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
@ -220,7 +214,7 @@ class DebugImagesIterator:
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
@ -228,7 +222,7 @@ class DebugImagesIterator:
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
@ -258,7 +252,13 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task,
)
if "aggregations" not in es_res:
return []
@ -290,7 +290,7 @@ class DebugImagesIterator:
def _get_task_metric_events(
self,
metric: MetricScrollState,
es_index: str,
company_id: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
@ -382,7 +382,7 @@ class DebugImagesIterator:
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
@ -396,7 +396,13 @@ class DebugImagesIterator:
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=metric.task,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []

View File

@ -10,9 +10,8 @@ from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.event_common import EventSettings, get_index_name
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
@ -47,18 +46,12 @@ class DebugSampleHistoryResult(object):
class DebugSampleHistory:
EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
redis=redis,
expiration_interval=self.state_expiration_sec,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_debug_image(
@ -73,7 +66,7 @@ class DebugSampleHistory:
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
es_index = get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return res
@ -219,7 +212,7 @@ class DebugSampleHistory:
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
es_index = get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return res
@ -247,6 +240,9 @@ class DebugSampleHistory:
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
@ -291,9 +287,7 @@ class DebugSampleHistory:
es_index=es_index, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
for var_name, min_iter, max_iter in variant_iterations
]
@ -324,7 +318,7 @@ class DebugSampleHistory:
# all variants that sent debug images
"terms": {
"field": "variant",
"size": EventMetrics.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {

View File

@ -13,12 +13,13 @@ from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics, EventType
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
@ -147,7 +148,7 @@ class EventBLL(object):
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
index_name = EventMetrics.get_index_name(company_id, event_type)
index_name = get_index_name(company_id, event_type)
es_action = {
"_op_type": "index", # overwrite if exists with same ID
"_index": index_name,
@ -406,7 +407,7 @@ class EventBLL(object):
if event_type is None:
event_type = "*"
es_index = EventMetrics.get_index_name(company_id, event_type)
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return [], None, 0
@ -435,14 +436,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
@ -494,7 +495,7 @@ class EventBLL(object):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = "plot"
es_index = EventMetrics.get_index_name(company_id, event_type)
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
@ -597,7 +598,7 @@ class EventBLL(object):
if event_type is None:
event_type = "*"
es_index = EventMetrics.get_index_name(company_id, event_type)
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
@ -652,7 +653,7 @@ class EventBLL(object):
def get_metrics_and_variants(self, company_id, task_id, event_type):
es_index = EventMetrics.get_index_name(company_id, event_type)
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return {}
@ -663,14 +664,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
}
}
@ -695,7 +696,7 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id, task_id):
es_index = EventMetrics.get_index_name(company_id, "training_stats_scalar")
es_index = get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
@ -714,14 +715,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
@ -780,7 +781,7 @@ class EventBLL(object):
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
es_index = EventMetrics.get_index_name(company_id, "training_stats_vector")
es_index = get_index_name(company_id, "training_stats_vector")
if not self.es.indices.exists(es_index):
return [], []
@ -849,7 +850,7 @@ class EventBLL(object):
extra_msg, company=company_id, id=task_id
)
es_index = EventMetrics.get_index_name(company_id, "*")
es_index = get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)

View File

@ -0,0 +1,65 @@
from enum import Enum
from typing import Union, Sequence
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventSettings:
@classproperty
def max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@classproperty
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
def get_index_name(company_id: str, event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool:
es_index = get_index_name(company_id, event_type)
if not es.indices.exists(es_index):
return True
return False
def search_company_events(
es: Elasticsearch,
company_id: Union[str, Sequence[str]],
event_type: str,
body: dict,
**kwargs,
) -> dict:
es_index = get_index_name(company_id, event_type)
return es.search(index=es_index, body=body, **kwargs)
def delete_company_events(
es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type)
return es.delete_by_query(index=es_index, body=body, **kwargs)

View File

@ -2,16 +2,15 @@ import itertools
import math
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from enum import Enum
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@ -22,14 +21,6 @@ from apiserver.tools import safe_get
log = config.logger(__file__)
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventMetrics:
MAX_AGGS_ELEMENTS_COUNT = 50
MAX_SAMPLE_BUCKETS = 6000
@ -37,23 +28,6 @@ class EventMetrics:
def __init__(self, es: Elasticsearch):
self.es = es
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
@property
def _max_concurrency(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
) -> dict:
@ -62,7 +36,7 @@ class EventMetrics:
The amount of points in each histogram should not exceed
the requested samples
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
es_index = get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
@ -89,7 +63,7 @@ class EventMetrics:
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
)
if run_parallel:
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
metrics = itertools.chain.from_iterable(
pool.map(get_scalar_average, interval_groups)
)
@ -136,7 +110,7 @@ class EventMetrics:
"only tasks from the same company are supported"
)
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
es_index = get_index_name(next(iter(companies)), "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
@ -147,7 +121,7 @@ class EventMetrics:
key=ScalarKey.resolve(key),
run_parallel=False,
)
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
)
@ -216,14 +190,14 @@ class EventMetrics:
"metrics": {
"terms": {
"field": "metric",
"size": self.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": {
@ -293,14 +267,14 @@ class EventMetrics:
"metrics": {
"terms": {
"field": "metric",
"size": self.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.max_variants_count,
"size": EventSettings.max_variants_count,
"order": {"_key": "asc"},
},
"aggs": aggregation,
@ -382,11 +356,11 @@ class EventMetrics:
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = EventMetrics.get_index_name(company_id, event_type.value)
es_index = get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
return {}
with ThreadPoolExecutor(self._max_concurrency) as pool:
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
@ -410,7 +384,7 @@ class EventMetrics:
"metrics": {
"terms": {
"field": "metric",
"size": self.max_metrics_count,
"size": EventSettings.max_metrics_count,
"order": {"_key": "asc"},
}
}

View File

@ -3,7 +3,7 @@ from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.event_common import check_empty_data, search_company_events
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@ -29,13 +29,12 @@ class LogEventsIterator:
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
es_index=es_index,
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
@ -45,7 +44,7 @@ class LogEventsIterator:
def _get_events(
self,
es_index,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
@ -71,7 +70,13 @@ class LogEventsIterator:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req)
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
@ -92,7 +97,13 @@ class LogEventsIterator:
}
},
}
es_result = self.es.search(index=es_index, body=es_req)
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp

View File

@ -19,7 +19,7 @@ from apiserver.apimodels.events import (
NextDebugImageSampleRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.event_common import get_index_name
from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
@ -321,7 +321,7 @@ def get_task_latest_scalar_values(call, company_id, _):
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
index_company, task_id
)
es_index = EventMetrics.get_index_name(index_company, "*")
es_index = get_index_name(index_company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,