mirror of
https://github.com/clearml/clearml-server
synced 2025-05-12 15:50:47 +00:00
Removed limit on event comparison for the same company tasks only
This commit is contained in:
parent
e66257761a
commit
14ff639bb0
@ -780,7 +780,7 @@ class EventBLL(object):
|
|||||||
|
|
||||||
def get_task_events(
|
def get_task_events(
|
||||||
self,
|
self,
|
||||||
company_id: str,
|
company_id: Union[str, Sequence[str]],
|
||||||
task_id: Union[str, Sequence[str]],
|
task_id: Union[str, Sequence[str]],
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
metrics: MetricVariants = None,
|
metrics: MetricVariants = None,
|
||||||
@ -798,10 +798,21 @@ class EventBLL(object):
|
|||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||||
else:
|
else:
|
||||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||||
|
company_ids = [
|
||||||
|
c_id
|
||||||
|
for c_id in set(company_ids)
|
||||||
|
if not check_empty_data(self.es, c_id, event_type)
|
||||||
|
]
|
||||||
|
if not company_ids:
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
|
|
||||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
task_ids = (
|
||||||
|
[task_id]
|
||||||
|
if isinstance(task_id, str)
|
||||||
|
else task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
must = []
|
must = []
|
||||||
if metrics:
|
if metrics:
|
||||||
@ -811,7 +822,7 @@ class EventBLL(object):
|
|||||||
must.append({"terms": {"task": task_ids}})
|
must.append({"terms": {"task": task_ids}})
|
||||||
else:
|
else:
|
||||||
tasks_iters = self.get_last_iters(
|
tasks_iters = self.get_last_iters(
|
||||||
company_id=company_id,
|
company_id=company_ids,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
task_id=task_ids,
|
task_id=task_ids,
|
||||||
iters=last_iter_count,
|
iters=last_iter_count,
|
||||||
@ -845,7 +856,7 @@ class EventBLL(object):
|
|||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
es_res = search_company_events(
|
es_res = search_company_events(
|
||||||
self.es,
|
self.es,
|
||||||
company_id=company_id,
|
company_id=company_ids,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
ignore=404,
|
ignore=404,
|
||||||
@ -1028,13 +1039,19 @@ class EventBLL(object):
|
|||||||
|
|
||||||
def get_last_iters(
|
def get_last_iters(
|
||||||
self,
|
self,
|
||||||
company_id: str,
|
company_id: Union[str, Sequence[str]],
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
task_id: Union[str, Sequence[str]],
|
task_id: Union[str, Sequence[str]],
|
||||||
iters: int,
|
iters: int,
|
||||||
metrics: MetricVariants = None
|
metrics: MetricVariants = None
|
||||||
) -> Mapping[str, Sequence]:
|
) -> Mapping[str, Sequence]:
|
||||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||||
|
company_ids = [
|
||||||
|
c_id
|
||||||
|
for c_id in set(company_ids)
|
||||||
|
if not check_empty_data(self.es, c_id, event_type)
|
||||||
|
]
|
||||||
|
if not company_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||||
@ -1063,7 +1080,7 @@ class EventBLL(object):
|
|||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
es_res = search_company_events(
|
es_res = search_company_events(
|
||||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
self.es, company_id=company_ids, event_type=event_type, body=es_req,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
|
@ -8,6 +8,7 @@ from elasticsearch import Elasticsearch
|
|||||||
|
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
|
from apiserver.database.model.task.task import Task
|
||||||
from apiserver.tools import safe_get
|
from apiserver.tools import safe_get
|
||||||
|
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ class EventType(Enum):
|
|||||||
|
|
||||||
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
SINGLE_SCALAR_ITERATION = -(2 ** 31)
|
||||||
MetricVariants = Mapping[str, Sequence[str]]
|
MetricVariants = Mapping[str, Sequence[str]]
|
||||||
|
TaskCompanies = Mapping[str, Sequence[Task]]
|
||||||
|
|
||||||
|
|
||||||
class EventSettings:
|
class EventSettings:
|
||||||
@ -52,9 +54,12 @@ class EventSettings:
|
|||||||
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
return int(self._max_es_allowed_aggregation_buckets * percentage)
|
||||||
|
|
||||||
|
|
||||||
def get_index_name(company_id: str, event_type: str):
|
def get_index_name(company_id: Union[str, Sequence[str]], event_type: str):
|
||||||
event_type = event_type.lower().replace(" ", "_")
|
event_type = event_type.lower().replace(" ", "_")
|
||||||
return f"events-{event_type}-{company_id.lower()}"
|
if isinstance(company_id, str):
|
||||||
|
company_id = [company_id]
|
||||||
|
|
||||||
|
return ",".join(f"events-{event_type}-{(c_id or '').lower()}" for c_id in company_id)
|
||||||
|
|
||||||
|
|
||||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||||
|
@ -18,11 +18,11 @@ from apiserver.bll.event.event_common import (
|
|||||||
get_metric_variants_condition,
|
get_metric_variants_condition,
|
||||||
get_max_metric_and_variant_counts,
|
get_max_metric_and_variant_counts,
|
||||||
SINGLE_SCALAR_ITERATION,
|
SINGLE_SCALAR_ITERATION,
|
||||||
|
TaskCompanies,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.database.model.task.task import Task
|
|
||||||
from apiserver.tools import safe_get
|
from apiserver.tools import safe_get
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
@ -108,8 +108,7 @@ class EventMetrics:
|
|||||||
|
|
||||||
def compare_scalar_metrics_average_per_iter(
|
def compare_scalar_metrics_average_per_iter(
|
||||||
self,
|
self,
|
||||||
company_id,
|
companies: TaskCompanies,
|
||||||
tasks: Sequence[Task],
|
|
||||||
samples,
|
samples,
|
||||||
key: ScalarKeyEnum,
|
key: ScalarKeyEnum,
|
||||||
metric_variants: MetricVariants = None,
|
metric_variants: MetricVariants = None,
|
||||||
@ -119,28 +118,41 @@ class EventMetrics:
|
|||||||
The amount of points in each histogram should not exceed the requested samples
|
The amount of points in each histogram should not exceed the requested samples
|
||||||
"""
|
"""
|
||||||
event_type = EventType.metrics_scalar
|
event_type = EventType.metrics_scalar
|
||||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
companies = {
|
||||||
|
company_id: tasks
|
||||||
|
for company_id, tasks in companies.items()
|
||||||
|
if not check_empty_data(
|
||||||
|
self.es, company_id=company_id, event_type=event_type
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if not companies:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
task_name_by_id = {t.id: t.name for t in tasks}
|
|
||||||
get_scalar_average_per_iter = partial(
|
get_scalar_average_per_iter = partial(
|
||||||
self._get_scalar_average_per_iter_core,
|
self._get_scalar_average_per_iter_core,
|
||||||
company_id=company_id,
|
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
samples=samples,
|
samples=samples,
|
||||||
key=ScalarKey.resolve(key),
|
key=ScalarKey.resolve(key),
|
||||||
metric_variants=metric_variants,
|
metric_variants=metric_variants,
|
||||||
run_parallel=False,
|
run_parallel=False,
|
||||||
)
|
)
|
||||||
task_ids = [t.id for t in tasks]
|
task_ids, company_ids = zip(
|
||||||
|
*(
|
||||||
|
(t.id, t.company)
|
||||||
|
for t in itertools.chain.from_iterable(companies.values())
|
||||||
|
)
|
||||||
|
)
|
||||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||||
task_metrics = zip(
|
task_metrics = zip(
|
||||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
task_ids, pool.map(get_scalar_average_per_iter, task_ids, company_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
task_names = {
|
||||||
|
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||||
|
}
|
||||||
res = defaultdict(lambda: defaultdict(dict))
|
res = defaultdict(lambda: defaultdict(dict))
|
||||||
for task_id, task_data in task_metrics:
|
for task_id, task_data in task_metrics:
|
||||||
task_name = task_name_by_id[task_id]
|
task_name = task_names[task_id]
|
||||||
for metric_key, metric_data in task_data.items():
|
for metric_key, metric_data in task_data.items():
|
||||||
for variant_key, variant_data in metric_data.items():
|
for variant_key, variant_data in metric_data.items():
|
||||||
variant_data["name"] = task_name
|
variant_data["name"] = task_name
|
||||||
@ -149,18 +161,27 @@ class EventMetrics:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def get_task_single_value_metrics(
|
def get_task_single_value_metrics(
|
||||||
self, company_id: str, tasks: Sequence[Task]
|
self, companies: TaskCompanies
|
||||||
) -> Mapping[str, dict]:
|
) -> Mapping[str, dict]:
|
||||||
"""
|
"""
|
||||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||||
"""
|
"""
|
||||||
if check_empty_data(
|
companies = {
|
||||||
|
company_id: [t.id for t in tasks]
|
||||||
|
for company_id, tasks in companies.items()
|
||||||
|
if not check_empty_data(
|
||||||
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||||
):
|
)
|
||||||
|
}
|
||||||
|
if not companies:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
task_ids = [t.id for t in tasks]
|
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
task_events = list(
|
||||||
|
itertools.chain.from_iterable(
|
||||||
|
pool.map(self._get_task_single_value_metrics, companies.items())
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _get_value(event: dict):
|
def _get_value(event: dict):
|
||||||
return {
|
return {
|
||||||
@ -174,8 +195,9 @@ class EventMetrics:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _get_task_single_value_metrics(
|
def _get_task_single_value_metrics(
|
||||||
self, company_id: str, task_ids: Sequence[str]
|
self, tasks: Tuple[str, Sequence[str]]
|
||||||
) -> Sequence[dict]:
|
) -> Sequence[dict]:
|
||||||
|
company_id, task_ids = tasks
|
||||||
es_req = {
|
es_req = {
|
||||||
"size": 10000,
|
"size": 10000,
|
||||||
"query": {
|
"query": {
|
||||||
|
@ -75,18 +75,25 @@ class MetricEventsIterator:
|
|||||||
|
|
||||||
def get_task_events(
|
def get_task_events(
|
||||||
self,
|
self,
|
||||||
company_id: str,
|
companies: Mapping[str, str],
|
||||||
task_metrics: Mapping[str, dict],
|
task_metrics: Mapping[str, dict],
|
||||||
iter_count: int,
|
iter_count: int,
|
||||||
navigate_earlier: bool = True,
|
navigate_earlier: bool = True,
|
||||||
refresh: bool = False,
|
refresh: bool = False,
|
||||||
state_id: str = None,
|
state_id: str = None,
|
||||||
) -> MetricEventsResult:
|
) -> MetricEventsResult:
|
||||||
if check_empty_data(self.es, company_id, self.event_type):
|
companies = {
|
||||||
|
task_id: company_id
|
||||||
|
for task_id, company_id in companies.items()
|
||||||
|
if not check_empty_data(
|
||||||
|
self.es, company_id=company_id, event_type=EventType.metrics_scalar
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if not companies:
|
||||||
return MetricEventsResult()
|
return MetricEventsResult()
|
||||||
|
|
||||||
def init_state(state_: MetricEventsScrollState):
|
def init_state(state_: MetricEventsScrollState):
|
||||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
state_.tasks = self._init_task_states(companies, task_metrics)
|
||||||
|
|
||||||
def validate_state(state_: MetricEventsScrollState):
|
def validate_state(state_: MetricEventsScrollState):
|
||||||
"""
|
"""
|
||||||
@ -95,7 +102,7 @@ class MetricEventsIterator:
|
|||||||
Refresh the state if requested
|
Refresh the state if requested
|
||||||
"""
|
"""
|
||||||
if refresh:
|
if refresh:
|
||||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
self._reinit_outdated_task_states(companies, state_, task_metrics)
|
||||||
|
|
||||||
with self.cache_manager.get_or_create_state(
|
with self.cache_manager.get_or_create_state(
|
||||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||||
@ -112,7 +119,7 @@ class MetricEventsIterator:
|
|||||||
pool.map(
|
pool.map(
|
||||||
partial(
|
partial(
|
||||||
self._get_task_metric_events,
|
self._get_task_metric_events,
|
||||||
company_id=company_id,
|
companies=companies,
|
||||||
iter_count=iter_count,
|
iter_count=iter_count,
|
||||||
navigate_earlier=navigate_earlier,
|
navigate_earlier=navigate_earlier,
|
||||||
specific_variants_requested=specific_variants_requested,
|
specific_variants_requested=specific_variants_requested,
|
||||||
@ -125,7 +132,7 @@ class MetricEventsIterator:
|
|||||||
|
|
||||||
def _reinit_outdated_task_states(
|
def _reinit_outdated_task_states(
|
||||||
self,
|
self,
|
||||||
company_id,
|
companies: Mapping[str, str],
|
||||||
state: MetricEventsScrollState,
|
state: MetricEventsScrollState,
|
||||||
task_metrics: Mapping[str, dict],
|
task_metrics: Mapping[str, dict],
|
||||||
):
|
):
|
||||||
@ -133,9 +140,7 @@ class MetricEventsIterator:
|
|||||||
Determine the metrics for which new event_type events were added
|
Determine the metrics for which new event_type events were added
|
||||||
since their states were initialized and re-init these states
|
since their states were initialized and re-init these states
|
||||||
"""
|
"""
|
||||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
tasks = Task.objects(id__in=list(task_metrics)).only("id", "metric_stats")
|
||||||
"id", "metric_stats"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_last_update_times_for_task_metrics(
|
def get_last_update_times_for_task_metrics(
|
||||||
task: Task,
|
task: Task,
|
||||||
@ -175,7 +180,7 @@ class MetricEventsIterator:
|
|||||||
if metrics_to_recalc:
|
if metrics_to_recalc:
|
||||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||||
|
|
||||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
updated_task_states = self._init_task_states(companies, task_metrics_to_recalc)
|
||||||
|
|
||||||
def merge_with_updated_task_states(
|
def merge_with_updated_task_states(
|
||||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||||
@ -205,14 +210,14 @@ class MetricEventsIterator:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _init_task_states(
|
def _init_task_states(
|
||||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
self, companies: Mapping[str, str], task_metrics: Mapping[str, dict]
|
||||||
) -> Sequence[TaskScrollState]:
|
) -> Sequence[TaskScrollState]:
|
||||||
"""
|
"""
|
||||||
Returned initialized metric scroll stated for the requested task metrics
|
Returned initialized metric scroll stated for the requested task metrics
|
||||||
"""
|
"""
|
||||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||||
task_metric_states = pool.map(
|
task_metric_states = pool.map(
|
||||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
partial(self._init_metric_states_for_task, companies=companies),
|
||||||
task_metrics.items(),
|
task_metrics.items(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -232,13 +237,14 @@ class MetricEventsIterator:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_metric_states_for_task(
|
def _init_metric_states_for_task(
|
||||||
self, task_metrics: Tuple[str, dict], company_id: str
|
self, task_metrics: Tuple[str, dict], companies: Mapping[str, str]
|
||||||
) -> Sequence[MetricState]:
|
) -> Sequence[MetricState]:
|
||||||
"""
|
"""
|
||||||
Return metric scroll states for the task filled with the variant states
|
Return metric scroll states for the task filled with the variant states
|
||||||
for the variants that reported any event_type events
|
for the variants that reported any event_type events
|
||||||
"""
|
"""
|
||||||
task, metrics = task_metrics
|
task, metrics = task_metrics
|
||||||
|
company_id = companies[task]
|
||||||
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
must = [{"term": {"task": task}}, *self._get_extra_conditions()]
|
||||||
if metrics:
|
if metrics:
|
||||||
must.append(get_metric_variants_condition(metrics))
|
must.append(get_metric_variants_condition(metrics))
|
||||||
@ -319,7 +325,7 @@ class MetricEventsIterator:
|
|||||||
def _get_task_metric_events(
|
def _get_task_metric_events(
|
||||||
self,
|
self,
|
||||||
task_state: TaskScrollState,
|
task_state: TaskScrollState,
|
||||||
company_id: str,
|
companies: Mapping[str, str],
|
||||||
iter_count: int,
|
iter_count: int,
|
||||||
navigate_earlier: bool,
|
navigate_earlier: bool,
|
||||||
specific_variants_requested: bool,
|
specific_variants_requested: bool,
|
||||||
@ -391,7 +397,10 @@ class MetricEventsIterator:
|
|||||||
}
|
}
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
es_res = search_company_events(
|
es_res = search_company_events(
|
||||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
self.es,
|
||||||
|
company_id=companies[task_state.task],
|
||||||
|
event_type=self.event_type,
|
||||||
|
body=es_req,
|
||||||
)
|
)
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return task_state.task, []
|
return task_state.task, []
|
||||||
|
@ -32,7 +32,7 @@ from apiserver.apimodels.events import (
|
|||||||
TaskMetric,
|
TaskMetric,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies
|
||||||
from apiserver.bll.event.events_iterator import Scroll
|
from apiserver.bll.event.events_iterator import Scroll
|
||||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||||
from apiserver.bll.model import ModelBLL
|
from apiserver.bll.model import ModelBLL
|
||||||
@ -464,12 +464,11 @@ def scalar_metrics_iter_histogram(
|
|||||||
call.result.data = metrics
|
call.result.data = metrics
|
||||||
|
|
||||||
|
|
||||||
def _get_task_or_model_index_company(
|
def _get_task_or_model_index_companies(
|
||||||
company_id: str, task_ids: Sequence[str], model_events=False,
|
company_id: str, task_ids: Sequence[str], model_events=False,
|
||||||
) -> Tuple[str, Sequence[Task]]:
|
) -> TaskCompanies:
|
||||||
"""
|
"""
|
||||||
Verify that all tasks exists and belong to store data in the same company index
|
Returns lists of tasks grouped by company
|
||||||
Return company and tasks
|
|
||||||
"""
|
"""
|
||||||
tasks_or_models = _assert_task_or_model_exists(
|
tasks_or_models = _assert_task_or_model_exists(
|
||||||
company_id, task_ids, model_events=model_events,
|
company_id, task_ids, model_events=model_events,
|
||||||
@ -485,13 +484,7 @@ def _get_task_or_model_index_company(
|
|||||||
)
|
)
|
||||||
raise error_cls(company=company_id, ids=invalid)
|
raise error_cls(company=company_id, ids=invalid)
|
||||||
|
|
||||||
companies = {t.get_index_company() for t in tasks_or_models}
|
return bucketize(tasks_or_models, key=lambda t: t.get_index_company())
|
||||||
if len(companies) > 1:
|
|
||||||
raise errors.bad_request.InvalidTaskId(
|
|
||||||
"only tasks from the same company are supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
return companies.pop(), tasks_or_models
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
@ -504,13 +497,12 @@ def multi_task_scalar_metrics_iter_histogram(
|
|||||||
task_ids = request.tasks
|
task_ids = request.tasks
|
||||||
if isinstance(task_ids, str):
|
if isinstance(task_ids, str):
|
||||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||||
company, tasks_or_models = _get_task_or_model_index_company(
|
|
||||||
company_id, task_ids, request.model_events
|
|
||||||
)
|
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||||
company_id=company,
|
companies=_get_task_or_model_index_companies(
|
||||||
tasks=tasks_or_models,
|
company_id, task_ids, request.model_events
|
||||||
|
),
|
||||||
samples=request.samples,
|
samples=request.samples,
|
||||||
key=request.key,
|
key=request.key,
|
||||||
)
|
)
|
||||||
@ -521,11 +513,11 @@ def multi_task_scalar_metrics_iter_histogram(
|
|||||||
def get_task_single_value_metrics(
|
def get_task_single_value_metrics(
|
||||||
call, company_id: str, request: SingleValueMetricsRequest
|
call, company_id: str, request: SingleValueMetricsRequest
|
||||||
):
|
):
|
||||||
company, tasks_or_models = _get_task_or_model_index_company(
|
res = event_bll.metrics.get_task_single_value_metrics(
|
||||||
|
companies=_get_task_or_model_index_companies(
|
||||||
company_id, request.tasks, request.model_events
|
company_id, request.tasks, request.model_events
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
|
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
tasks=[{"task": task, "values": values} for task, values in res.items()]
|
tasks=[{"task": task, "values": values} for task, values in res.items()]
|
||||||
)
|
)
|
||||||
@ -537,11 +529,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
|||||||
iters = call.data.get("iters", 1)
|
iters = call.data.get("iters", 1)
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = call.data.get("scroll_id")
|
||||||
|
|
||||||
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
|
companies = _get_task_or_model_index_companies(company_id, task_ids)
|
||||||
|
|
||||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
company,
|
list(companies),
|
||||||
task_ids,
|
task_ids,
|
||||||
event_type=EventType.metrics_plot,
|
event_type=EventType.metrics_plot,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
@ -549,8 +541,9 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
|||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_names = {t.id: t.name for t in tasks_or_models}
|
task_names = {
|
||||||
|
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||||
|
}
|
||||||
return_events = _get_top_iter_unique_events_per_task(
|
return_events = _get_top_iter_unique_events_per_task(
|
||||||
result.events, max_iters=iters, task_names=task_names
|
result.events, max_iters=iters, task_names=task_names
|
||||||
)
|
)
|
||||||
@ -564,18 +557,18 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
|||||||
|
|
||||||
|
|
||||||
def _get_multitask_plots(
|
def _get_multitask_plots(
|
||||||
company: str,
|
companies: TaskCompanies,
|
||||||
tasks_or_models: Sequence[Task],
|
|
||||||
last_iters: int,
|
last_iters: int,
|
||||||
metrics: MetricVariants = None,
|
metrics: MetricVariants = None,
|
||||||
scroll_id=None,
|
scroll_id=None,
|
||||||
no_scroll=True,
|
no_scroll=True,
|
||||||
model_events=False,
|
model_events=False,
|
||||||
) -> Tuple[dict, int, str]:
|
) -> Tuple[dict, int, str]:
|
||||||
task_names = {t.id: t.name for t in tasks_or_models}
|
task_names = {
|
||||||
|
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||||
|
}
|
||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
company_id=company,
|
company_id=list(companies),
|
||||||
task_id=list(task_names),
|
task_id=list(task_names),
|
||||||
event_type=EventType.metrics_plot,
|
event_type=EventType.metrics_plot,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
@ -599,13 +592,11 @@ def get_multi_task_plots(call, company_id, _):
|
|||||||
no_scroll = call.data.get("no_scroll", False)
|
no_scroll = call.data.get("no_scroll", False)
|
||||||
model_events = call.data.get("model_events", False)
|
model_events = call.data.get("model_events", False)
|
||||||
|
|
||||||
company, tasks_or_models = _get_task_or_model_index_company(
|
companies = _get_task_or_model_index_companies(
|
||||||
company_id, task_ids, model_events
|
company_id, task_ids, model_events=model_events
|
||||||
)
|
)
|
||||||
|
|
||||||
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
||||||
company=company,
|
companies=companies,
|
||||||
tasks_or_models=tasks_or_models,
|
|
||||||
last_iters=iters,
|
last_iters=iters,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
no_scroll=no_scroll,
|
no_scroll=no_scroll,
|
||||||
@ -728,12 +719,11 @@ def _get_metrics_response(metric_events: Sequence[tuple]) -> Sequence[MetricEven
|
|||||||
def task_plots(call, company_id, request: MetricEventsRequest):
|
def task_plots(call, company_id, request: MetricEventsRequest):
|
||||||
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
||||||
task_ids = list(task_metrics)
|
task_ids = list(task_metrics)
|
||||||
company, _ = _get_task_or_model_index_company(
|
task_or_models = _assert_task_or_model_exists(
|
||||||
company_id, task_ids=task_ids, model_events=request.model_events
|
company_id, task_ids=task_ids, model_events=request.model_events
|
||||||
)
|
)
|
||||||
|
|
||||||
result = event_bll.plots_iterator.get_task_events(
|
result = event_bll.plots_iterator.get_task_events(
|
||||||
company_id=company,
|
companies={t.id: t.get_index_company() for t in task_or_models},
|
||||||
task_metrics=task_metrics,
|
task_metrics=task_metrics,
|
||||||
iter_count=request.iters,
|
iter_count=request.iters,
|
||||||
navigate_earlier=request.navigate_earlier,
|
navigate_earlier=request.navigate_earlier,
|
||||||
@ -824,12 +814,11 @@ def get_debug_images_v1_8(call, company_id, _):
|
|||||||
def get_debug_images(call, company_id, request: MetricEventsRequest):
|
def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||||
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
task_metrics = _task_metrics_dict_from_request(request.metrics)
|
||||||
task_ids = list(task_metrics)
|
task_ids = list(task_metrics)
|
||||||
company, _ = _get_task_or_model_index_company(
|
task_or_models = _assert_task_or_model_exists(
|
||||||
company_id, task_ids=task_ids, model_events=request.model_events
|
company_id, task_ids=task_ids, model_events=request.model_events
|
||||||
)
|
)
|
||||||
|
|
||||||
result = event_bll.debug_images_iterator.get_task_events(
|
result = event_bll.debug_images_iterator.get_task_events(
|
||||||
company_id=company,
|
companies={t.id: t.get_index_company() for t in task_or_models},
|
||||||
task_metrics=task_metrics,
|
task_metrics=task_metrics,
|
||||||
iter_count=request.iters,
|
iter_count=request.iters,
|
||||||
navigate_earlier=request.navigate_earlier,
|
navigate_earlier=request.navigate_earlier,
|
||||||
@ -922,11 +911,13 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
|
|||||||
|
|
||||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||||
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||||
company, _ = _get_task_or_model_index_company(
|
task_or_models = _assert_task_or_model_exists(
|
||||||
company_id, request.tasks, model_events=request.model_events
|
company_id, request.tasks, model_events=request.model_events,
|
||||||
)
|
)
|
||||||
res = event_bll.metrics.get_task_metrics(
|
res = event_bll.metrics.get_task_metrics(
|
||||||
company, task_ids=request.tasks, event_type=request.event_type
|
task_or_models[0].get_index_company(),
|
||||||
|
task_ids=request.tasks,
|
||||||
|
event_type=request.event_type,
|
||||||
)
|
)
|
||||||
call.result.data = {
|
call.result.data = {
|
||||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from itertools import chain
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from apiserver.apimodels.reports import (
|
from apiserver.apimodels.reports import (
|
||||||
@ -24,7 +25,7 @@ from apiserver.database.model.project import Project
|
|||||||
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
|
from apiserver.database.model.task.task import Task, TaskType, TaskStatus
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.services.events import (
|
from apiserver.services.events import (
|
||||||
_get_task_or_model_index_company,
|
_get_task_or_model_index_companies,
|
||||||
event_bll,
|
event_bll,
|
||||||
_get_metrics_response,
|
_get_metrics_response,
|
||||||
_get_metric_variants_from_request,
|
_get_metric_variants_from_request,
|
||||||
@ -214,10 +215,12 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
task_ids = [task["id"] for task in tasks]
|
task_ids = [task["id"] for task in tasks]
|
||||||
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
|
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
|
||||||
if request.debug_images:
|
if request.debug_images:
|
||||||
result = event_bll.debug_images_iterator.get_task_events(
|
result = event_bll.debug_images_iterator.get_task_events(
|
||||||
company_id=company,
|
companies={
|
||||||
|
t.id: t.company for t in chain.from_iterable(companies.values())
|
||||||
|
},
|
||||||
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
|
task_metrics=_get_task_metrics_from_request(task_ids, request.debug_images),
|
||||||
iter_count=request.debug_images.iters,
|
iter_count=request.debug_images.iters,
|
||||||
)
|
)
|
||||||
@ -227,8 +230,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
|||||||
|
|
||||||
if request.plots:
|
if request.plots:
|
||||||
res["plots"] = _get_multitask_plots(
|
res["plots"] = _get_multitask_plots(
|
||||||
company=company,
|
companies=companies,
|
||||||
tasks_or_models=tasks_or_models,
|
|
||||||
last_iters=request.plots.iters,
|
last_iters=request.plots.iters,
|
||||||
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
metrics=_get_metric_variants_from_request(request.plots.metrics),
|
||||||
)[0]
|
)[0]
|
||||||
@ -237,8 +239,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
|||||||
res[
|
res[
|
||||||
"scalar_metrics_iter_histogram"
|
"scalar_metrics_iter_histogram"
|
||||||
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
] = event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||||
company_id=company_id,
|
companies=companies,
|
||||||
tasks=tasks_or_models,
|
|
||||||
samples=request.scalar_metrics_iter_histogram.samples,
|
samples=request.scalar_metrics_iter_histogram.samples,
|
||||||
key=request.scalar_metrics_iter_histogram.key,
|
key=request.scalar_metrics_iter_histogram.key,
|
||||||
metric_variants=_get_metric_variants_from_request(
|
metric_variants=_get_metric_variants_from_request(
|
||||||
|
Loading…
Reference in New Issue
Block a user