Use EVENT_TYPE enum instead of string

This commit is contained in:
allegroai 2021-01-05 18:21:11 +02:00
parent 6974aa3a99
commit 4707647c92
8 changed files with 239 additions and 146 deletions

View File

@ -19,6 +19,7 @@ from apiserver.bll.event.event_common import (
EventSettings, EventSettings,
check_empty_data, check_empty_data,
search_company_events, search_company_events,
EventType,
) )
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
@ -59,7 +60,7 @@ class DebugImagesResult(object):
class DebugImagesIterator: class DebugImagesIterator:
EVENT_TYPE = "training_debug_image" EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch): def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es self.es = es
@ -142,10 +143,10 @@ class DebugImagesIterator:
return [ return [
( (
(task.id, stats.metric), (task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE].last_update, stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
) )
for stats in metric_stats.values() for stats in metric_stats.values()
if self.EVENT_TYPE in stats.event_stats_by_type if self.EVENT_TYPE.value in stats.event_stats_by_type
] ]
update_times = dict( update_times = dict(
@ -257,7 +258,6 @@ class DebugImagesIterator:
company_id=company_id, company_id=company_id,
event_type=self.EVENT_TYPE, event_type=self.EVENT_TYPE,
body=es_req, body=es_req,
routing=task,
) )
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -401,7 +401,6 @@ class DebugImagesIterator:
company_id=company_id, company_id=company_id,
event_type=self.EVENT_TYPE, event_type=self.EVENT_TYPE,
body=es_req, body=es_req,
routing=metric.task,
) )
if "aggregations" not in es_res: if "aggregations" not in es_res:
return metric.task, metric.name, [] return metric.task, metric.name, []

View File

