Improve internal events implementation

This commit is contained in:
allegroai 2021-01-05 18:20:38 +02:00
parent e2deff4eef
commit 6974aa3a99
8 changed files with 160 additions and 109 deletions

View File

@ -7,7 +7,7 @@ from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max from jsonmodels.validators import Length, Min, Max
from apiserver.apimodels import ListField, IntField, ActualEnumField from apiserver.apimodels import ListField, IntField, ActualEnumField
from apiserver.bll.event.event_metrics import EventType from apiserver.bll.event.event_common import EventType
from apiserver.bll.event.scalar_key import ScalarKeyEnum from apiserver.bll.event.scalar_key import ScalarKeyEnum
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.utilities.stringenum import StringEnum from apiserver.utilities.stringenum import StringEnum

View File

@ -15,9 +15,12 @@ from redis import StrictRedis
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_metrics import EventMetrics from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task
@ -58,22 +61,12 @@ class DebugImagesResult(object):
class DebugImagesIterator: class DebugImagesIterator:
EVENT_TYPE = "training_debug_image" EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@property
def _max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
def __init__(self, redis: StrictRedis, es: Elasticsearch): def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es self.es = es
self.cache_manager = RedisCacheManager( self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState, state_class=DebugImageEventsScrollState,
redis=redis, redis=redis,
expiration_interval=self.state_expiration_sec, expiration_interval=EventSettings.state_expiration_sec,
) )
def get_task_events( def get_task_events(
@ -85,13 +78,12 @@ class DebugImagesIterator:
refresh: bool = False, refresh: bool = False,
state_id: str = None, state_id: str = None,
) -> DebugImagesResult: ) -> DebugImagesResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) if check_empty_data(self.es, company_id, self.EVENT_TYPE):
if not self.es.indices.exists(es_index):
return DebugImagesResult() return DebugImagesResult()
def init_state(state_: DebugImageEventsScrollState): def init_state(state_: DebugImageEventsScrollState):
unique_metrics = set(metrics) unique_metrics = set(metrics)
state_.metrics = self._init_metric_states(es_index, list(unique_metrics)) state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
def validate_state(state_: DebugImageEventsScrollState): def validate_state(state_: DebugImageEventsScrollState):
""" """
@ -106,7 +98,7 @@ class DebugImagesIterator:
scroll_id=state_.id, scroll_id=state_.id,
) )
if refresh: if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state_) self._reinit_outdated_metric_states(company_id, state_)
for metric_state in state_.metrics: for metric_state in state_.metrics:
metric_state.reset() metric_state.reset()
@ -114,12 +106,12 @@ class DebugImagesIterator:
state_id=state_id, init_state=init_state, validate_state=validate_state state_id=state_id, init_state=init_state, validate_state=validate_state
) as state: ) as state:
res = DebugImagesResult(next_scroll_id=state.id) res = DebugImagesResult(next_scroll_id=state.id)
with ThreadPoolExecutor(self._max_workers) as pool: with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res.metric_events = list( res.metric_events = list(
pool.map( pool.map(
partial( partial(
self._get_task_metric_events, self._get_task_metric_events,
es_index=es_index, company_id=company_id,
iter_count=iter_count, iter_count=iter_count,
navigate_earlier=navigate_earlier, navigate_earlier=navigate_earlier,
), ),
@ -130,7 +122,7 @@ class DebugImagesIterator:
return res return res
def _reinit_outdated_metric_states( def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState self, company_id, state: DebugImageEventsScrollState
): ):
""" """
Determines the metrics for which new debug image events were added Determines the metrics for which new debug image events were added
@ -171,14 +163,14 @@ class DebugImagesIterator:
*(metric for metric in state.metrics if metric not in outdated_metrics), *(metric for metric in state.metrics if metric not in outdated_metrics),
*( *(
self._init_metric_states( self._init_metric_states(
es_index, company_id,
[(metric.task, metric.name) for metric in outdated_metrics], [(metric.task, metric.name) for metric in outdated_metrics],
) )
), ),
] ]
def _init_metric_states( def _init_metric_states(
self, es_index, metrics: Sequence[Tuple[str, str]] self, company_id: str, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]: ) -> Sequence[MetricScrollState]:
""" """
Returned initialized metric scroll stated for the requested task metrics Returned initialized metric scroll stated for the requested task metrics
@ -187,18 +179,20 @@ class DebugImagesIterator:
for (task, metric) in metrics: for (task, metric) in metrics:
tasks[task].append(metric) tasks[task].append(metric)
with ThreadPoolExecutor(self._max_workers) as pool: with ThreadPoolExecutor(EventSettings.max_workers) as pool:
return list( return list(
chain.from_iterable( chain.from_iterable(
pool.map( pool.map(
partial(self._init_metric_states_for_task, es_index=es_index), partial(
self._init_metric_states_for_task, company_id=company_id
),
tasks.items(), tasks.items(),
) )
) )
) )
def _init_metric_states_for_task( def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], es_index self, task_metrics: Tuple[str, Sequence[str]], company_id: str
) -> Sequence[MetricScrollState]: ) -> Sequence[MetricScrollState]:
""" """
Return metric scroll states for the task filled with the variant states Return metric scroll states for the task filled with the variant states
@ -220,7 +214,7 @@ class DebugImagesIterator:
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": EventMetrics.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
@ -228,7 +222,7 @@ class DebugImagesIterator:
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
@ -258,7 +252,13 @@ class DebugImagesIterator:
} }
with translate_errors_context(), TimingContext("es", "_init_metric_states"): with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task,
)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -290,7 +290,7 @@ class DebugImagesIterator:
def _get_task_metric_events( def _get_task_metric_events(
self, self,
metric: MetricScrollState, metric: MetricScrollState,
es_index: str, company_id: str,
iter_count: int, iter_count: int,
navigate_earlier: bool, navigate_earlier: bool,
) -> Tuple: ) -> Tuple:
@ -382,7 +382,7 @@ class DebugImagesIterator:
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
@ -396,7 +396,13 @@ class DebugImagesIterator:
}, },
} }
with translate_errors_context(), TimingContext("es", "get_debug_image_events"): with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=metric.task,
)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return metric.task, metric.name, [] return metric.task, metric.name, []

