diff --git a/apiserver/bll/event/debug_images_iterator.py b/apiserver/bll/event/debug_images_iterator.py index d219847..09470dd 100644 --- a/apiserver/bll/event/debug_images_iterator.py +++ b/apiserver/bll/event/debug_images_iterator.py @@ -19,6 +19,7 @@ from apiserver.bll.event.event_common import ( EventSettings, check_empty_data, search_company_events, + EventType, ) from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.database.errors import translate_errors_context @@ -59,7 +60,7 @@ class DebugImagesResult(object): class DebugImagesIterator: - EVENT_TYPE = "training_debug_image" + EVENT_TYPE = EventType.metrics_image def __init__(self, redis: StrictRedis, es: Elasticsearch): self.es = es @@ -142,10 +143,10 @@ class DebugImagesIterator: return [ ( (task.id, stats.metric), - stats.event_stats_by_type[self.EVENT_TYPE].last_update, + stats.event_stats_by_type[self.EVENT_TYPE.value].last_update, ) for stats in metric_stats.values() - if self.EVENT_TYPE in stats.event_stats_by_type + if self.EVENT_TYPE.value in stats.event_stats_by_type ] update_times = dict( @@ -257,7 +258,6 @@ class DebugImagesIterator: company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, - routing=task, ) if "aggregations" not in es_res: return [] @@ -401,7 +401,6 @@ class DebugImagesIterator: company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, - routing=metric.task, ) if "aggregations" not in es_res: return metric.task, metric.name, [] diff --git a/apiserver/bll/event/debug_sample_history.py b/apiserver/bll/event/debug_sample_history.py index ebd6984..aa568f8 100644 --- a/apiserver/bll/event/debug_sample_history.py +++ b/apiserver/bll/event/debug_sample_history.py @@ -10,7 +10,12 @@ from redis import StrictRedis from apiserver.apierrors import errors from apiserver.apimodels import JsonSerializableMixin -from apiserver.bll.event.event_common import EventSettings, get_index_name +from apiserver.bll.event.event_common import ( + EventSettings, + EventType, + check_empty_data, + search_company_events, +) from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.database.errors import translate_errors_context from apiserver.timing_context import TimingContext @@ -44,7 +49,7 @@ class DebugSampleHistoryResult(object): class DebugSampleHistory: - EVENT_TYPE = "training_debug_image" + EVENT_TYPE = EventType.metrics_image def __init__(self, redis: StrictRedis, es: Elasticsearch): self.es = es @@ -66,14 +71,13 @@ class DebugSampleHistory: if not state or state.task != task: raise errors.bad_request.InvalidScrollId(scroll_id=state_id) - es_index = get_index_name(company_id, self.EVENT_TYPE) - if not self.es.indices.exists(es_index): + if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE): return res image = self._get_next_for_current_iteration( - es_index=es_index, navigate_earlier=navigate_earlier, state=state + company_id=company_id, navigate_earlier=navigate_earlier, state=state ) or self._get_next_for_another_iteration( - es_index=es_index, navigate_earlier=navigate_earlier, state=state + company_id=company_id, navigate_earlier=navigate_earlier, state=state ) if not image: return res @@ -94,7 +98,7 @@ class DebugSampleHistory: res.max_iteration = var_state.max_iteration def _get_next_for_current_iteration( - self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState + self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState ) -> Optional[dict]: """ Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration @@ -127,7 +131,9 @@ class DebugSampleHistory: with translate_errors_context(), TimingContext( "es", "get_next_for_current_iteration" ): - es_res = self.es.search(index=es_index, body=es_req, routing=state.task) + es_res = search_company_events( + self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req + ) hits = nested_get(es_res, ("hits", "hits")) if not hits: @@ -136,7 +142,7 @@ class DebugSampleHistory: return hits[0]["_source"] def _get_next_for_another_iteration( - self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState + self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState ) -> Optional[dict]: """ Get the image for the first variant for the next iteration (if navigate_earlier is set to False) @@ -189,7 +195,9 @@ class DebugSampleHistory: with translate_errors_context(), TimingContext( "es", "get_next_for_another_iteration" ): - es_res = self.es.search(index=es_index, body=es_req, routing=state.task) + es_res = search_company_events( + self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req + ) hits = nested_get(es_res, ("hits", "hits")) if not hits: @@ -212,14 +220,13 @@ class DebugSampleHistory: If the iteration is not passed then get the latest event """ res = DebugSampleHistoryResult() - es_index = get_index_name(company_id, self.EVENT_TYPE) - if not self.es.indices.exists(es_index): + if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE): return res def init_state(state_: DebugSampleHistoryState): state_.task = task state_.metric = metric - self._reset_variant_states(es_index, state=state_) + self._reset_variant_states(company_id=company_id, state=state_) def validate_state(state_: DebugSampleHistoryState): if state_.task != task or state_.metric != metric: @@ -228,7 +235,7 @@ class DebugSampleHistory: scroll_id=state_.id, ) if refresh: - self._reset_variant_states(es_index, state=state_) + self._reset_variant_states(company_id=company_id, state=state_) state: DebugSampleHistoryState with self.cache_manager.get_or_create_state( @@ -271,7 +278,12 @@ class DebugSampleHistory: with translate_errors_context(), TimingContext( "es", "get_debug_image_for_variant" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task) + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=self.EVENT_TYPE, + body=es_req, + ) hits = nested_get(es_res, ("hits", "hits")) if not hits: @@ -282,9 +294,9 @@ class DebugSampleHistory: ) return res - def _reset_variant_states(self, es_index, state: DebugSampleHistoryState): + def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState): variant_iterations = self._get_variant_iterations( - es_index=es_index, task=state.task, metric=state.metric + company_id=company_id, task=state.task, metric=state.metric ) state.variant_states = [ VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter) @@ -293,7 +305,7 @@ class DebugSampleHistory: def _get_variant_iterations( self, - es_index: str, + company_id: str, task: str, metric: str, variants: Optional[Sequence[str]] = None, @@ -344,7 +356,9 @@ class DebugSampleHistory: with translate_errors_context(), TimingContext( "es", "get_debug_image_iterations" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task) + es_res = search_company_events( + self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req + ) def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]: variant = variant_bucket["key"] diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 13a9864..43c3489 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -13,7 +13,14 @@ from mongoengine import Q from nested_dict import nested_dict from apiserver.bll.event.debug_sample_history import DebugSampleHistory -from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name +from apiserver.bll.event.event_common import ( + EventType, + EventSettings, + get_index_name, + check_empty_data, + search_company_events, + delete_company_events, +) from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils from apiserver.es_factory import es_factory @@ -156,7 +163,7 @@ class EventBLL(object): } # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten) - if event_type != "log": + if event_type != EventType.task_log.value: es_action["_id"] = self._get_event_id(event) else: es_action["_id"] = dbutils.id() @@ -389,10 +396,10 @@ class EventBLL(object): def scroll_task_events( self, - company_id, - task_id, - order, - event_type=None, + company_id: str, + task_id: str, + order: str, + event_type: EventType, batch_size=10000, scroll_id=None, ): @@ -404,12 +411,7 @@ class EventBLL(object): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: size = min(batch_size, 10000) - if event_type is None: - event_type = "*" - - es_index = get_index_name(company_id, event_type) - - if not self.es.indices.exists(es_index): + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return [], None, 0 es_req = { @@ -419,15 +421,25 @@ class EventBLL(object): } with translate_errors_context(), TimingContext("es", "scroll_task_events"): - es_res = self.es.search(index=es_index, body=es_req, scroll="1h") + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=event_type, + body=es_req, + scroll="1h", + ) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return events, next_scroll_id, total_events def get_last_iterations_per_event_metric_variant( - self, es_index: str, task_id: str, num_last_iterations: int, event_type: str + self, + company_id: str, + task_id: str, + num_last_iterations: int, + event_type: EventType, ): - if not self.es.indices.exists(es_index): + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return [] es_req: dict = { @@ -461,13 +473,14 @@ class EventBLL(object): }, "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, } - if event_type: - es_req["query"]["bool"]["must"].append({"term": {"type": event_type}}) with translate_errors_context(), TimingContext( "es", "task_last_iter_metric_variant" ): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) + if "aggregations" not in es_res: return [] @@ -494,9 +507,8 @@ class EventBLL(object): with translate_errors_context(), TimingContext("es", "get_task_events"): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: - event_type = "plot" - es_index = get_index_name(company_id, event_type) - if not self.es.indices.exists(es_index): + event_type = EventType.metrics_plot + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return TaskEventsResult() plot_valid_condition = { @@ -519,7 +531,10 @@ class EventBLL(object): should = [] for i, task_id in enumerate(tasks): last_iters = self.get_last_iterations_per_event_metric_variant( - es_index, task_id, last_iterations_per_plot, event_type + company_id=company_id, + task_id=task_id, + num_last_iterations=last_iterations_per_plot, + event_type=event_type, ) if not last_iters: continue @@ -551,8 +566,13 @@ class EventBLL(object): } with translate_errors_context(), TimingContext("es", "get_task_plots"): - es_res = self.es.search( - index=es_index, body=es_req, ignore=404, scroll="1h", + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=event_type, + body=es_req, + ignore=404, + scroll="1h", ) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) @@ -577,9 +597,9 @@ class EventBLL(object): def get_task_events( self, - company_id, - task_id, - event_type=None, + company_id: str, + task_id: str, + event_type: EventType, metric=None, variant=None, last_iter_count=None, @@ -595,11 +615,8 @@ class EventBLL(object): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id - if event_type is None: - event_type = "*" - es_index = get_index_name(company_id, event_type) - if not self.es.indices.exists(es_index): + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return TaskEventsResult() must = [] @@ -614,7 +631,10 @@ class EventBLL(object): should = [] for i, task_id in enumerate(task_ids): last_iters = self.get_last_iters( - es_index, task_id, event_type, last_iter_count + company_id=company_id, + event_type=event_type, + task_id=task_id, + iters=last_iter_count, ) if not last_iters: continue @@ -642,8 +662,13 @@ class EventBLL(object): } with translate_errors_context(), TimingContext("es", "get_task_events"): - es_res = self.es.search( - index=es_index, body=es_req, ignore=404, scroll="1h", + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=event_type, + body=es_req, + ignore=404, + scroll="1h", ) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) @@ -651,11 +676,10 @@ class EventBLL(object): events=events, next_scroll_id=next_scroll_id, total_events=total_events ) - def get_metrics_and_variants(self, company_id, task_id, event_type): - - es_index = get_index_name(company_id, event_type) - - if not self.es.indices.exists(es_index): + def get_metrics_and_variants( + self, company_id: str, task_id: str, event_type: EventType + ): + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} es_req = { @@ -684,7 +708,9 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) metrics = {} for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): @@ -695,10 +721,9 @@ class EventBLL(object): return metrics - def get_task_latest_scalar_values(self, company_id, task_id): - es_index = get_index_name(company_id, "training_stats_scalar") - - if not self.es.indices.exists(es_index): + def get_task_latest_scalar_values(self, company_id: str, task_id: str): + event_type = EventType.metrics_scalar + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} es_req = { @@ -753,7 +778,9 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) metrics = [] max_timestamp = 0 @@ -780,9 +807,8 @@ class EventBLL(object): return metrics, max_timestamp def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant): - - es_index = get_index_name(company_id, "training_stats_vector") - if not self.es.indices.exists(es_index): + event_type = EventType.metrics_vector + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return [], [] es_req = { @@ -800,7 +826,9 @@ class EventBLL(object): "sort": ["iter"], } with translate_errors_context(), TimingContext("es", "task_stats_vector"): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) vectors = [] iterations = [] @@ -810,8 +838,10 @@ class EventBLL(object): return iterations, vectors - def get_last_iters(self, es_index, task_id, event_type, iters): - if not self.es.indices.exists(es_index): + def get_last_iters( + self, company_id: str, event_type: EventType, task_id: str, iters: int + ): + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return [] es_req: dict = { @@ -827,11 +857,12 @@ class EventBLL(object): }, "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, } - if event_type: - es_req["query"]["bool"]["must"].append({"term": {"type": event_type}}) with translate_errors_context(), TimingContext("es", "task_last_iter"): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) + if "aggregations" not in es_res: return [] @@ -850,9 +881,14 @@ class EventBLL(object): extra_msg, company=company_id, id=task_id ) - es_index = get_index_name(company_id, "*") es_req = {"query": {"term": {"task": task_id}}} with translate_errors_context(), TimingContext("es", "delete_task_events"): - es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True) + es_res = delete_company_events( + es=self.es, + company_id=company_id, + event_type=EventType.all, + body=es_req, + refresh=True, + ) return es_res.get("deleted", 0) diff --git a/apiserver/bll/event/event_common.py b/apiserver/bll/event/event_common.py index 4e3aa0b..bb9a075 100644 --- a/apiserver/bll/event/event_common.py +++ b/apiserver/bll/event/event_common.py @@ -13,6 +13,7 @@ class EventType(Enum): metrics_image = "training_debug_image" metrics_plot = "plot" task_log = "log" + all = "*" class EventSettings: @@ -40,8 +41,8 @@ def get_index_name(company_id: str, event_type: str): return f"events-{event_type}-{company_id}" -def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool: - es_index = get_index_name(company_id, event_type) +def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool: + es_index = get_index_name(company_id, event_type.value) if not es.indices.exists(es_index): return True return False @@ -50,16 +51,16 @@ def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> boo def search_company_events( es: Elasticsearch, company_id: Union[str, Sequence[str]], - event_type: str, + event_type: EventType, body: dict, **kwargs, ) -> dict: - es_index = get_index_name(company_id, event_type) + es_index = get_index_name(company_id, event_type.value) return es.search(index=es_index, body=body, **kwargs) def delete_company_events( - es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs + es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs ) -> dict: - es_index = get_index_name(company_id, event_type) + es_index = get_index_name(company_id, event_type.value) return es.delete_by_query(index=es_index, body=body, **kwargs) diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index 732e54c..ef45d23 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -10,7 +10,12 @@ from elasticsearch import Elasticsearch from mongoengine import Q from apiserver.apierrors import errors -from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings +from apiserver.bll.event.event_common import ( + EventType, + EventSettings, + search_company_events, + check_empty_data, +) from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context @@ -36,31 +41,40 @@ class EventMetrics: The amount of points in each histogram should not exceed the requested samples """ - es_index = get_index_name(company_id, "training_stats_scalar") - if not self.es.indices.exists(es_index): + event_type = EventType.metrics_scalar + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} return self._get_scalar_average_per_iter_core( - task_id, es_index, samples, ScalarKey.resolve(key) + task_id, company_id, event_type, samples, ScalarKey.resolve(key) ) def _get_scalar_average_per_iter_core( self, task_id: str, - es_index: str, + company_id: str, + event_type: EventType, samples: int, key: ScalarKey, run_parallel: bool = True, ) -> dict: intervals = self._get_task_metric_intervals( - es_index=es_index, task_id=task_id, samples=samples, field=key.field + company_id=company_id, + event_type=event_type, + task_id=task_id, + samples=samples, + field=key.field, ) if not intervals: return {} interval_groups = self._group_task_metric_intervals(intervals) get_scalar_average = partial( - self._get_scalar_average, task_id=task_id, es_index=es_index, key=key + self._get_scalar_average, + task_id=task_id, + company_id=company_id, + event_type=event_type, + key=key, ) if run_parallel: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: @@ -110,13 +124,15 @@ class EventMetrics: "only tasks from the same company are supported" ) - es_index = get_index_name(next(iter(companies)), "training_stats_scalar") - if not self.es.indices.exists(es_index): + event_type = EventType.metrics_scalar + company_id = next(iter(companies)) + if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} get_scalar_average_per_iter = partial( self._get_scalar_average_per_iter_core, - es_index=es_index, + company_id=company_id, + event_type=event_type, samples=samples, key=ScalarKey.resolve(key), run_parallel=False, @@ -175,7 +191,12 @@ class EventMetrics: return metric_interval_groups def _get_task_metric_intervals( - self, es_index, task_id: str, samples: int, field: str = "iter" + self, + company_id: str, + event_type: EventType, + task_id: str, + samples: int, + field: str = "iter", ) -> Sequence[MetricInterval]: """ Calculate interval per task metric variant so that the resulting @@ -212,7 +233,9 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req, + ) aggs_result = es_res.get("aggregations") if not aggs_result: @@ -255,7 +278,8 @@ class EventMetrics: self, metrics_interval: MetricIntervalGroup, task_id: str, - es_index: str, + company_id: str, + event_type: EventType, key: ScalarKey, ) -> Sequence[MetricData]: """ @@ -283,7 +307,11 @@ class EventMetrics: } } aggs_result = self._query_aggregation_for_task_metrics( - es_index, aggs=aggs, task_id=task_id, metrics=metrics + company_id=company_id, + event_type=event_type, + aggs=aggs, + task_id=task_id, + metrics=metrics, ) if not aggs_result: @@ -314,7 +342,8 @@ class EventMetrics: def _query_aggregation_for_task_metrics( self, - es_index: str, + company_id: str, + event_type: EventType, aggs: dict, task_id: str, metrics: Sequence[Tuple[str, str]], @@ -345,7 +374,9 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "task_stats_scalar"): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req, + ) return es_res.get("aggregations") @@ -356,30 +387,26 @@ class EventMetrics: For the requested tasks return all the metrics that reported events of the requested types """ - es_index = get_index_name(company_id, event_type.value) - if not self.es.indices.exists(es_index): + if check_empty_data(self.es, company_id, event_type): return {} with ThreadPoolExecutor(EventSettings.max_workers) as pool: res = pool.map( partial( - self._get_task_metrics, es_index=es_index, event_type=event_type, + self._get_task_metrics, + company_id=company_id, + event_type=event_type, ), task_ids, ) return list(zip(task_ids, res)) - def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence: + def _get_task_metrics( + self, task_id: str, company_id: str, event_type: EventType + ) -> Sequence: es_req = { "size": 0, - "query": { - "bool": { - "must": [ - {"term": {"task": task_id}}, - {"term": {"type": event_type.value}}, - ] - } - }, + "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, "aggs": { "metrics": { "terms": { @@ -392,7 +419,9 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "_get_task_metrics"): - es_res = self.es.search(index=es_index, body=es_req) + es_res = search_company_events( + self.es, company_id=company_id, event_type=event_type, body=es_req + ) return [ metric["key"] diff --git a/apiserver/bll/event/log_events_iterator.py b/apiserver/bll/event/log_events_iterator.py index d3ba3c1..bfc202e 100644 --- a/apiserver/bll/event/log_events_iterator.py +++ b/apiserver/bll/event/log_events_iterator.py @@ -3,7 +3,11 @@ from typing import Optional, Tuple, Sequence import attr from elasticsearch import Elasticsearch -from apiserver.bll.event.event_common import check_empty_data, search_company_events +from apiserver.bll.event.event_common import ( + check_empty_data, + search_company_events, + EventType, +) from apiserver.database.errors import translate_errors_context from apiserver.timing_context import TimingContext @@ -16,7 +20,7 @@ class TaskEventsResult: class LogEventsIterator: - EVENT_TYPE = "log" + EVENT_TYPE = EventType.task_log def __init__(self, es: Elasticsearch): self.es = es @@ -75,7 +79,6 @@ class LogEventsIterator: company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, - routing=task_id, ) hits = es_result["hits"]["hits"] hits_total = es_result["hits"]["total"]["value"] @@ -102,7 +105,6 @@ class LogEventsIterator: company_id=company_id, event_type=self.EVENT_TYPE, body=es_req, - routing=task_id, ) last_second_hits = es_result["hits"]["hits"] if not last_second_hits or len(last_second_hits) < 2: diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 421c51a..a6af6e8 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -32,6 +32,7 @@ from furl import furl from mongoengine import Q from apiserver.bll.event import EventBLL +from apiserver.bll.event.event_common import EventType from apiserver.bll.task.artifacts import get_artifact_id from apiserver.bll.task.param_utils import ( split_param_name, @@ -532,19 +533,22 @@ class PrePopulate: scroll_id = None while True: res = cls.event_bll.get_task_events( - task.company, task.id, scroll_id=scroll_id + company_id=task.company, + task_id=task.id, + event_type=EventType.all, + scroll_id=scroll_id, ) if not res.events: break scroll_id = res.next_scroll_id for event in res.events: event_type = event.get("type") - if event_type == "training_debug_image": + if event_type == EventType.metrics_image.value: url = cls._get_fixed_url(event.get("url")) if url: event["url"] = url artifacts.append(url) - elif event_type == "plot": + elif event_type == EventType.metrics_plot.value: plot_str: str = event.get("plot_str", "") for match in cls.img_source_regex.findall(plot_str): url = cls._get_fixed_url(match) diff --git a/apiserver/services/events.py b/apiserver/services/events.py index cdc94a5..6908bcb 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -19,7 +19,7 @@ from apiserver.apimodels.events import ( NextDebugImageSampleRequest, ) from apiserver.bll.event import EventBLL -from apiserver.bll.event.event_common import get_index_name +from apiserver.bll.event.event_common import EventType from apiserver.bll.task import TaskBLL from apiserver.service_repo import APICall, endpoint from apiserver.utilities import json @@ -63,7 +63,7 @@ def get_task_log_v1_5(call, company_id, _): task.get_index_company(), task_id, order, - event_type="log", + event_type=EventType.task_log, batch_size=batch_size, scroll_id=scroll_id, ) @@ -90,7 +90,7 @@ def get_task_log_v1_7(call, company_id, _): company_id=task.get_index_company(), task_id=task_id, order=scroll_order, - event_type="log", + event_type=EventType.task_log, batch_size=batch_size, scroll_id=scroll_id, ) @@ -118,11 +118,9 @@ def get_task_log(call, company_id, request: LogEventsRequest): from_timestamp=request.from_timestamp, ) - if ( - request.order and ( - (request.navigate_earlier and request.order == LogOrderEnum.asc) - or (not request.navigate_earlier and request.order == LogOrderEnum.desc) - ) + if request.order and ( + (request.navigate_earlier and request.order == LogOrderEnum.asc) + or (not request.navigate_earlier and request.order == LogOrderEnum.desc) ): res.events.reverse() @@ -182,7 +180,7 @@ def download_task_log(call, company_id, _): task.get_index_company(), task_id, order="asc", - event_type="log", + event_type=EventType.task_log, batch_size=batch_size, scroll_id=scroll_id, ) @@ -219,7 +217,7 @@ def get_vector_metrics_and_variants(call, company_id, _): )[0] call.result.data = dict( metrics=event_bll.get_metrics_and_variants( - task.get_index_company(), task_id, "training_stats_vector" + task.get_index_company(), task_id, EventType.metrics_vector ) ) @@ -232,7 +230,7 @@ def get_scalar_metrics_and_variants(call, company_id, _): )[0] call.result.data = dict( metrics=event_bll.get_metrics_and_variants( - task.get_index_company(), task_id, "training_stats_scalar" + task.get_index_company(), task_id, EventType.metrics_scalar ) ) @@ -272,7 +270,7 @@ def get_task_events(call, company_id, _): task.get_index_company(), task_id, sort=[{"timestamp": {"order": order}}], - event_type=event_type, + event_type=EventType(event_type) if event_type else EventType.all, scroll_id=scroll_id, size=batch_size, ) @@ -297,7 +295,7 @@ def get_scalar_metric_data(call, company_id, _): result = event_bll.get_task_events( task.get_index_company(), task_id, - event_type="training_stats_scalar", + event_type=EventType.metrics_scalar, sort=[{"iter": {"order": "desc"}}], metric=metric, scroll_id=scroll_id, @@ -321,8 +319,9 @@ def get_task_latest_scalar_values(call, company_id, _): metrics, last_timestamp = event_bll.get_task_latest_scalar_values( index_company, task_id ) - es_index = get_index_name(index_company, "*") - last_iters = event_bll.get_last_iters(es_index, task_id, None, 1) + last_iters = event_bll.get_last_iters( + company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1 + ) call.result.data = dict( metrics=metrics, last_iter=last_iters[0] if last_iters else 0, @@ -344,7 +343,10 @@ def scalar_metrics_iter_histogram( company_id, request.task, allow_public=True, only=("company", "company_origin") )[0] metrics = event_bll.metrics.get_scalar_metrics_average_per_iter( - task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key + task.get_index_company(), + task_id=request.task, + samples=request.samples, + key=request.key, ) call.result.data = metrics @@ -394,7 +396,7 @@ def get_multi_task_plots_v1_7(call, company_id, _): result = event_bll.get_task_events( next(iter(companies)), task_ids, - event_type="plot", + event_type=EventType.metrics_plot, sort=[{"iter": {"order": "desc"}}], size=10000, scroll_id=scroll_id, @@ -436,7 +438,7 @@ def get_multi_task_plots(call, company_id, req_model): result = event_bll.get_task_events( next(iter(companies)), task_ids, - event_type="plot", + event_type=EventType.metrics_plot, sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id, @@ -476,7 +478,7 @@ def get_task_plots_v1_7(call, company_id, _): result = event_bll.get_task_events( task.get_index_company(), task_id, - event_type="plot", + event_type=EventType.metrics_plot, sort=[{"iter": {"order": "desc"}}], size=10000, scroll_id=scroll_id, @@ -539,7 +541,7 @@ def get_debug_images_v1_7(call, company_id, _): result = event_bll.get_task_events( task.get_index_company(), task_id, - event_type="training_debug_image", + event_type=EventType.metrics_image, sort=[{"iter": {"order": "desc"}}], size=10000, scroll_id=scroll_id, @@ -568,7 +570,7 @@ def get_debug_images_v1_8(call, company_id, _): result = event_bll.get_task_events( task.get_index_company(), task_id, - event_type="training_debug_image", + event_type=EventType.metrics_image, sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id, @@ -594,7 +596,10 @@ def get_debug_images_v1_8(call, company_id, _): def get_debug_images(call, company_id, request: DebugImagesRequest): task_ids = {m.task for m in request.metrics} tasks = task_bll.assert_exists( - company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin") + company_id, + task_ids=task_ids, + allow_public=True, + only=("company", "company_origin"), ) companies = {t.get_index_company() for t in tasks} @@ -662,7 +667,7 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque company_id=task.company, task=request.task, state_id=request.scroll_id, - navigate_earlier=request.navigate_earlier + navigate_earlier=request.navigate_earlier, ) call.result.data = attr.asdict(res, recurse=False) @@ -670,7 +675,10 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest): task = task_bll.assert_exists( - company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin") + company_id, + task_ids=request.tasks, + allow_public=True, + only=("company", "company_origin"), )[0] res = event_bll.metrics.get_tasks_metrics( task.get_index_company(), task_ids=request.tasks, event_type=request.event_type