mirror of
https://github.com/clearml/clearml-server
synced 2025-05-10 14:50:44 +00:00
Removed limit on event comparison for the same company tasks only
This commit is contained in:
parent
e66257761a
commit
14ff639bb0
@ -780,7 +780,7 @@ class EventBLL(object):
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
task_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
metrics: MetricVariants = None,
|
||||
@ -798,10 +798,21 @@ class EventBLL(object):
|
||||
with translate_errors_context():
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||
company_ids = [
|
||||
c_id
|
||||
for c_id in set(company_ids)
|
||||
if not check_empty_data(self.es, c_id, event_type)
|
||||
]
|
||||
if not company_ids:
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
task_ids = (
|
||||
[task_id]
|
||||
if isinstance(task_id, str)
|
||||
else task_id
|
||||
)
|
||||
|
||||
|
||||
must = []
|
||||
if metrics:
|
||||
@ -811,7 +822,7 @@ class EventBLL(object):
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
@ -845,7 +856,7 @@ class EventBLL(object):
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
company_id=company_ids,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
@ -1028,13 +1039,19 @@ class EventBLL(object):
|
||||
|
||||
def get_last_iters(
|
||||
self,
|
||||
company_id: str,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
metrics: MetricVariants = None
|
||||
) -> Mapping[str, Sequence]:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||
company_ids = [
|
||||
c_id
|
||||
for c_id in set(company_ids)
|
||||
if not check_empty_data(self.es, c_id, event_type)
|
||||
]
|
||||
if not company_ids:
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
@ -1063,7 +1080,7 @@ class EventBLL(object):
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
self.es, company_id=company_ids, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
|
@ -8,6 +8,7 @@ from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
|
||||
@ -22,6 +23,7 @@ class EventType(Enum):
|
||||
|
||||
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
TaskCompanies = Mapping[str, Sequence[Task]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@ -52,9 +54,12 @@ class EventSettings:
|
||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||
|
||||
|
||||
def get_index_name(company_id: str, event_type: str):
|
||||
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id.lower()}"
|
||||
if isinstance(company_id, str):
|
||||
company_id = [company_id]
|
||||
|
||||
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
|
@ -18,11 +18,11 @@ from apiserver.bll.event.event_common import (
|
||||
get_metric_variants_condition,
|
||||
get_max_metric_and_variant_counts,
|
||||
SINGLE_SCALAR_ITERATION,
|
||||
TaskCompanies,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -108,8 +108,7 @@ class EventMetrics:
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
tasks: Sequence[Task],
|
||||
companies: TaskCompanies,
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
@ -119,28 +118,41 @@ class EventMetrics:
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
companies = {
|
||||
company_id: tasks
|
||||
for company_id, tasks in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=event_type
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return {}
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in tasks}
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
run_parallel=False,
|
||||
)
|
||||
task_ids = [t.id for t in tasks]
|
||||
task_ids, company_ids = zip(
|
||||
*(
|
||||
(t.id, t.company)
|
||||
for t in itertools.chain.from_iterable(companies.values())
|
||||
)
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
|
||||
)
|
||||
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
task_name = task_names[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
@ -149,18 +161,27 @@ class EventMetrics:
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, company_id: str, tasks: Sequence[Task]
|
||||
self, companies: TaskCompanies
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
if check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
):
|
||||
companies = {
|
||||
company_id: [t.id for t in tasks]
|
||||
for company_id, tasks in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return {}
|
||||
|
||||
task_ids = [t.id for t in tasks]
|
||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_events = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(self._get_task_single_value_metrics, companies.items())
|
||||
),
|
||||
)
|
||||
|
||||
def _get_value(event: dict):
|
||||
return {
|
||||
@ -174,8 +195,9 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
def _get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
self, tasks: Tuple[str, Sequence[str]]
|
||||
) -> Sequence[dict]:
|
||||
company_id, task_ids = tasks
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
|
@ -75,18 +75,25 @@ class MetricEventsIterator:
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
companies: Mapping[str, str],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> MetricEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.event_type):
|
||||
companies = {
|
||||
task_id: company_id
|
||||
for task_id, company_id in companies.items()
|
||||
if not check_empty_data(
|
||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||
)
|
||||
}
|
||||
if not companies:
|
||||
return MetricEventsResult()
|
||||
|
||||
def init_state(state_: MetricEventsScrollState):
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
state_.tasks = self._init_task_states(companies, task_metrics)
|
||||
|
||||
def validate_state(state_: MetricEventsScrollState):
|
||||
"""
|
||||
@ -95,7 +102,7 @@ class MetricEventsIterator:
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if refresh:
|
||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
||||
self._reinit_outdated_task_states(companies, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
@ -112,7 +119,7 @@ class MetricEventsIterator:
|
||||
pool.map(
|
||||
partial(
|
||||
self._get_task_metric_events,
|
||||
company_id=company_id,
|
||||
companies=companies,
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
specific_variants_requested=specific_variants_requested,
|
||||
@ -125,7 +132,7 @@ class MetricEventsIterator:
|
||||
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
companies: Mapping[str, str],
|
||||
state: MetricEventsScrollState,
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
@ -133,9 +140,7 @@ class MetricEventsIterator:
|
||||
Determine the metrics for which new event_type events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
|
||||
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
@ -175,7 +180,7 @@ class MetricEventsIterator:
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
||||
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
@ -205,14 +210,14 @@ class MetricEventsIterator:
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
||||
partial(self._init_metric_states_for_task, companies=companies),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
@ -232,13 +237,14 @@ class MetricEventsIterator:
|
||||
pass
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any event_type events
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
company_id = companies[task]
|
||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||
if metrics:
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
@ -319,7 +325,7 @@ class MetricEventsIterator:
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
companies: Mapping[str, str],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
specific_variants_requested: bool,
|
||||
@ -391,7 +397,10 @@ class MetricEventsIterator:
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
||||
self.es,
|
||||
company_id=companies[task_state.task],
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return task_state.task, []
|
||||
|
@ -32,7 +32,7 @@ from apiserver.apimodels.events import (
|
||||
TaskMetric,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||
from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.model import ModelBLL
|
||||
@ -464,12 +464,11 @@ def scalar_metrics_iter_histogram(
|
||||
call.result.data = metrics
|
||||
|
||||
|
||||
def _get_task_or_model_index_company(
|
||||
def _get_task_or_model_index_companies(
|
||||
company_id: str, task_ids: Sequence[str], model_events=False,
|
||||
) -> Tuple[str, Sequence[Task]]:
|
||||
) -> TaskCompanies:
|
||||
"""
|
||||
Verify that all tasks exists and belong to store data in the same company index
|
||||
Return company and tasks
|
||||
Returns lists of tasks grouped by company
|
||||
"""
|
||||
tasks_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids, model_events=model_events,
|
||||
@ -485,13 +484,7 @@ def _get_task_or_model_index_company(
|
||||
)
|
||||
raise error_cls(company=company_id, ids=invalid)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks_or_models}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
return companies.pop(), tasks_or_models
|
||||
return bucketize(tasks_or_models, key=lambda t: t.get_index_company())
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -504,13 +497,12 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
task_ids = request.tasks
|
||||
if isinstance(task_ids, str):
|
||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||
company, tasks_or_models = _get_task_or_model_index_company(
|
||||
company_id, task_ids, request.model_events
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||
company_id=company,
|
||||
tasks=tasks_or_models,
|
||||
companies=_get_task_or_model_index_companies(
|
||||
company_id, task_ids, request.model_events
|
||||
),
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
)
|
||||
@ -521,11 +513,11 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
def get_task_single_value_metrics(
|
||||
call, company_id: str, request: SingleValueMetricsRequest
|
||||
):
|
||||
company, tasks_or_models = _get_task_or_model_index_company(
|
||||
company_id, request.tasks, request.model_events
|
||||
res = event_bll.metrics.get_task_single_value_metrics(
|
||||
companies=_get_task_or_model_index_companies(
|
||||
company_id, request.tasks, request.model_events
|
||||
),
|
||||
)
|
||||
|
||||
res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
|
||||
call.result.data = dict(
|
||||
tasks=[{"task": task, "values": values} for task, values in res.items()]
|
||||
)
|
||||
@ -537,11 +529,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company,
|
||||
list(companies),
|
||||
task_ids,
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@ -549,8 +541,9 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
task_names = {t.id: t.name for t in tasks_or_models}
|
||||
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, task_names=task_names
|
||||
)
|
||||
@ -564,18 +557,18 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
|
||||
|
||||
def _get_multitask_plots(
|
||||
company: str,
|
||||
tasks_or_models: Sequence[Task],
|
||||
companies: TaskCompanies,
|
||||
last_iters: int,
|
||||
metrics: MetricVariants = None,
|
||||
scroll_id=None,
|
||||
no_scroll=True,
|
||||
model_events=False,
|
||||
) -> Tuple[dict, int, str]:
|
||||
task_names = {t.id: t.name for t in tasks_or_models}
|
||||
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
}
|
||||
result = event_bll.get_task_events(
|
||||
company_id=company,
|
||||
company_id=list(companies),
|
||||
task_id=list(task_names),
|
||||
event_type=EventType.metrics_plot,
|
||||
metrics=metrics,
|
||||
@ -599,13 +592,11 @@ def get_multi_task_plots(call, company_id, _):
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
company, tasks_or_models = _get_task_or_model_index_company(
|
||||
company_id, task_ids, model_events
|
||||
companies = _get_task_or_model_index_companies(
|
||||
company_id, task_ids, model_events=model_events
|
||||
)
|
||||
|
||||
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
||||
company=company,
|
||||
tasks_or_models=tasks_or_models,
|
||||
companies=companies,
|
||||
last_iters=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
@ -728,12 +719,11 @@ def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEven
|
||||
def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
||||
task_ids = list(task_metrics)
|
||||
company, _ = _get_task_or_model_index_company(
|
||||
task_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids=task_ids, model_events=request.model_events
|
||||
)
|
||||
|
||||
result = event_bll.plots_iterator.get_task_events(
|
||||
company_id=company,
|
||||
companies={t.id: t.get_index_company() for t in task_or_models},
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -824,12 +814,11 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
||||
task_ids = list(task_metrics)
|
||||
company, _ = _get_task_or_model_index_company(
|
||||
task_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids=task_ids, model_events=request.model_events
|
||||
)
|
||||
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
companies={t.id: t.get_index_company() for t in task_or_models},
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -922,11 +911,13 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
company, _ = _get_task_or_model_index_company(
|
||||
company_id, request.tasks, model_events=request.model_events
|
||||
task_or_models = _assert_task_or_model_exists(
|
||||
company_id, request.tasks, model_events=request.model_events,
|
||||
)
|
||||
res = event_bll.metrics.get_task_metrics(
|
||||
company, task_ids=request.tasks, event_type=request.event_type
|
||||
task_or_models[0].get_index_company(),
|
||||
task_ids=request.tasks,
|
||||
event_type=request.event_type,
|
||||
)
|
||||
call.result.data = {
|
||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||
|
@ -1,5 +1,6 @@
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from itertools import chain
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.reports import (
|
||||
@ -24,7 +25,7 @@ from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.events import (
|
||||
_get_task_or_model_index_company,
|
||||
_get_task_or_model_index_companies,
|
||||
event_bll,
|
||||
_get_metrics_response,
|
||||
_get_metric_variants_from_request,
|
||||
@ -214,10 +215,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
return res
|
||||
|
||||
task_ids = [task["id"] for task in tasks]
|
||||
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
|
||||
if request.debug_images:
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
companies={
|
||||
t.id: t.company for t in chain.from_iterable(companies.values())
|
||||
},
|
||||
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
|
||||
iter_count=request.debug_images.iters,
|
||||
)
|
||||
@ -227,8 +230,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
|
||||
if request.plots:
|
||||
res["plots"] = _get_multitask_plots(
|
||||
company=company,
|
||||
tasks_or_models=tasks_or_models,
|
||||
companies=companies,
|
||||
last_iters=request.plots.iters,
|
||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||
)[0]
|
||||
@ -237,8 +239,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
res[
|
||||
"scalar_metrics_iter_histogram"
|
||||
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||
company_id=company_id,
|
||||
tasks=tasks_or_models,
|
||||
companies=companies,
|
||||
samples=request.scalar_metrics_iter_histogram.samples,
|
||||
key=request.scalar_metrics_iter_histogram.key,
|
||||
metric_variants=_get_metric_variants_from_request(
|
||||
|
Loading…
Reference in New Issue
Block a user