@ -10,7 +10,12 @@ from redis import StrictRedis
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import EventSettings, get_index_name from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
@ -44,7 +49,7 @@ class DebugSampleHistoryResult(object):
class DebugSampleHistory: class DebugSampleHistory:
EVENT_TYPE = "training_debug_image" EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch): def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es self.es = es
@ -66,14 +71,13 @@ class DebugSampleHistory:
if not state or state.task != task: if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id) raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
es_index = get_index_name(company_id, self.EVENT_TYPE) if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
if not self.es.indices.exists(es_index):
return res return res
image = self._get_next_for_current_iteration( image = self._get_next_for_current_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration( ) or self._get_next_for_another_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state company_id=company_id, navigate_earlier=navigate_earlier, state=state
) )
if not image: if not image:
return res return res
@ -94,7 +98,7 @@ class DebugSampleHistory:
res.max_iteration = var_state.max_iteration res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration( def _get_next_for_current_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]: ) -> Optional[dict]:
""" """
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
@ -127,7 +131,9 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration" "es", "get_next_for_current_iteration"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task) es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits")) hits = nested_get(es_res, ("hits", "hits"))
if not hits: if not hits:
@ -136,7 +142,7 @@ class DebugSampleHistory:
return hits[0]["_source"] return hits[0]["_source"]
def _get_next_for_another_iteration( def _get_next_for_another_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]: ) -> Optional[dict]:
""" """
Get the image for the first variant for the next iteration (if navigate_earlier is set to False) Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
@ -189,7 +195,9 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration" "es", "get_next_for_another_iteration"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task) es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits")) hits = nested_get(es_res, ("hits", "hits"))
if not hits: if not hits:
@ -212,14 +220,13 @@ class DebugSampleHistory:
If the iteration is not passed then get the latest event If the iteration is not passed then get the latest event
""" """
res = DebugSampleHistoryResult() res = DebugSampleHistoryResult()
es_index = get_index_name(company_id, self.EVENT_TYPE) if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
if not self.es.indices.exists(es_index):
return res return res
def init_state(state_: DebugSampleHistoryState): def init_state(state_: DebugSampleHistoryState):
state_.task = task state_.task = task
state_.metric = metric state_.metric = metric
self._reset_variant_states(es_index, state=state_) self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugSampleHistoryState): def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric: if state_.task != task or state_.metric != metric:
@ -228,7 +235,7 @@ class DebugSampleHistory:
scroll_id=state_.id, scroll_id=state_.id,
) )
if refresh: if refresh:
self._reset_variant_states(es_index, state=state_) self._reset_variant_states(company_id=company_id, state=state_)
state: DebugSampleHistoryState state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state( with self.cache_manager.get_or_create_state(
@ -271,7 +278,12 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant" "es", "get_debug_image_for_variant"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=task) es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits")) hits = nested_get(es_res, ("hits", "hits"))
if not hits: if not hits:
@ -282,9 +294,9 @@ class DebugSampleHistory:
) )
return res return res
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState): def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations( variant_iterations = self._get_variant_iterations(
es_index=es_index, task=state.task, metric=state.metric company_id=company_id, task=state.task, metric=state.metric
) )
state.variant_states = [ state.variant_states = [
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter) VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
@ -293,7 +305,7 @@ class DebugSampleHistory:
def _get_variant_iterations( def _get_variant_iterations(
self, self,
es_index: str, company_id: str,
task: str, task: str,
metric: str, metric: str,
variants: Optional[Sequence[str]] = None, variants: Optional[Sequence[str]] = None,
@ -344,7 +356,9 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations" "es", "get_debug_image_iterations"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=task) es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]: def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"] variant = variant_bucket["key"]

View File

@ -13,7 +13,14 @@ from mongoengine import Q
from nested_dict import nested_dict from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name from apiserver.bll.event.event_common import (
EventType,
EventSettings,
get_index_name,
check_empty_data,
search_company_events,
delete_company_events,
)
from apiserver.bll.util import parallel_chunked_decorator from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
@ -156,7 +163,7 @@ class EventBLL(object):
} }
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten) # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
if event_type != "log": if event_type != EventType.task_log.value:
es_action["_id"] = self._get_event_id(event) es_action["_id"] = self._get_event_id(event)
else: else:
es_action["_id"] = dbutils.id() es_action["_id"] = dbutils.id()
@ -389,10 +396,10 @@ class EventBLL(object):
def scroll_task_events( def scroll_task_events(
self, self,
company_id, company_id: str,
task_id, task_id: str,
order, order: str,
event_type=None, event_type: EventType,
batch_size=10000, batch_size=10000,
scroll_id=None, scroll_id=None,
): ):
@ -404,12 +411,7 @@ class EventBLL(object):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else: else:
size = min(batch_size, 10000) size = min(batch_size, 10000)
if event_type is None: if check_empty_data(self.es, company_id=company_id, event_type=event_type):
event_type = "*"
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return [], None, 0 return [], None, 0
es_req = { es_req = {
@ -419,15 +421,25 @@ class EventBLL(object):
} }
with translate_errors_context(), TimingContext("es", "scroll_task_events"): with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(index=es_index, body=es_req, scroll="1h") es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
scroll="1h",
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id, total_events return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant( def get_last_iterations_per_event_metric_variant(
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str self,
company_id: str,
task_id: str,
num_last_iterations: int,
event_type: EventType,
): ):
if not self.es.indices.exists(es_index): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return [] return []
es_req: dict = { es_req: dict = {
@ -461,13 +473,14 @@ class EventBLL(object):
}, },
"query": {"bool": {"must": [{"term": {"task": task_id}}]}}, "query": {"bool": {"must": [{"term": {"task": task_id}}]}},
} }
if event_type:
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant" "es", "task_last_iter_metric_variant"
): ):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -494,9 +507,8 @@ class EventBLL(object):
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else: else:
event_type = "plot" event_type = EventType.metrics_plot
es_index = get_index_name(company_id, event_type) if check_empty_data(self.es, company_id=company_id, event_type=event_type):
if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
plot_valid_condition = { plot_valid_condition = {
@ -519,7 +531,10 @@ class EventBLL(object):
should = [] should = []
for i, task_id in enumerate(tasks): for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant( last_iters = self.get_last_iterations_per_event_metric_variant(
es_index, task_id, last_iterations_per_plot, event_type company_id=company_id,
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
) )
if not last_iters: if not last_iters:
continue continue
@ -551,8 +566,13 @@ class EventBLL(object):
} }
with translate_errors_context(), TimingContext("es", "get_task_plots"): with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = self.es.search( es_res = search_company_events(
index=es_index, body=es_req, ignore=404, scroll="1h", self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
) )
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@ -577,9 +597,9 @@ class EventBLL(object):
def get_task_events( def get_task_events(
self, self,
company_id, company_id: str,
task_id, task_id: str,
event_type=None, event_type: EventType,
metric=None, metric=None,
variant=None, variant=None,
last_iter_count=None, last_iter_count=None,
@ -595,11 +615,8 @@ class EventBLL(object):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else: else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if event_type is None:
event_type = "*"
es_index = get_index_name(company_id, event_type) if check_empty_data(self.es, company_id=company_id, event_type=event_type):
if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
must = [] must = []
@ -614,7 +631,10 @@ class EventBLL(object):
should = [] should = []
for i, task_id in enumerate(task_ids): for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters( last_iters = self.get_last_iters(
es_index, task_id, event_type, last_iter_count company_id=company_id,
event_type=event_type,
task_id=task_id,
iters=last_iter_count,
) )
if not last_iters: if not last_iters:
continue continue
@ -642,8 +662,13 @@ class EventBLL(object):
} }
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.search( es_res = search_company_events(
index=es_index, body=es_req, ignore=404, scroll="1h", self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
) )
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@ -651,11 +676,10 @@ class EventBLL(object):
events=events, next_scroll_id=next_scroll_id, total_events=total_events events=events, next_scroll_id=next_scroll_id, total_events=total_events
) )
def get_metrics_and_variants(self, company_id, task_id, event_type): def get_metrics_and_variants(
self, company_id: str, task_id: str, event_type: EventType
es_index = get_index_name(company_id, event_type) ):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
if not self.es.indices.exists(es_index):
return {} return {}
es_req = { es_req = {
@ -684,7 +708,9 @@ class EventBLL(object):
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants" "es", "events_get_metrics_and_variants"
): ):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
metrics = {} metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@ -695,10 +721,9 @@ class EventBLL(object):
return metrics return metrics
def get_task_latest_scalar_values(self, company_id, task_id): def get_task_latest_scalar_values(self, company_id: str, task_id: str):
es_index = get_index_name(company_id, "training_stats_scalar") event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
if not self.es.indices.exists(es_index):
return {} return {}
es_req = { es_req = {
@ -753,7 +778,9 @@ class EventBLL(object):
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants" "es", "events_get_metrics_and_variants"
): ):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
metrics = [] metrics = []
max_timestamp = 0 max_timestamp = 0
@ -780,9 +807,8 @@ class EventBLL(object):
return metrics, max_timestamp return metrics, max_timestamp
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant): def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
event_type = EventType.metrics_vector
es_index = get_index_name(company_id, "training_stats_vector") if check_empty_data(self.es, company_id=company_id, event_type=event_type):
if not self.es.indices.exists(es_index):
return [], [] return [], []
es_req = { es_req = {
@ -800,7 +826,9 @@ class EventBLL(object):
"sort": ["iter"], "sort": ["iter"],
} }
with translate_errors_context(), TimingContext("es", "task_stats_vector"): with translate_errors_context(), TimingContext("es", "task_stats_vector"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
vectors = [] vectors = []
iterations = [] iterations = []
@ -810,8 +838,10 @@ class EventBLL(object):
return iterations, vectors return iterations, vectors
def get_last_iters(self, es_index, task_id, event_type, iters): def get_last_iters(
if not self.es.indices.exists(es_index): self, company_id: str, event_type: EventType, task_id: str, iters: int
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return [] return []
es_req: dict = { es_req: dict = {
@ -827,11 +857,12 @@ class EventBLL(object):
}, },
"query": {"bool": {"must": [{"term": {"task": task_id}}]}}, "query": {"bool": {"must": [{"term": {"task": task_id}}]}},
} }
if event_type:
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext("es", "task_last_iter"): with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -850,9 +881,14 @@ class EventBLL(object):
extra_msg, company=company_id, id=task_id extra_msg, company=company_id, id=task_id
) )
es_index = get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}} es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"): with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True) es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0) return es_res.get("deleted", 0)

