diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 6f6792a..dd9efb1 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -780,7 +780,7 @@ class EventBLL(object): def get_task_events( self, - company_id: str, + company_id: Union[str, Sequence[str]], task_id: Union[str, Sequence[str]], event_type: EventType, metrics: MetricVariants = None, @@ -798,10 +798,21 @@ class EventBLL(object): with translate_errors_context(): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: - if check_empty_data(self.es, company_id=company_id, event_type=event_type): + company_ids = [company_id] if isinstance(company_id, str) else company_id + company_ids = [ + c_id + for c_id in set(company_ids) + if not check_empty_data(self.es, c_id, event_type) + ] + if not company_ids: return TaskEventsResult() - task_ids = [task_id] if isinstance(task_id, str) else task_id + task_ids = ( + [task_id] + if isinstance(task_id, str) + else task_id + ) + must = [] if metrics: @@ -811,7 +822,7 @@ class EventBLL(object): must.append({"terms": {"task": task_ids}}) else: tasks_iters = self.get_last_iters( - company_id=company_id, + company_id=company_ids, event_type=event_type, task_id=task_ids, iters=last_iter_count, @@ -845,7 +856,7 @@ class EventBLL(object): with translate_errors_context(): es_res = search_company_events( self.es, - company_id=company_id, + company_id=company_ids, event_type=event_type, body=es_req, ignore=404, @@ -1028,13 +1039,19 @@ class EventBLL(object): def get_last_iters( self, - company_id: str, + company_id: Union[str, Sequence[str]], event_type: EventType, task_id: Union[str, Sequence[str]], iters: int, metrics: MetricVariants = None ) -> Mapping[str, Sequence]: - if check_empty_data(self.es, company_id=company_id, event_type=event_type): + company_ids = [company_id] if isinstance(company_id, str) else company_id + company_ids = [ + c_id + for c_id in set(company_ids) + if not check_empty_data(self.es, c_id, event_type) + ] + if not company_ids: return {} task_ids = [task_id] if isinstance(task_id, str) else task_id @@ -1063,7 +1080,7 @@ class EventBLL(object): with translate_errors_context(): es_res = search_company_events( - self.es, company_id=company_id, event_type=event_type, body=es_req, + self.es, company_id=company_ids, event_type=event_type, body=es_req, ) if "aggregations" not in es_res: diff --git a/apiserver/bll/event/event_common.py b/apiserver/bll/event/event_common.py index 66849a2..de5d216 100644 --- a/apiserver/bll/event/event_common.py +++ b/apiserver/bll/event/event_common.py @@ -8,6 +8,7 @@ from elasticsearch import Elasticsearch from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context +from apiserver.database.model.task.task import Task from apiserver.tools import safe_get @@ -22,6 +23,7 @@ class EventType(Enum): SINGLE_SCALAR_ITERATION = -(2 ** 31) MetricVariants = Mapping[str, Sequence[str]] +TaskCompanies = Mapping[str, Sequence[Task]] class EventSettings: @@ -52,9 +54,12 @@ class EventSettings: return int(self._max_es_allowed_aggregation_buckets * percentage) -def get_index_name(company_id: str, event_type: str): +def get_index_name(company_id: Union[str, Sequence[str]], event_type: str): event_type = event_type.lower().replace(" ", "_") - return f"events-{event_type}-{company_id.lower()}" + if isinstance(company_id, str): + company_id = [company_id] + + return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id) def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool: diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index 4c3c616..a3e2dd8 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -18,11 +18,11 @@ from apiserver.bll.event.event_common import ( get_metric_variants_condition, get_max_metric_and_variant_counts, SINGLE_SCALAR_ITERATION, + TaskCompanies, ) from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context -from apiserver.database.model.task.task import Task from apiserver.tools import safe_get log = config.logger(__file__) @@ -108,8 +108,7 @@ class EventMetrics: def compare_scalar_metrics_average_per_iter( self, - company_id, - tasks: Sequence[Task], + companies: TaskCompanies, samples, key: ScalarKeyEnum, metric_variants: MetricVariants = None, @@ -119,28 +118,41 @@ class EventMetrics: The amount of points in each histogram should not exceed the requested samples """ event_type = EventType.metrics_scalar - if check_empty_data(self.es, company_id=company_id, event_type=event_type): + companies = { + company_id: tasks + for company_id, tasks in companies.items() + if not check_empty_data( + self.es, company_id=company_id, event_type=event_type + ) + } + if not companies: return {} - task_name_by_id = {t.id: t.name for t in tasks} get_scalar_average_per_iter = partial( self._get_scalar_average_per_iter_core, - company_id=company_id, event_type=event_type, samples=samples, key=ScalarKey.resolve(key), metric_variants=metric_variants, run_parallel=False, ) - task_ids = [t.id for t in tasks] + task_ids, company_ids = zip( + *( + (t.id, t.company) + for t in itertools.chain.from_iterable(companies.values()) + ) + ) with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: task_metrics = zip( - task_ids, pool.map(get_scalar_average_per_iter, task_ids) + task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids) ) + task_names = { + t.id: t.name for t in itertools.chain.from_iterable(companies.values()) + } res = defaultdict(lambda: defaultdict(dict)) for task_id, task_data in task_metrics: - task_name = task_name_by_id[task_id] + task_name = task_names[task_id] for metric_key, metric_data in task_data.items(): for variant_key, variant_data in metric_data.items(): variant_data["name"] = task_name @@ -149,18 +161,27 @@ class EventMetrics: return res def get_task_single_value_metrics( - self, company_id: str, tasks: Sequence[Task] + self, companies: TaskCompanies ) -> Mapping[str, dict]: """ For the requested tasks return all the events delivered for the single iteration (-2**31) """ - if check_empty_data( - self.es, company_id=company_id, event_type=EventType.metrics_scalar - ): + companies = { + company_id: [t.id for t in tasks] + for company_id, tasks in companies.items() + if not check_empty_data( + self.es, company_id=company_id, event_type=EventType.metrics_scalar + ) + } + if not companies: return {} - task_ids = [t.id for t in tasks] - task_events = self._get_task_single_value_metrics(company_id, task_ids) + with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: + task_events = list( + itertools.chain.from_iterable( + pool.map(self._get_task_single_value_metrics, companies.items()) + ), + ) def _get_value(event: dict): return { @@ -174,8 +195,9 @@ class EventMetrics: } def _get_task_single_value_metrics( - self, company_id: str, task_ids: Sequence[str] + self, tasks: Tuple[str, Sequence[str]] ) -> Sequence[dict]: + company_id, task_ids = tasks es_req = { "size": 10000, "query": { diff --git a/apiserver/bll/event/metric_events_iterator.py b/apiserver/bll/event/metric_events_iterator.py index f7d3beb..bc75174 100644 --- a/apiserver/bll/event/metric_events_iterator.py +++ b/apiserver/bll/event/metric_events_iterator.py @@ -75,18 +75,25 @@ class MetricEventsIterator: def get_task_events( self, - company_id: str, + companies: Mapping[str, str], task_metrics: Mapping[str, dict], iter_count: int, navigate_earlier: bool = True, refresh: bool = False, state_id: str = None, ) -> MetricEventsResult: - if check_empty_data(self.es, company_id, self.event_type): + companies = { + task_id: company_id + for task_id, company_id in companies.items() + if not check_empty_data( + self.es, company_id=company_id, event_type=EventType.metrics_scalar + ) + } + if not companies: return MetricEventsResult() def init_state(state_: MetricEventsScrollState): - state_.tasks = self._init_task_states(company_id, task_metrics) + state_.tasks = self._init_task_states(companies, task_metrics) def validate_state(state_: MetricEventsScrollState): """ @@ -95,7 +102,7 @@ class MetricEventsIterator: Refresh the state if requested """ if refresh: - self._reinit_outdated_task_states(company_id, state_, task_metrics) + self._reinit_outdated_task_states(companies, state_, task_metrics) with self.cache_manager.get_or_create_state( state_id=state_id, init_state=init_state, validate_state=validate_state @@ -112,7 +119,7 @@ class MetricEventsIterator: pool.map( partial( self._get_task_metric_events, - company_id=company_id, + companies=companies, iter_count=iter_count, navigate_earlier=navigate_earlier, specific_variants_requested=specific_variants_requested, @@ -125,7 +132,7 @@ class MetricEventsIterator: def _reinit_outdated_task_states( self, - company_id, + companies: Mapping[str, str], state: MetricEventsScrollState, task_metrics: Mapping[str, dict], ): @@ -133,9 +140,7 @@ class MetricEventsIterator: Determine the metrics for which new event_type events were added since their states were initialized and re-init these states """ - tasks = Task.objects(id__in=list(task_metrics), company=company_id).only( - "id", "metric_stats" - ) + tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats") def get_last_update_times_for_task_metrics( task: Task, @@ -175,7 +180,7 @@ class MetricEventsIterator: if metrics_to_recalc: task_metrics_to_recalc[task] = metrics_to_recalc - updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc) + updated_task_states = self._init_task_states(companies, task_metrics_to_recalc) def merge_with_updated_task_states( old_state: TaskScrollState, updates: Sequence[TaskScrollState] @@ -205,14 +210,14 @@ class MetricEventsIterator: ] def _init_task_states( - self, company_id: str, task_metrics: Mapping[str, dict] + self, companies: Mapping[str, str], task_metrics: Mapping[str, dict] ) -> Sequence[TaskScrollState]: """ Returned initialized metric scroll stated for the requested task metrics """ with ThreadPoolExecutor(EventSettings.max_workers) as pool: task_metric_states = pool.map( - partial(self._init_metric_states_for_task, company_id=company_id), + partial(self._init_metric_states_for_task, companies=companies), task_metrics.items(), ) @@ -232,13 +237,14 @@ class MetricEventsIterator: pass def _init_metric_states_for_task( - self, task_metrics: Tuple[str, dict], company_id: str + self, task_metrics: Tuple[str, dict], companies: Mapping[str, str] ) -> Sequence[MetricState]: """ Return metric scroll states for the task filled with the variant states for the variants that reported any event_type events """ task, metrics = task_metrics + company_id = companies[task] must = [{"term": {"task": task}}, *self._get_extra_conditions()] if metrics: must.append(get_metric_variants_condition(metrics)) @@ -319,7 +325,7 @@ class MetricEventsIterator: def _get_task_metric_events( self, task_state: TaskScrollState, - company_id: str, + companies: Mapping[str, str], iter_count: int, navigate_earlier: bool, specific_variants_requested: bool, @@ -391,7 +397,10 @@ class MetricEventsIterator: } with translate_errors_context(): es_res = search_company_events( - self.es, company_id=company_id, event_type=self.event_type, body=es_req, + self.es, + company_id=companies[task_state.task], + event_type=self.event_type, + body=es_req, ) if "aggregations" not in es_res: return task_state.task, [] diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 19eb439..dfd86b0 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -32,7 +32,7 @@ from apiserver.apimodels.events import ( TaskMetric, ) from apiserver.bll.event import EventBLL -from apiserver.bll.event.event_common import EventType, MetricVariants +from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies from apiserver.bll.event.events_iterator import Scroll from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey from apiserver.bll.model import ModelBLL @@ -464,12 +464,11 @@ def scalar_metrics_iter_histogram( call.result.data = metrics -def _get_task_or_model_index_company( +def _get_task_or_model_index_companies( company_id: str, task_ids: Sequence[str], model_events=False, -) -> Tuple[str, Sequence[Task]]: +) -> TaskCompanies: """ - Verify that all tasks exists and belong to store data in the same company index - Return company and tasks + Returns lists of tasks grouped by company """ tasks_or_models = _assert_task_or_model_exists( company_id, task_ids, model_events=model_events, @@ -485,13 +484,7 @@ def _get_task_or_model_index_company( ) raise error_cls(company=company_id, ids=invalid) - companies = {t.get_index_company() for t in tasks_or_models} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) - - return companies.pop(), tasks_or_models + return bucketize(tasks_or_models, key=lambda t: t.get_index_company()) @endpoint( @@ -504,13 +497,12 @@ def multi_task_scalar_metrics_iter_histogram( task_ids = request.tasks if isinstance(task_ids, str): task_ids = [s.strip() for s in task_ids.split(",")] - company, tasks_or_models = _get_task_or_model_index_company( - company_id, task_ids, request.model_events - ) + call.result.data = dict( metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter( - company_id=company, - tasks=tasks_or_models, + companies=_get_task_or_model_index_companies( + company_id, task_ids, request.model_events + ), samples=request.samples, key=request.key, ) @@ -521,11 +513,11 @@ def multi_task_scalar_metrics_iter_histogram( def get_task_single_value_metrics( call, company_id: str, request: SingleValueMetricsRequest ): - company, tasks_or_models = _get_task_or_model_index_company( - company_id, request.tasks, request.model_events + res = event_bll.metrics.get_task_single_value_metrics( + companies=_get_task_or_model_index_companies( + company_id, request.tasks, request.model_events + ), ) - - res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models) call.result.data = dict( tasks=[{"task": task, "values": values} for task, values in res.items()] ) @@ -537,11 +529,11 @@ def get_multi_task_plots_v1_7(call, company_id, _): iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") - company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) + companies = _get_task_or_model_index_companies(company_id, task_ids) # Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( - company, + list(companies), task_ids, event_type=EventType.metrics_plot, sort=[{"iter": {"order": "desc"}}], @@ -549,8 +541,9 @@ def get_multi_task_plots_v1_7(call, company_id, _): scroll_id=scroll_id, ) - task_names = {t.id: t.name for t in tasks_or_models} - + task_names = { + t.id: t.name for t in itertools.chain.from_iterable(companies.values()) + } return_events = _get_top_iter_unique_events_per_task( result.events, max_iters=iters, task_names=task_names ) @@ -564,18 +557,18 @@ def get_multi_task_plots_v1_7(call, company_id, _): def _get_multitask_plots( - company: str, - tasks_or_models: Sequence[Task], + companies: TaskCompanies, last_iters: int, metrics: MetricVariants = None, scroll_id=None, no_scroll=True, model_events=False, ) -> Tuple[dict, int, str]: - task_names = {t.id: t.name for t in tasks_or_models} - + task_names = { + t.id: t.name for t in itertools.chain.from_iterable(companies.values()) + } result = event_bll.get_task_events( - company_id=company, + company_id=list(companies), task_id=list(task_names), event_type=EventType.metrics_plot, metrics=metrics, @@ -599,13 +592,11 @@ def get_multi_task_plots(call, company_id, _): no_scroll = call.data.get("no_scroll", False) model_events = call.data.get("model_events", False) - company, tasks_or_models = _get_task_or_model_index_company( - company_id, task_ids, model_events + companies = _get_task_or_model_index_companies( + company_id, task_ids, model_events=model_events ) - return_events, total_events, next_scroll_id = _get_multitask_plots( - company=company, - tasks_or_models=tasks_or_models, + companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll, @@ -728,12 +719,11 @@ def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEven def task_plots(call, company_id, request: MetricEventsRequest): task_metrics = _task_metrics_dict_from_request(request.metrics) task_ids = list(task_metrics) - company, _ = _get_task_or_model_index_company( + task_or_models = _assert_task_or_model_exists( company_id, task_ids=task_ids, model_events=request.model_events ) - result = event_bll.plots_iterator.get_task_events( - company_id=company, + companies={t.id: t.get_index_company() for t in task_or_models}, task_metrics=task_metrics, iter_count=request.iters, navigate_earlier=request.navigate_earlier, @@ -824,12 +814,11 @@ def get_debug_images_v1_8(call, company_id, _): def get_debug_images(call, company_id, request: MetricEventsRequest): task_metrics = _task_metrics_dict_from_request(request.metrics) task_ids = list(task_metrics) - company, _ = _get_task_or_model_index_company( + task_or_models = _assert_task_or_model_exists( company_id, task_ids=task_ids, model_events=request.model_events ) - result = event_bll.debug_images_iterator.get_task_events( - company_id=company, + companies={t.id: t.get_index_company() for t in task_or_models}, task_metrics=task_metrics, iter_count=request.iters, navigate_earlier=request.navigate_earlier, @@ -922,11 +911,13 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest): @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): - company, _ = _get_task_or_model_index_company( - company_id, request.tasks, model_events=request.model_events + task_or_models = _assert_task_or_model_exists( + company_id, request.tasks, model_events=request.model_events, ) res = event_bll.metrics.get_task_metrics( - company, task_ids=request.tasks, event_type=request.event_type + task_or_models[0].get_index_company(), + task_ids=request.tasks, + event_type=request.event_type, ) call.result.data = { "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 37b0f08..a0bc694 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -1,5 +1,6 @@ import textwrap from datetime import datetime +from itertools import chain from typing import Sequence from apiserver.apimodels.reports import ( @@ -24,7 +25,7 @@ from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskType, TaskStatus from apiserver.service_repo import APICall, endpoint from apiserver.services.events import ( - _get_task_or_model_index_company, + _get_task_or_model_index_companies, event_bll, _get_metrics_response, _get_metric_variants_from_request, @@ -214,10 +215,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): return res task_ids = [task["id"] for task in tasks] - company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) + companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids) if request.debug_images: result = event_bll.debug_images_iterator.get_task_events( - company_id=company, + companies={ + t.id: t.company for t in chain.from_iterable(companies.values()) + }, task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images), iter_count=request.debug_images.iters, ) @@ -227,8 +230,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): if request.plots: res["plots"] = _get_multitask_plots( - company=company, - tasks_or_models=tasks_or_models, + companies=companies, last_iters=request.plots.iters, metrics=_get_metric_variants_from_request(request.plots.metrics), )[0] @@ -237,8 +239,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): res[ "scalar_metrics_iter_histogram" ] = event_bll.metrics.compare_scalar_metrics_average_per_iter( - company_id=company_id, - tasks=tasks_or_models, + companies=companies, samples=request.scalar_metrics_iter_histogram.samples, key=request.scalar_metrics_iter_histogram.key, metric_variants=_get_metric_variants_from_request(