View File

@ -10,9 +10,8 @@ from redis import StrictRedis
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_metrics import EventMetrics from apiserver.bll.event.event_common import EventSettings, get_index_name
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get from apiserver.utilities.dicts import nested_get
@ -47,18 +46,12 @@ class DebugSampleHistoryResult(object):
class DebugSampleHistory: class DebugSampleHistory:
EVENT_TYPE = "training_debug_image" EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch): def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es self.es = es
self.cache_manager = RedisCacheManager( self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState, state_class=DebugSampleHistoryState,
redis=redis, redis=redis,
expiration_interval=self.state_expiration_sec, expiration_interval=EventSettings.state_expiration_sec,
) )
def get_next_debug_image( def get_next_debug_image(
@ -73,7 +66,7 @@ class DebugSampleHistory:
if not state or state.task != task: if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id) raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) es_index = get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return res return res
@ -219,7 +212,7 @@ class DebugSampleHistory:
If the iteration is not passed then get the latest event If the iteration is not passed then get the latest event
""" """
res = DebugSampleHistoryResult() res = DebugSampleHistoryResult()
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) es_index = get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return res return res
@ -247,6 +240,9 @@ class DebugSampleHistory:
if not var_state: if not var_state:
return res return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [ must_conditions = [
{"term": {"task": task}}, {"term": {"task": task}},
{"term": {"metric": metric}}, {"term": {"metric": metric}},
@ -291,9 +287,7 @@ class DebugSampleHistory:
es_index=es_index, task=state.task, metric=state.metric es_index=es_index, task=state.task, metric=state.metric
) )
state.variant_states = [ state.variant_states = [
VariantState( VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations for var_name, min_iter, max_iter in variant_iterations
] ]
@ -324,7 +318,7 @@ class DebugSampleHistory:
# all variants that sent debug images # all variants that sent debug images
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {

View File

@ -13,12 +13,13 @@ from mongoengine import Q
from nested_dict import nested_dict from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name
from apiserver.bll.util import parallel_chunked_decorator from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics, EventType from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config from apiserver.config_repo import config
@ -147,7 +148,7 @@ class EventBLL(object):
event["metric"] = event.get("metric") or "" event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or "" event["variant"] = event.get("variant") or ""
index_name = EventMetrics.get_index_name(company_id, event_type) index_name = get_index_name(company_id, event_type)
es_action = { es_action = {
"_op_type": "index", # overwrite if exists with same ID "_op_type": "index", # overwrite if exists with same ID
"_index": index_name, "_index": index_name,
@ -406,7 +407,7 @@ class EventBLL(object):
if event_type is None: if event_type is None:
event_type = "*" event_type = "*"
es_index = EventMetrics.get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return [], None, 0 return [], None, 0
@ -435,14 +436,14 @@ class EventBLL(object):
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": EventMetrics.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
@ -494,7 +495,7 @@ class EventBLL(object):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else: else:
event_type = "plot" event_type = "plot"
es_index = EventMetrics.get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
@ -597,7 +598,7 @@ class EventBLL(object):
if event_type is None: if event_type is None:
event_type = "*" event_type = "*"
es_index = EventMetrics.get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
@ -652,7 +653,7 @@ class EventBLL(object):
def get_metrics_and_variants(self, company_id, task_id, event_type): def get_metrics_and_variants(self, company_id, task_id, event_type):
es_index = EventMetrics.get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return {} return {}
@ -663,14 +664,14 @@ class EventBLL(object):
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": EventMetrics.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
} }
} }
@ -695,7 +696,7 @@ class EventBLL(object):
return metrics return metrics
def get_task_latest_scalar_values(self, company_id, task_id): def get_task_latest_scalar_values(self, company_id, task_id):
es_index = EventMetrics.get_index_name(company_id, "training_stats_scalar") es_index = get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return {} return {}
@ -714,14 +715,14 @@ class EventBLL(object):
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": EventMetrics.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
@ -780,7 +781,7 @@ class EventBLL(object):
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant): def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
es_index = EventMetrics.get_index_name(company_id, "training_stats_vector") es_index = get_index_name(company_id, "training_stats_vector")
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return [], [] return [], []
@ -849,7 +850,7 @@ class EventBLL(object):
extra_msg, company=company_id, id=task_id extra_msg, company=company_id, id=task_id
) )
es_index = EventMetrics.get_index_name(company_id, "*") es_index = get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}} es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"): with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True) es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)