View File

@ -13,6 +13,7 @@ class EventType(Enum):
metrics_image = "training_debug_image" metrics_image = "training_debug_image"
metrics_plot = "plot" metrics_plot = "plot"
task_log = "log" task_log = "log"
all = "*"
class EventSettings: class EventSettings:
@ -40,8 +41,8 @@ def get_index_name(company_id: str, event_type: str):
return f"events-{event_type}-{company_id}" return f"events-{event_type}-{company_id}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool: def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
es_index = get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type.value)
if not es.indices.exists(es_index): if not es.indices.exists(es_index):
return True return True
return False return False
@ -50,16 +51,16 @@ def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> boo
def search_company_events( def search_company_events(
es: Elasticsearch, es: Elasticsearch,
company_id: Union[str, Sequence[str]], company_id: Union[str, Sequence[str]],
event_type: str, event_type: EventType,
body: dict, body: dict,
**kwargs, **kwargs,
) -> dict: ) -> dict:
es_index = get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type.value)
return es.search(index=es_index, body=body, **kwargs) return es.search(index=es_index, body=body, **kwargs)
def delete_company_events( def delete_company_events(
es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict: ) -> dict:
es_index = get_index_name(company_id, event_type) es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs) return es.delete_by_query(index=es_index, body=body, **kwargs)

