Removed limit on event comparison for the same company tasks only

This commit is contained in:
allegroai 2022-12-21 18:42:40 +02:00
parent e66257761a
commit 14ff639bb0
6 changed files with 137 additions and 92 deletions

View File

@ -780,7 +780,7 @@ class EventBLL(object):
def get_task_events(
self,
company_id: str,
company_id: Union[str, Sequence[str]],
task_id: Union[str, Sequence[str]],
event_type: EventType,
metrics: MetricVariants = None,
@ -798,10 +798,21 @@ class EventBLL(object):
with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
task_ids = (
[task_id]
if isinstance(task_id, str)
else task_id
)
must = []
if metrics:
@ -811,7 +822,7 @@ class EventBLL(object):
must.append({"terms": {"task": task_ids}})
else:
tasks_iters = self.get_last_iters(
company_id=company_id,
company_id=company_ids,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
@ -845,7 +856,7 @@ class EventBLL(object):
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
company_id=company_ids,
event_type=event_type,
body=es_req,
ignore=404,
@ -1028,13 +1039,19 @@ class EventBLL(object):
def get_last_iters(
self,
company_id: str,
company_id: Union[str, Sequence[str]],
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
metrics: MetricVariants = None
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
@ -1063,7 +1080,7 @@ class EventBLL(object):
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
self.es, company_id=company_ids, event_type=event_type, body=es_req,
)
if "aggregations" not in es_res:

View File

@ -8,6 +8,7 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
@ -22,6 +23,7 @@ class EventType(Enum):
SINGLE_SCALAR_ITERATION = -(2 ** 31)
MetricVariants = Mapping[str, Sequence[str]]
TaskCompanies = Mapping[str, Sequence[Task]]
class EventSettings:
@ -52,9 +54,12 @@ class EventSettings:
return int(self._max_es_allowed_aggregation_buckets * percentage)
def get_index_name(company_id: str, event_type: str):
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id.lower()}"
if isinstance(company_id, str):
company_id = [company_id]
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:

View File

@ -18,11 +18,11 @@ from apiserver.bll.event.event_common import (
get_metric_variants_condition,
get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION,
TaskCompanies,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
log = config.logger(__file__)
@ -108,8 +108,7 @@ class EventMetrics:
def compare_scalar_metrics_average_per_iter(
self,
company_id,
tasks: Sequence[Task],
companies: TaskCompanies,
samples,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
@ -119,28 +118,41 @@ class EventMetrics:
The amount of points in each histogram should not exceed the requested samples
"""
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
companies = {
company_id: tasks
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=event_type
)
}
if not companies:
return {}
task_name_by_id = {t.id: t.name for t in tasks}
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
)
task_ids = [t.id for t in tasks]
task_ids, company_ids = zip(
*(
(t.id, t.company)
for t in itertools.chain.from_iterable(companies.values())
)
)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
)
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id]
task_name = task_names[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
@ -149,18 +161,27 @@ class EventMetrics:
return res
def get_task_single_value_metrics(
self, company_id: str, tasks: Sequence[Task]
self, companies: TaskCompanies
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
):
companies = {
company_id: [t.id for t in tasks]
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
)
}
if not companies:
return {}
task_ids = [t.id for t in tasks]
task_events = self._get_task_single_value_metrics(company_id, task_ids)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
),
)
def _get_value(event: dict):
return {
@ -174,8 +195,9 @@ class EventMetrics:
}
def _get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
self, tasks: Tuple[str, Sequence[str]]
) -> Sequence[dict]:
company_id, task_ids = tasks
es_req = {
"size": 10000,
"query": {

View File

@ -75,18 +75,25 @@ class MetricEventsIterator:
def get_task_events(
self,
company_id: str,
companies: Mapping[str, str],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> MetricEventsResult:
if check_empty_data(self.es, company_id, self.event_type):
companies = {
task_id: company_id
for task_id, company_id in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
)
}
if not companies:
return MetricEventsResult()
def init_state(state_: MetricEventsScrollState):
state_.tasks = self._init_task_states(company_id, task_metrics)
state_.tasks = self._init_task_states(companies, task_metrics)
def validate_state(state_: MetricEventsScrollState):
"""
@ -95,7 +102,7 @@ class MetricEventsIterator:
Refresh the state if requested
"""
if refresh:
self._reinit_outdated_task_states(company_id, state_, task_metrics)
self._reinit_outdated_task_states(companies, state_, task_metrics)
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state
@ -112,7 +119,7 @@ class MetricEventsIterator:
pool.map(
partial(
self._get_task_metric_events,
company_id=company_id,
companies=companies,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
specific_variants_requested=specific_variants_requested,
@ -125,7 +132,7 @@ class MetricEventsIterator:
def _reinit_outdated_task_states(
self,
company_id,
companies: Mapping[str, str],
state: MetricEventsScrollState,
task_metrics: Mapping[str, dict],
):
@ -133,9 +140,7 @@ class MetricEventsIterator:
Determine the metrics for which new event_type events were added
since their states were initialized and re-init these states
"""
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
"id", "metric_stats"
)
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
def get_last_update_times_for_task_metrics(
task: Task,
@ -175,7 +180,7 @@ class MetricEventsIterator:
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
@ -205,14 +210,14 @@ class MetricEventsIterator:
]
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, dict]
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id),
partial(self._init_metric_states_for_task, companies=companies),
task_metrics.items(),
)
@ -232,13 +237,14 @@ class MetricEventsIterator:
pass
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], company_id: str
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any event_type events
"""
task, metrics = task_metrics
company_id = companies[task]
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
if metrics:
must.append(get_metric_variants_condition(metrics))
@ -319,7 +325,7 @@ class MetricEventsIterator:
def _get_task_metric_events(
self,
task_state: TaskScrollState,
company_id: str,
companies: Mapping[str, str],
iter_count: int,
navigate_earlier: bool,
specific_variants_requested: bool,
@ -391,7 +397,10 @@ class MetricEventsIterator:
}
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
self.es,
company_id=companies[task_state.task],
event_type=self.event_type,
body=es_req,
)
if "aggregations" not in es_res:
return task_state.task, []

View File

@ -32,7 +32,7 @@ from apiserver.apimodels.events import (
TaskMetric,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
@ -464,12 +464,11 @@ def scalar_metrics_iter_histogram(
call.result.data = metrics
def _get_task_or_model_index_company(
def _get_task_or_model_index_companies(
company_id: str, task_ids: Sequence[str], model_events=False,
) -> Tuple[str, Sequence[Task]]:
) -> TaskCompanies:
"""
Verify that all tasks exists and belong to store data in the same company index
Return company and tasks
Returns lists of tasks grouped by company
"""
tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events,
@ -485,13 +484,7 @@ def _get_task_or_model_index_company(
)
raise error_cls(company=company_id, ids=invalid)
companies = {t.get_index_company() for t in tasks_or_models}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
return companies.pop(), tasks_or_models
return bucketize(tasks_or_models, key=lambda t: t.get_index_company())
@endpoint(
@ -504,13 +497,12 @@ def multi_task_scalar_metrics_iter_histogram(
task_ids = request.tasks
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids, request.model_events
)
call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id=company,
tasks=tasks_or_models,
companies=_get_task_or_model_index_companies(
company_id, task_ids, request.model_events
),
samples=request.samples,
key=request.key,
)
@ -521,11 +513,11 @@ def multi_task_scalar_metrics_iter_histogram(
def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest
):
company, tasks_or_models = _get_task_or_model_index_company(
company_id, request.tasks, request.model_events
res = event_bll.metrics.get_task_single_value_metrics(
companies=_get_task_or_model_index_companies(
company_id, request.tasks, request.model_events
),
)
res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()]
)
@ -537,11 +529,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
companies = _get_task_or_model_index_companies(company_id, task_ids)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company,
list(companies),
task_ids,
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
@ -549,8 +541,9 @@ def get_multi_task_plots_v1_7(call, company_id, _):
scroll_id=scroll_id,
)
task_names = {t.id: t.name for t in tasks_or_models}
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, task_names=task_names
)
@ -564,18 +557,18 @@ def get_multi_task_plots_v1_7(call, company_id, _):
def _get_multitask_plots(
company: str,
tasks_or_models: Sequence[Task],
companies: TaskCompanies,
last_iters: int,
metrics: MetricVariants = None,
scroll_id=None,
no_scroll=True,
model_events=False,
) -> Tuple[dict, int, str]:
task_names = {t.id: t.name for t in tasks_or_models}
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
result = event_bll.get_task_events(
company_id=company,
company_id=list(companies),
task_id=list(task_names),
event_type=EventType.metrics_plot,
metrics=metrics,
@ -599,13 +592,11 @@ def get_multi_task_plots(call, company_id, _):
no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids, model_events
companies = _get_task_or_model_index_companies(
company_id, task_ids, model_events=model_events
)
return_events, total_events, next_scroll_id = _get_multitask_plots(
company=company,
tasks_or_models=tasks_or_models,
companies=companies,
last_iters=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
@ -728,12 +719,11 @@ def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEven
def task_plots(call, company_id, request: MetricEventsRequest):
task_metrics = _task_metrics_dict_from_request(request.metrics)
task_ids = list(task_metrics)
company, _ = _get_task_or_model_index_company(
task_or_models = _assert_task_or_model_exists(
company_id, task_ids=task_ids, model_events=request.model_events
)
result = event_bll.plots_iterator.get_task_events(
company_id=company,
companies={t.id: t.get_index_company() for t in task_or_models},
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
@ -824,12 +814,11 @@ def get_debug_images_v1_8(call, company_id, _):
def get_debug_images(call, company_id, request: MetricEventsRequest):
task_metrics = _task_metrics_dict_from_request(request.metrics)
task_ids = list(task_metrics)
company, _ = _get_task_or_model_index_company(
task_or_models = _assert_task_or_model_exists(
company_id, task_ids=task_ids, model_events=request.model_events
)
result = event_bll.debug_images_iterator.get_task_events(
company_id=company,
companies={t.id: t.get_index_company() for t in task_or_models},
task_metrics=task_metrics,
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
@ -922,11 +911,13 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
company, _ = _get_task_or_model_index_company(
company_id, request.tasks, model_events=request.model_events
task_or_models = _assert_task_or_model_exists(
company_id, request.tasks, model_events=request.model_events,
)
res = event_bll.metrics.get_task_metrics(
company, task_ids=request.tasks, event_type=request.event_type
task_or_models[0].get_index_company(),
task_ids=request.tasks,
event_type=request.event_type,
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]

View File

@ -1,5 +1,6 @@
import textwrap
from datetime import datetime
from itertools import chain
from typing import Sequence
from apiserver.apimodels.reports import (
@ -24,7 +25,7 @@ from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
from apiserver.service_repo import APICall, endpoint
from apiserver.services.events import (
_get_task_or_model_index_company,
_get_task_or_model_index_companies,
event_bll,
_get_metrics_response,
_get_metric_variants_from_request,
@ -214,10 +215,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
return res
task_ids = [task["id"] for task in tasks]
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events(
company_id=company,
companies={
t.id: t.company for t in chain.from_iterable(companies.values())
},
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
iter_count=request.debug_images.iters,
)
@ -227,8 +230,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
if request.plots:
res["plots"] = _get_multitask_plots(
company=company,
tasks_or_models=tasks_or_models,
companies=companies,
last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics),
)[0]
@ -237,8 +239,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
res[
"scalar_metrics_iter_histogram"
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id=company_id,
tasks=tasks_or_models,
companies=companies,
samples=request.scalar_metrics_iter_histogram.samples,
key=request.scalar_metrics_iter_histogram.key,
metric_variants=_get_metric_variants_from_request(