View File

@ -0,0 +1,65 @@
from enum import Enum
from typing import Union, Sequence
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventSettings:
@classproperty
def max_workers(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@classproperty
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
def get_index_name(company_id: str, event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool:
es_index = get_index_name(company_id, event_type)
if not es.indices.exists(es_index):
return True
return False
def search_company_events(
es: Elasticsearch,
company_id: Union[str, Sequence[str]],
event_type: str,
body: dict,
**kwargs,
) -> dict:
es_index = get_index_name(company_id, event_type)
return es.search(index=es_index, body=body, **kwargs)
def delete_company_events(
es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type)
return es.delete_by_query(index=es_index, body=body, **kwargs)

View File

@ -2,16 +2,15 @@ import itertools
import math import math
from collections import defaultdict from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from enum import Enum
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from typing import Sequence, Tuple from typing import Sequence, Tuple
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from mongoengine import Q from mongoengine import Q
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
@ -22,14 +21,6 @@ from apiserver.tools import safe_get
log = config.logger(__file__) log = config.logger(__file__)
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventMetrics: class EventMetrics:
MAX_AGGS_ELEMENTS_COUNT = 50 MAX_AGGS_ELEMENTS_COUNT = 50
MAX_SAMPLE_BUCKETS = 6000 MAX_SAMPLE_BUCKETS = 6000
@ -37,23 +28,6 @@ class EventMetrics:
def __init__(self, es: Elasticsearch): def __init__(self, es: Elasticsearch):
self.es = es self.es = es
@classproperty
def max_metrics_count(self):
return config.get("services.events.events_retrieval.max_metrics_count", 100)
@classproperty
def max_variants_count(self):
return config.get("services.events.events_retrieval.max_variants_count", 100)
@property
def _max_concurrency(self):
return config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
def get_scalar_metrics_average_per_iter( def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
) -> dict: ) -> dict:
@ -62,7 +36,7 @@ class EventMetrics:
The amount of points in each histogram should not exceed The amount of points in each histogram should not exceed
the requested samples the requested samples
""" """
es_index = self.get_index_name(company_id, "training_stats_scalar") es_index = get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return {} return {}
@ -89,7 +63,7 @@ class EventMetrics:
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
) )
if run_parallel: if run_parallel:
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
metrics = itertools.chain.from_iterable( metrics = itertools.chain.from_iterable(
pool.map(get_scalar_average, interval_groups) pool.map(get_scalar_average, interval_groups)
) )
@ -136,7 +110,7 @@ class EventMetrics:
"only tasks from the same company are supported" "only tasks from the same company are supported"
) )
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar") es_index = get_index_name(next(iter(companies)), "training_stats_scalar")
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return {} return {}
@ -147,7 +121,7 @@ class EventMetrics:
key=ScalarKey.resolve(key), key=ScalarKey.resolve(key),
run_parallel=False, run_parallel=False,
) )
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip( task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids) task_ids, pool.map(get_scalar_average_per_iter, task_ids)
) )
@ -216,14 +190,14 @@ class EventMetrics:
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": self.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": self.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
@ -293,14 +267,14 @@ class EventMetrics:
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": self.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": self.max_variants_count, "size": EventSettings.max_variants_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": aggregation, "aggs": aggregation,
@ -382,11 +356,11 @@ class EventMetrics:
For the requested tasks return all the metrics that For the requested tasks return all the metrics that
reported events of the requested types reported events of the requested types
""" """
es_index = EventMetrics.get_index_name(company_id, event_type.value) es_index = get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return {} return {}
with ThreadPoolExecutor(self._max_concurrency) as pool: with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res = pool.map( res = pool.map(
partial( partial(
self._get_task_metrics, es_index=es_index, event_type=event_type, self._get_task_metrics, es_index=es_index, event_type=event_type,
@ -410,7 +384,7 @@ class EventMetrics:
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": self.max_metrics_count, "size": EventSettings.max_metrics_count,
"order": {"_key": "asc"}, "order": {"_key": "asc"},
} }
} }