View File

@ -10,7 +10,12 @@ from elasticsearch import Elasticsearch
from mongoengine import Q from mongoengine import Q
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings from apiserver.bll.event.event_common import (
EventType,
EventSettings,
search_company_events,
check_empty_data,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
@ -36,31 +41,40 @@ class EventMetrics:
The amount of points in each histogram should not exceed The amount of points in each histogram should not exceed
the requested samples the requested samples
""" """
es_index = get_index_name(company_id, "training_stats_scalar") event_type = EventType.metrics_scalar
if not self.es.indices.exists(es_index): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {} return {}
return self._get_scalar_average_per_iter_core( return self._get_scalar_average_per_iter_core(
task_id, es_index, samples, ScalarKey.resolve(key) task_id, company_id, event_type, samples, ScalarKey.resolve(key)
) )
def _get_scalar_average_per_iter_core( def _get_scalar_average_per_iter_core(
self, self,
task_id: str, task_id: str,
es_index: str, company_id: str,
event_type: EventType,
samples: int, samples: int,
key: ScalarKey, key: ScalarKey,
run_parallel: bool = True, run_parallel: bool = True,
) -> dict: ) -> dict:
intervals = self._get_task_metric_intervals( intervals = self._get_task_metric_intervals(
es_index=es_index, task_id=task_id, samples=samples, field=key.field company_id=company_id,
event_type=event_type,
task_id=task_id,
samples=samples,
field=key.field,
) )
if not intervals: if not intervals:
return {} return {}
interval_groups = self._group_task_metric_intervals(intervals) interval_groups = self._group_task_metric_intervals(intervals)
get_scalar_average = partial( get_scalar_average = partial(
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key self._get_scalar_average,
task_id=task_id,
company_id=company_id,
event_type=event_type,
key=key,
) )
if run_parallel: if run_parallel:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
@ -110,13 +124,15 @@ class EventMetrics:
"only tasks from the same company are supported" "only tasks from the same company are supported"
) )
es_index = get_index_name(next(iter(companies)), "training_stats_scalar") event_type = EventType.metrics_scalar
if not self.es.indices.exists(es_index): company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {} return {}
get_scalar_average_per_iter = partial( get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core, self._get_scalar_average_per_iter_core,
es_index=es_index, company_id=company_id,
event_type=event_type,
samples=samples, samples=samples,
key=ScalarKey.resolve(key), key=ScalarKey.resolve(key),
run_parallel=False, run_parallel=False,
@ -175,7 +191,12 @@ class EventMetrics:
return metric_interval_groups return metric_interval_groups
def _get_task_metric_intervals( def _get_task_metric_intervals(
self, es_index, task_id: str, samples: int, field: str = "iter" self,
company_id: str,
event_type: EventType,
task_id: str,
samples: int,
field: str = "iter",
) -> Sequence[MetricInterval]: ) -> Sequence[MetricInterval]:
""" """
Calculate interval per task metric variant so that the resulting Calculate interval per task metric variant so that the resulting
@ -212,7 +233,9 @@ class EventMetrics:
} }
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
aggs_result = es_res.get("aggregations") aggs_result = es_res.get("aggregations")
if not aggs_result: if not aggs_result:
@ -255,7 +278,8 @@ class EventMetrics:
self, self,
metrics_interval: MetricIntervalGroup, metrics_interval: MetricIntervalGroup,
task_id: str, task_id: str,
es_index: str, company_id: str,
event_type: EventType,
key: ScalarKey, key: ScalarKey,
) -> Sequence[MetricData]: ) -> Sequence[MetricData]:
""" """
@ -283,7 +307,11 @@ class EventMetrics:
} }
} }
aggs_result = self._query_aggregation_for_task_metrics( aggs_result = self._query_aggregation_for_task_metrics(
es_index, aggs=aggs, task_id=task_id, metrics=metrics company_id=company_id,
event_type=event_type,
aggs=aggs,
task_id=task_id,
metrics=metrics,
) )
if not aggs_result: if not aggs_result:
@ -314,7 +342,8 @@ class EventMetrics:
def _query_aggregation_for_task_metrics( def _query_aggregation_for_task_metrics(
self, self,
es_index: str, company_id: str,
event_type: EventType,
aggs: dict, aggs: dict,
task_id: str, task_id: str,
metrics: Sequence[Tuple[str, str]], metrics: Sequence[Tuple[str, str]],
@ -345,7 +374,9 @@ class EventMetrics:
} }
with translate_errors_context(), TimingContext("es", "task_stats_scalar"): with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_res.get("aggregations") return es_res.get("aggregations")
@ -356,30 +387,26 @@ class EventMetrics:
For the requested tasks return all the metrics that For the requested tasks return all the metrics that
reported events of the requested types reported events of the requested types
""" """
es_index = get_index_name(company_id, event_type.value) if check_empty_data(self.es, company_id, event_type):
if not self.es.indices.exists(es_index):
return {} return {}
with ThreadPoolExecutor(EventSettings.max_workers) as pool: with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res = pool.map( res = pool.map(
partial( partial(
self._get_task_metrics, es_index=es_index, event_type=event_type, self._get_task_metrics,
company_id=company_id,
event_type=event_type,
), ),
task_ids, task_ids,
) )
return list(zip(task_ids, res)) return list(zip(task_ids, res))
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence: def _get_task_metrics(
self, task_id: str, company_id: str, event_type: EventType
) -> Sequence:
es_req = { es_req = {
"size": 0, "size": 0,
"query": { "query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"type": event_type.value}},
]
}
},
"aggs": { "aggs": {
"metrics": { "metrics": {
"terms": { "terms": {
@ -392,7 +419,9 @@ class EventMetrics:
} }
with translate_errors_context(), TimingContext("es", "_get_task_metrics"): with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req) es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
return [ return [
metric["key"] metric["key"]

View File

@ -3,7 +3,11 @@ from typing import Optional, Tuple, Sequence
import attr import attr
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import check_empty_data, search_company_events from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
@ -16,7 +20,7 @@ class TaskEventsResult:
class LogEventsIterator: class LogEventsIterator:
EVENT_TYPE = "log" EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch): def __init__(self, es: Elasticsearch):
self.es = es self.es = es
@ -75,7 +79,6 @@ class LogEventsIterator:
company_id=company_id, company_id=company_id,
event_type=self.EVENT_TYPE, event_type=self.EVENT_TYPE,
body=es_req, body=es_req,
routing=task_id,
) )
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"] hits_total = es_result["hits"]["total"]["value"]
@ -102,7 +105,6 @@ class LogEventsIterator:
company_id=company_id, company_id=company_id,
event_type=self.EVENT_TYPE, event_type=self.EVENT_TYPE,
body=es_req, body=es_req,
routing=task_id,
) )
last_second_hits = es_result["hits"]["hits"] last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2: if not last_second_hits or len(last_second_hits) < 2:

