Removed limit on event comparison for the same company tasks only

This commit is contained in:
allegroai 2022-12-21 18:42:40 +02:00
parent e66257761a
commit 14ff639bb0
6 changed files with 137 additions and 92 deletions

View File

@ -780,7 +780,7 @@ class EventBLL(object):
def get_task_events( def get_task_events(
self, self,
company_id: str, company_id: Union[str, Sequence[str]],
task_id: Union[str, Sequence[str]], task_id: Union[str, Sequence[str]],
event_type: EventType, event_type: EventType,
metrics: MetricVariants = None, metrics: MetricVariants = None,
@ -798,10 +798,21 @@ class EventBLL(object):
with translate_errors_context(): with translate_errors_context():
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else: else:
if check_empty_data(self.es, company_id=company_id, event_type=event_type): company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return TaskEventsResult() return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id task_ids = (
[task_id]
if isinstance(task_id, str)
else task_id
)
must = [] must = []
if metrics: if metrics:
@ -811,7 +822,7 @@ class EventBLL(object):
must.append({"terms": {"task": task_ids}}) must.append({"terms": {"task": task_ids}})
else: else:
tasks_iters = self.get_last_iters( tasks_iters = self.get_last_iters(
company_id=company_id, company_id=company_ids,
event_type=event_type, event_type=event_type,
task_id=task_ids, task_id=task_ids,
iters=last_iter_count, iters=last_iter_count,
@ -845,7 +856,7 @@ class EventBLL(object):
with translate_errors_context(): with translate_errors_context():
es_res = search_company_events( es_res = search_company_events(
self.es, self.es,
company_id=company_id, company_id=company_ids,
event_type=event_type, event_type=event_type,
body=es_req, body=es_req,
ignore=404, ignore=404,
@ -1028,13 +1039,19 @@ class EventBLL(object):
def get_last_iters( def get_last_iters(
self, self,
company_id: str, company_id: Union[str, Sequence[str]],
event_type: EventType, event_type: EventType,
task_id: Union[str, Sequence[str]], task_id: Union[str, Sequence[str]],
iters: int, iters: int,
metrics: MetricVariants = None metrics: MetricVariants = None
) -> Mapping[str, Sequence]: ) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type): company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [
c_id
for c_id in set(company_ids)
if not check_empty_data(self.es, c_id, event_type)
]
if not company_ids:
return {} return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id task_ids = [task_id] if isinstance(task_id, str) else task_id
@ -1063,7 +1080,7 @@ class EventBLL(object):
with translate_errors_context(): with translate_errors_context():
es_res = search_company_events( es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req, self.es, company_id=company_ids, event_type=event_type, body=es_req,
) )
if "aggregations" not in es_res: if "aggregations" not in es_res:

View File

@ -8,6 +8,7 @@ from elasticsearch import Elasticsearch
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get from apiserver.tools import safe_get
@ -22,6 +23,7 @@ class EventType(Enum):
SINGLE_SCALAR_ITERATION = -(2 ** 31) SINGLE_SCALAR_ITERATION = -(2 ** 31)
MetricVariants = Mapping[str, Sequence[str]] MetricVariants = Mapping[str, Sequence[str]]
TaskCompanies = Mapping[str, Sequence[Task]]
class EventSettings: class EventSettings:
@ -52,9 +54,12 @@ class EventSettings:
return int(self._max_es_allowed_aggregation_buckets * percentage) return int(self._max_es_allowed_aggregation_buckets * percentage)
def get_index_name(company_id: str, event_type: str): def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
event_type = event_type.lower().replace(" ", "_") event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id.lower()}" if isinstance(company_id, str):
company_id = [company_id]
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool: def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:

View File

@ -18,11 +18,11 @@ from apiserver.bll.event.event_common import (
get_metric_variants_condition, get_metric_variants_condition,
get_max_metric_and_variant_counts, get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION, SINGLE_SCALAR_ITERATION,
TaskCompanies,
) )
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get from apiserver.tools import safe_get
log = config.logger(__file__) log = config.logger(__file__)
@ -108,8 +108,7 @@ class EventMetrics:
def compare_scalar_metrics_average_per_iter( def compare_scalar_metrics_average_per_iter(
self, self,
company_id, companies: TaskCompanies,
tasks: Sequence[Task],
samples, samples,
key: ScalarKeyEnum, key: ScalarKeyEnum,
metric_variants: MetricVariants = None, metric_variants: MetricVariants = None,
@ -119,28 +118,41 @@ class EventMetrics:
The amount of points in each histogram should not exceed the requested samples The amount of points in each histogram should not exceed the requested samples
""" """
event_type = EventType.metrics_scalar event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type): companies = {
company_id: tasks
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=event_type
)
}
if not companies:
return {} return {}
task_name_by_id = {t.id: t.name for t in tasks}
get_scalar_average_per_iter = partial( get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core, self._get_scalar_average_per_iter_core,
company_id=company_id,
event_type=event_type, event_type=event_type,
samples=samples, samples=samples,
key=ScalarKey.resolve(key), key=ScalarKey.resolve(key),
metric_variants=metric_variants, metric_variants=metric_variants,
run_parallel=False, run_parallel=False,
) )
task_ids = [t.id for t in tasks] task_ids, company_ids = zip(
*(
(t.id, t.company)
for t in itertools.chain.from_iterable(companies.values())
)
)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip( task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids) task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
) )
task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
res = defaultdict(lambda: defaultdict(dict)) res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics: for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id] task_name = task_names[task_id]
for metric_key, metric_data in task_data.items(): for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items(): for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name variant_data["name"] = task_name
@ -149,18 +161,27 @@ class EventMetrics:
return res return res
def get_task_single_value_metrics( def get_task_single_value_metrics(
self, company_id: str, tasks: Sequence[Task] self, companies: TaskCompanies
) -> Mapping[str, dict]: ) -> Mapping[str, dict]:
""" """
For the requested tasks return all the events delivered for the single iteration (-2**31) For the requested tasks return all the events delivered for the single iteration (-2**31)
""" """
if check_empty_data( companies = {
company_id: [t.id for t in tasks]
for company_id, tasks in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar self.es, company_id=company_id, event_type=EventType.metrics_scalar
): )
}
if not companies:
return {} return {}
task_ids = [t.id for t in tasks] with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_events = self._get_task_single_value_metrics(company_id, task_ids) task_events = list(
itertools.chain.from_iterable(
pool.map(self._get_task_single_value_metrics, companies.items())
),
)
def _get_value(event: dict): def _get_value(event: dict):
return { return {
@ -174,8 +195,9 @@ class EventMetrics:
} }
def _get_task_single_value_metrics( def _get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str] self, tasks: Tuple[str, Sequence[str]]
) -> Sequence[dict]: ) -> Sequence[dict]:
company_id, task_ids = tasks
es_req = { es_req = {
"size": 10000, "size": 10000,
"query": { "query": {

View File

@ -75,18 +75,25 @@ class MetricEventsIterator:
def get_task_events( def get_task_events(
self, self,
company_id: str, companies: Mapping[str, str],
task_metrics: Mapping[str, dict], task_metrics: Mapping[str, dict],
iter_count: int, iter_count: int,
navigate_earlier: bool = True, navigate_earlier: bool = True,
refresh: bool = False, refresh: bool = False,
state_id: str = None, state_id: str = None,
) -> MetricEventsResult: ) -> MetricEventsResult:
if check_empty_data(self.es, company_id, self.event_type): companies = {
task_id: company_id
for task_id, company_id in companies.items()
if not check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
)
}
if not companies:
return MetricEventsResult() return MetricEventsResult()
def init_state(state_: MetricEventsScrollState): def init_state(state_: MetricEventsScrollState):
state_.tasks = self._init_task_states(company_id, task_metrics) state_.tasks = self._init_task_states(companies, task_metrics)
def validate_state(state_: MetricEventsScrollState): def validate_state(state_: MetricEventsScrollState):
""" """
@ -95,7 +102,7 @@ class MetricEventsIterator:
Refresh the state if requested Refresh the state if requested
""" """
if refresh: if refresh:
self._reinit_outdated_task_states(company_id, state_, task_metrics) self._reinit_outdated_task_states(companies, state_, task_metrics)
with self.cache_manager.get_or_create_state( with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state state_id=state_id, init_state=init_state, validate_state=validate_state
@ -112,7 +119,7 @@ class MetricEventsIterator:
pool.map( pool.map(
partial( partial(
self._get_task_metric_events, self._get_task_metric_events,
company_id=company_id, companies=companies,
iter_count=iter_count, iter_count=iter_count,
navigate_earlier=navigate_earlier, navigate_earlier=navigate_earlier,
specific_variants_requested=specific_variants_requested, specific_variants_requested=specific_variants_requested,
@ -125,7 +132,7 @@ class MetricEventsIterator:
def _reinit_outdated_task_states( def _reinit_outdated_task_states(
self, self,
company_id, companies: Mapping[str, str],
state: MetricEventsScrollState, state: MetricEventsScrollState,
task_metrics: Mapping[str, dict], task_metrics: Mapping[str, dict],
): ):
@ -133,9 +140,7 @@ class MetricEventsIterator:
Determine the metrics for which new event_type events were added Determine the metrics for which new event_type events were added
since their states were initialized and re-init these states since their states were initialized and re-init these states
""" """
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only( tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics( def get_last_update_times_for_task_metrics(
task: Task, task: Task,
@ -175,7 +180,7 @@ class MetricEventsIterator:
if metrics_to_recalc: if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc task_metrics_to_recalc[task] = metrics_to_recalc
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc) updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
def merge_with_updated_task_states( def merge_with_updated_task_states(
old_state: TaskScrollState, updates: Sequence[TaskScrollState] old_state: TaskScrollState, updates: Sequence[TaskScrollState]
@ -205,14 +210,14 @@ class MetricEventsIterator:
] ]
def _init_task_states( def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, dict] self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]: ) -> Sequence[TaskScrollState]:
""" """
Returned initialized metric scroll stated for the requested task metrics Returned initialized metric scroll stated for the requested task metrics
""" """
with ThreadPoolExecutor(EventSettings.max_workers) as pool: with ThreadPoolExecutor(EventSettings.max_workers) as pool:
task_metric_states = pool.map( task_metric_states = pool.map(
partial(self._init_metric_states_for_task, company_id=company_id), partial(self._init_metric_states_for_task, companies=companies),
task_metrics.items(), task_metrics.items(),
) )
@ -232,13 +237,14 @@ class MetricEventsIterator:
pass pass
def _init_metric_states_for_task( def _init_metric_states_for_task(
self, task_metrics: Tuple[str, dict], company_id: str self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
) -> Sequence[MetricState]: ) -> Sequence[MetricState]:
""" """
Return metric scroll states for the task filled with the variant states Return metric scroll states for the task filled with the variant states
for the variants that reported any event_type events for the variants that reported any event_type events
""" """
task, metrics = task_metrics task, metrics = task_metrics
company_id = companies[task]
must = [{"term": {"task": task}}, *self._get_extra_conditions()] must = [{"term": {"task": task}}, *self._get_extra_conditions()]
if metrics: if metrics:
must.append(get_metric_variants_condition(metrics)) must.append(get_metric_variants_condition(metrics))
@ -319,7 +325,7 @@ class MetricEventsIterator:
def _get_task_metric_events( def _get_task_metric_events(
self, self,
task_state: TaskScrollState, task_state: TaskScrollState,
company_id: str, companies: Mapping[str, str],
iter_count: int, iter_count: int,
navigate_earlier: bool, navigate_earlier: bool,
specific_variants_requested: bool, specific_variants_requested: bool,
@ -391,7 +397,10 @@ class MetricEventsIterator:
} }
with translate_errors_context(): with translate_errors_context():
es_res = search_company_events( es_res = search_company_events(
self.es, company_id=company_id, event_type=self.event_type, body=es_req, self.es,
company_id=companies[task_state.task],
event_type=self.event_type,
body=es_req,
) )
if "aggregations" not in es_res: if "aggregations" not in es_res:
return task_state.task, [] return task_state.task, []

View File

@ -32,7 +32,7 @@ from apiserver.apimodels.events import (
TaskMetric, TaskMetric,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
from apiserver.bll.event.events_iterator import Scroll from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
@ -464,12 +464,11 @@ def scalar_metrics_iter_histogram(
call.result.data = metrics call.result.data = metrics
def _get_task_or_model_index_company( def _get_task_or_model_index_companies(
company_id: str, task_ids: Sequence[str], model_events=False, company_id: str, task_ids: Sequence[str], model_events=False,
) -> Tuple[str, Sequence[Task]]: ) -> TaskCompanies:
""" """
Verify that all tasks exists and belong to store data in the same company index Returns lists of tasks grouped by company
Return company and tasks
""" """
tasks_or_models = _assert_task_or_model_exists( tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events, company_id, task_ids, model_events=model_events,
@ -485,13 +484,7 @@ def _get_task_or_model_index_company(
) )
raise error_cls(company=company_id, ids=invalid) raise error_cls(company=company_id, ids=invalid)
companies = {t.get_index_company() for t in tasks_or_models} return bucketize(tasks_or_models, key=lambda t: t.get_index_company())
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
return companies.pop(), tasks_or_models
@endpoint( @endpoint(
@ -504,13 +497,12 @@ def multi_task_scalar_metrics_iter_histogram(
task_ids = request.tasks task_ids = request.tasks
if isinstance(task_ids, str): if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")] task_ids = [s.strip() for s in task_ids.split(",")]
company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids, request.model_events
)
call.result.data = dict( call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter( metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id=company, companies=_get_task_or_model_index_companies(
tasks=tasks_or_models, company_id, task_ids, request.model_events
),
samples=request.samples, samples=request.samples,
key=request.key, key=request.key,
) )
@ -521,11 +513,11 @@ def multi_task_scalar_metrics_iter_histogram(
def get_task_single_value_metrics( def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest call, company_id: str, request: SingleValueMetricsRequest
): ):
company, tasks_or_models = _get_task_or_model_index_company( res = event_bll.metrics.get_task_single_value_metrics(
companies=_get_task_or_model_index_companies(
company_id, request.tasks, request.model_events company_id, request.tasks, request.model_events
),
) )
res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
call.result.data = dict( call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()] tasks=[{"task": task, "values": values} for task, values in res.items()]
) )
@ -537,11 +529,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
iters = call.data.get("iters", 1) iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) companies = _get_task_or_model_index_companies(company_id, task_ids)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination # Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events( result = event_bll.get_task_events(
company, list(companies),
task_ids, task_ids,
event_type=EventType.metrics_plot, event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
@ -549,8 +541,9 @@ def get_multi_task_plots_v1_7(call, company_id, _):
scroll_id=scroll_id, scroll_id=scroll_id,
) )
task_names = {t.id: t.name for t in tasks_or_models} task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
return_events = _get_top_iter_unique_events_per_task( return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, task_names=task_names result.events, max_iters=iters, task_names=task_names
) )
@ -564,18 +557,18 @@ def get_multi_task_plots_v1_7(call, company_id, _):
def _get_multitask_plots( def _get_multitask_plots(
company: str, companies: TaskCompanies,
tasks_or_models: Sequence[Task],
last_iters: int, last_iters: int,
metrics: MetricVariants = None, metrics: MetricVariants = None,
scroll_id=None, scroll_id=None,
no_scroll=True, no_scroll=True,
model_events=False, model_events=False,
) -> Tuple[dict, int, str]: ) -> Tuple[dict, int, str]:
task_names = {t.id: t.name for t in tasks_or_models} task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
}
result = event_bll.get_task_events( result = event_bll.get_task_events(
company_id=company, company_id=list(companies),
task_id=list(task_names), task_id=list(task_names),
event_type=EventType.metrics_plot, event_type=EventType.metrics_plot,
metrics=metrics, metrics=metrics,
@ -599,13 +592,11 @@ def get_multi_task_plots(call, company_id, _):
no_scroll = call.data.get("no_scroll", False) no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False) model_events = call.data.get("model_events", False)
company, tasks_or_models = _get_task_or_model_index_company( companies = _get_task_or_model_index_companies(
company_id, task_ids, model_events company_id, task_ids, model_events=model_events
) )
return_events, total_events, next_scroll_id = _get_multitask_plots( return_events, total_events, next_scroll_id = _get_multitask_plots(
company=company, companies=companies,
tasks_or_models=tasks_or_models,
last_iters=iters, last_iters=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=no_scroll, no_scroll=no_scroll,
@ -728,12 +719,11 @@ def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEven
def task_plots(call, company_id, request: MetricEventsRequest): def task_plots(call, company_id, request: MetricEventsRequest):
task_metrics = _task_metrics_dict_from_request(request.metrics) task_metrics = _task_metrics_dict_from_request(request.metrics)
task_ids = list(task_metrics) task_ids = list(task_metrics)
company, _ = _get_task_or_model_index_company( task_or_models = _assert_task_or_model_exists(
company_id, task_ids=task_ids, model_events=request.model_events company_id, task_ids=task_ids, model_events=request.model_events
) )
result = event_bll.plots_iterator.get_task_events( result = event_bll.plots_iterator.get_task_events(
company_id=company, companies={t.id: t.get_index_company() for t in task_or_models},
task_metrics=task_metrics, task_metrics=task_metrics,
iter_count=request.iters, iter_count=request.iters,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -824,12 +814,11 @@ def get_debug_images_v1_8(call, company_id, _):
def get_debug_images(call, company_id, request: MetricEventsRequest): def get_debug_images(call, company_id, request: MetricEventsRequest):
task_metrics = _task_metrics_dict_from_request(request.metrics) task_metrics = _task_metrics_dict_from_request(request.metrics)
task_ids = list(task_metrics) task_ids = list(task_metrics)
company, _ = _get_task_or_model_index_company( task_or_models = _assert_task_or_model_exists(
company_id, task_ids=task_ids, model_events=request.model_events company_id, task_ids=task_ids, model_events=request.model_events
) )
result = event_bll.debug_images_iterator.get_task_events( result = event_bll.debug_images_iterator.get_task_events(
company_id=company, companies={t.id: t.get_index_company() for t in task_or_models},
task_metrics=task_metrics, task_metrics=task_metrics,
iter_count=request.iters, iter_count=request.iters,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -922,11 +911,13 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
company, _ = _get_task_or_model_index_company( task_or_models = _assert_task_or_model_exists(
company_id, request.tasks, model_events=request.model_events company_id, request.tasks, model_events=request.model_events,
) )
res = event_bll.metrics.get_task_metrics( res = event_bll.metrics.get_task_metrics(
company, task_ids=request.tasks, event_type=request.event_type task_or_models[0].get_index_company(),
task_ids=request.tasks,
event_type=request.event_type,
) )
call.result.data = { call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]

View File

@ -1,5 +1,6 @@
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from itertools import chain
from typing import Sequence from typing import Sequence
from apiserver.apimodels.reports import ( from apiserver.apimodels.reports import (
@ -24,7 +25,7 @@ from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType, TaskStatus from apiserver.database.model.task.task import Task, TaskType, TaskStatus
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.services.events import ( from apiserver.services.events import (
_get_task_or_model_index_company, _get_task_or_model_index_companies,
event_bll, event_bll,
_get_metrics_response, _get_metrics_response,
_get_metric_variants_from_request, _get_metric_variants_from_request,
@ -214,10 +215,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
return res return res
task_ids = [task["id"] for task in tasks] task_ids = [task["id"] for task in tasks]
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
if request.debug_images: if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events( result = event_bll.debug_images_iterator.get_task_events(
company_id=company, companies={
t.id: t.company for t in chain.from_iterable(companies.values())
},
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images), task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
iter_count=request.debug_images.iters, iter_count=request.debug_images.iters,
) )
@ -227,8 +230,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
if request.plots: if request.plots:
res["plots"] = _get_multitask_plots( res["plots"] = _get_multitask_plots(
company=company, companies=companies,
tasks_or_models=tasks_or_models,
last_iters=request.plots.iters, last_iters=request.plots.iters,
metrics=_get_metric_variants_from_request(request.plots.metrics), metrics=_get_metric_variants_from_request(request.plots.metrics),
)[0] )[0]
@ -237,8 +239,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
res[ res[
"scalar_metrics_iter_histogram" "scalar_metrics_iter_histogram"
] = event_bll.metrics.compare_scalar_metrics_average_per_iter( ] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id=company_id, companies=companies,
tasks=tasks_or_models,
samples=request.scalar_metrics_iter_histogram.samples, samples=request.scalar_metrics_iter_histogram.samples,
key=request.scalar_metrics_iter_histogram.key, key=request.scalar_metrics_iter_histogram.key,
metric_variants=_get_metric_variants_from_request( metric_variants=_get_metric_variants_from_request(