View File

@ -3,7 +3,7 @@ from typing import Optional, Tuple, Sequence
import attr import attr
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from apiserver.bll.event.event_metrics import EventMetrics from apiserver.bll.event.event_common import check_empty_data, search_company_events
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
@ -29,13 +29,12 @@ class LogEventsIterator:
navigate_earlier: bool = True, navigate_earlier: bool = True,
from_timestamp: Optional[int] = None, from_timestamp: Optional[int] = None,
) -> TaskEventsResult: ) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) if check_empty_data(self.es, company_id, self.EVENT_TYPE):
if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
res = TaskEventsResult() res = TaskEventsResult()
res.events, res.total_events = self._get_events( res.events, res.total_events = self._get_events(
es_index=es_index, company_id=company_id,
task_id=task_id, task_id=task_id,
batch_size=batch_size, batch_size=batch_size,
navigate_earlier=navigate_earlier, navigate_earlier=navigate_earlier,
@ -45,7 +44,7 @@ class LogEventsIterator:
def _get_events( def _get_events(
self, self,
es_index, company_id: str,
task_id: str, task_id: str,
batch_size: int, batch_size: int,
navigate_earlier: bool, navigate_earlier: bool,
@ -71,7 +70,13 @@ class LogEventsIterator:
es_req["search_after"] = [from_timestamp] es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req) es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"] hits_total = es_result["hits"]["total"]["value"]
if not hits: if not hits:
@ -92,7 +97,13 @@ class LogEventsIterator:
} }
}, },
} }
es_result = self.es.search(index=es_index, body=es_req) es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"] last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2: if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp # if only one element is returned for the last timestamp

View File

@ -19,7 +19,7 @@ from apiserver.apimodels.events import (
NextDebugImageSampleRequest, NextDebugImageSampleRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_metrics import EventMetrics from apiserver.bll.event.event_common import get_index_name
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json from apiserver.utilities import json
@ -321,7 +321,7 @@ def get_task_latest_scalar_values(call, company_id, _):
metrics, last_timestamp = event_bll.get_task_latest_scalar_values( metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
index_company, task_id index_company, task_id
) )
es_index = EventMetrics.get_index_name(index_company, "*") es_index = get_index_name(index_company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1) last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict( call.result.data = dict(
metrics=metrics, metrics=metrics,