View File

@ -32,6 +32,7 @@ from furl import furl
from mongoengine import Q from mongoengine import Q
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.artifacts import get_artifact_id from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.bll.task.param_utils import ( from apiserver.bll.task.param_utils import (
split_param_name, split_param_name,
@ -532,19 +533,22 @@ class PrePopulate:
scroll_id = None scroll_id = None
while True: while True:
res = cls.event_bll.get_task_events( res = cls.event_bll.get_task_events(
task.company, task.id, scroll_id=scroll_id company_id=task.company,
task_id=task.id,
event_type=EventType.all,
scroll_id=scroll_id,
) )
if not res.events: if not res.events:
break break
scroll_id = res.next_scroll_id scroll_id = res.next_scroll_id
for event in res.events: for event in res.events:
event_type = event.get("type") event_type = event.get("type")
if event_type == "training_debug_image": if event_type == EventType.metrics_image.value:
url = cls._get_fixed_url(event.get("url")) url = cls._get_fixed_url(event.get("url"))
if url: if url:
event["url"] = url event["url"] = url
artifacts.append(url) artifacts.append(url)
elif event_type == "plot": elif event_type == EventType.metrics_plot.value:
plot_str: str = event.get("plot_str", "") plot_str: str = event.get("plot_str", "")
for match in cls.img_source_regex.findall(plot_str): for match in cls.img_source_regex.findall(plot_str):
url = cls._get_fixed_url(match) url = cls._get_fixed_url(match)

View File

@ -19,7 +19,7 @@ from apiserver.apimodels.events import (
NextDebugImageSampleRequest, NextDebugImageSampleRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import get_index_name from apiserver.bll.event.event_common import EventType
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json from apiserver.utilities import json
@ -63,7 +63,7 @@ def get_task_log_v1_5(call, company_id, _):
task.get_index_company(), task.get_index_company(),
task_id, task_id,
order, order,
event_type="log", event_type=EventType.task_log,
batch_size=batch_size, batch_size=batch_size,
scroll_id=scroll_id, scroll_id=scroll_id,
) )
@ -90,7 +90,7 @@ def get_task_log_v1_7(call, company_id, _):
company_id=task.get_index_company(), company_id=task.get_index_company(),
task_id=task_id, task_id=task_id,
order=scroll_order, order=scroll_order,
event_type="log", event_type=EventType.task_log,
batch_size=batch_size, batch_size=batch_size,
scroll_id=scroll_id, scroll_id=scroll_id,
) )
@ -118,11 +118,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
from_timestamp=request.from_timestamp, from_timestamp=request.from_timestamp,
) )
if ( if request.order and (
request.order and (
(request.navigate_earlier and request.order == LogOrderEnum.asc) (request.navigate_earlier and request.order == LogOrderEnum.asc)
or (not request.navigate_earlier and request.order == LogOrderEnum.desc) or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
)
): ):
res.events.reverse() res.events.reverse()
@ -182,7 +180,7 @@ def download_task_log(call, company_id, _):
task.get_index_company(), task.get_index_company(),
task_id, task_id,
order="asc", order="asc",
event_type="log", event_type=EventType.task_log,
batch_size=batch_size, batch_size=batch_size,
scroll_id=scroll_id, scroll_id=scroll_id,
) )
@ -219,7 +217,7 @@ def get_vector_metrics_and_variants(call, company_id, _):
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, "training_stats_vector" task.get_index_company(), task_id, EventType.metrics_vector
) )
) )
@ -232,7 +230,7 @@ def get_scalar_metrics_and_variants(call, company_id, _):
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, "training_stats_scalar" task.get_index_company(), task_id, EventType.metrics_scalar
) )
) )
@ -272,7 +270,7 @@ def get_task_events(call, company_id, _):
task.get_index_company(), task.get_index_company(),
task_id, task_id,
sort=[{"timestamp": {"order": order}}], sort=[{"timestamp": {"order": order}}],
event_type=event_type, event_type=EventType(event_type) if event_type else EventType.all,
scroll_id=scroll_id, scroll_id=scroll_id,
size=batch_size, size=batch_size,
) )
@ -297,7 +295,7 @@ def get_scalar_metric_data(call, company_id, _):
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.get_index_company(), task.get_index_company(),
task_id, task_id,
event_type="training_stats_scalar", event_type=EventType.metrics_scalar,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
metric=metric, metric=metric,
scroll_id=scroll_id, scroll_id=scroll_id,
@ -321,8 +319,9 @@ def get_task_latest_scalar_values(call, company_id, _):
metrics, last_timestamp = event_bll.get_task_latest_scalar_values( metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
index_company, task_id index_company, task_id
) )
es_index = get_index_name(index_company, "*") last_iters = event_bll.get_last_iters(
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1) company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
)
call.result.data = dict( call.result.data = dict(
metrics=metrics, metrics=metrics,
last_iter=last_iters[0] if last_iters else 0, last_iter=last_iters[0] if last_iters else 0,
@ -344,7 +343,10 @@ def scalar_metrics_iter_histogram(
company_id, request.task, allow_public=True, only=("company", "company_origin") company_id, request.task, allow_public=True, only=("company", "company_origin")
)[0] )[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter( metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key task.get_index_company(),
task_id=request.task,
samples=request.samples,
key=request.key,
) )
call.result.data = metrics call.result.data = metrics
@ -394,7 +396,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
result = event_bll.get_task_events( result = event_bll.get_task_events(
next(iter(companies)), next(iter(companies)),
task_ids, task_ids,
event_type="plot", event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
size=10000, size=10000,
scroll_id=scroll_id, scroll_id=scroll_id,
@ -436,7 +438,7 @@ def get_multi_task_plots(call, company_id, req_model):
result = event_bll.get_task_events( result = event_bll.get_task_events(
next(iter(companies)), next(iter(companies)),
task_ids, task_ids,
event_type="plot", event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters, last_iter_count=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
@ -476,7 +478,7 @@ def get_task_plots_v1_7(call, company_id, _):
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.get_index_company(), task.get_index_company(),
task_id, task_id,
event_type="plot", event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
size=10000, size=10000,
scroll_id=scroll_id, scroll_id=scroll_id,
@ -539,7 +541,7 @@ def get_debug_images_v1_7(call, company_id, _):
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.get_index_company(), task.get_index_company(),
task_id, task_id,
event_type="training_debug_image", event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
size=10000, size=10000,
scroll_id=scroll_id, scroll_id=scroll_id,
@ -568,7 +570,7 @@ def get_debug_images_v1_8(call, company_id, _):
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.get_index_company(), task.get_index_company(),
task_id, task_id,
event_type="training_debug_image", event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters, last_iter_count=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
@ -594,7 +596,10 @@ def get_debug_images_v1_8(call, company_id, _):
def get_debug_images(call, company_id, request: DebugImagesRequest): def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics} task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists( tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin") company_id,
task_ids=task_ids,
allow_public=True,
only=("company", "company_origin"),
) )
companies = {t.get_index_company() for t in tasks} companies = {t.get_index_company() for t in tasks}
@ -662,7 +667,7 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
company_id=task.company, company_id=task.company,
task=request.task, task=request.task,
state_id=request.scroll_id, state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier navigate_earlier=request.navigate_earlier,
) )
call.result.data = attr.asdict(res, recurse=False) call.result.data = attr.asdict(res, recurse=False)
@ -670,7 +675,10 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest): def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin") company_id,
task_ids=request.tasks,
allow_public=True,
only=("company", "company_origin"),
)[0] )[0]
res = event_bll.metrics.get_tasks_metrics( res = event_bll.metrics.get_tasks_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type task.get_index_company(), task_ids=request.tasks, event_type=request.event_type