Support querying task events per specific metrics and variants

This commit is contained in:
allegroai 2021-07-25 14:29:41 +03:00
parent 677bb3ba6d
commit 9069cfe1da
7 changed files with 245 additions and 78 deletions

View File

@ -14,12 +14,20 @@ from apiserver.utilities.stringenum import StringEnum
class HistogramRequestBase(Base): class HistogramRequestBase(Base):
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)]) samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class MetricVariants(Base):
metric: str = StringField(required=True)
variants: Sequence[str] = ListField(
items_types=str, validators=Length(minimum_value=1)
)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase): class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True) task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@ -39,6 +47,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base): class TaskMetric(Base):
task: str = StringField(required=True) task: str = StringField(required=True)
metric: str = StringField(default=None) metric: str = StringField(default=None)
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base): class DebugImagesRequest(Base):
@ -59,8 +68,8 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant): class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField() iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False) refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
class NextDebugImageSampleRequest(Base): class NextDebugImageSampleRequest(Base):
@ -102,3 +111,10 @@ class TaskMetricsRequest(Base):
items_types=str, validators=[Length(minimum_value=1)] items_types=str, validators=[Length(minimum_value=1)]
) )
event_type: EventType = ActualEnumField(EventType, required=True) event_type: EventType = ActualEnumField(EventType, required=True)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set from typing import Sequence, Tuple, Optional, Mapping
import attr import attr
import dpath import dpath
@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
check_empty_data, check_empty_data,
search_company_events, search_company_events,
EventType, EventType,
get_metric_variants_condition,
) )
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
@ -74,7 +75,7 @@ class DebugImagesIterator:
def get_task_events( def get_task_events(
self, self,
company_id: str, company_id: str,
task_metrics: Mapping[str, Set[str]], task_metrics: Mapping[str, dict],
iter_count: int, iter_count: int,
navigate_earlier: bool = True, navigate_earlier: bool = True,
refresh: bool = False, refresh: bool = False,
@ -118,7 +119,7 @@ class DebugImagesIterator:
self, self,
company_id, company_id,
state: DebugImageEventsScrollState, state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]], task_metrics: Mapping[str, dict],
): ):
""" """
Determine the metrics for which new debug image events were added Determine the metrics for which new debug image events were added
@ -158,11 +159,11 @@ class DebugImagesIterator:
task_metrics_to_recalc = {} task_metrics_to_recalc = {}
for task, metrics_times in update_times.items(): for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task] old_metric_states = task_metric_states[task]
metrics_to_recalc = set( metrics_to_recalc = {
m m: task_metrics[task].get(m)
for m, t in metrics_times.items() for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t if m not in old_metric_states or old_metric_states[m].timestamp < t
) }
if metrics_to_recalc: if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc task_metrics_to_recalc[task] = metrics_to_recalc
@ -196,7 +197,7 @@ class DebugImagesIterator:
] ]
def _init_task_states( def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]] self, company_id: str, task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]: ) -> Sequence[TaskScrollState]:
""" """
Returned initialized metric scroll stated for the requested task metrics Returned initialized metric scroll stated for the requested task metrics
@ -213,7 +214,7 @@ class DebugImagesIterator:
] ]
def _init_metric_states_for_task( def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Set[str]], company_id: str self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]: ) -> Sequence[MetricState]:
""" """
Return metric scroll states for the task filled with the variant states Return metric scroll states for the task filled with the variant states
@ -222,10 +223,11 @@ class DebugImagesIterator:
task, metrics = task_metrics task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}] must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics: if metrics:
must.append({"terms": {"metric": list(metrics)}}) must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
es_req: dict = { es_req: dict = {
"size": 0, "size": 0,
"query": {"bool": {"must": must}}, "query": query,
"aggs": { "aggs": {
"metrics": { "metrics": {
"terms": { "terms": {

View File

@ -6,9 +6,8 @@ from collections import defaultdict
from contextlib import closing from contextlib import closing
from datetime import datetime from datetime import datetime
from operator import attrgetter from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, Dict from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import six
from elasticsearch import helpers from elasticsearch import helpers
from elasticsearch.helpers import BulkIndexError from elasticsearch.helpers import BulkIndexError
from mongoengine import Q from mongoengine import Q
@ -22,6 +21,8 @@ from apiserver.bll.event.event_common import (
check_empty_data, check_empty_data,
search_company_events, search_company_events,
delete_company_events, delete_company_events,
MetricVariants,
get_metric_variants_condition,
) )
from apiserver.bll.util import parallel_chunked_decorator from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils from apiserver.database import utils as dbutils
@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
# noinspection PyTypeChecker # noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType)) EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2**63 - 1 MAX_LONG = 2 ** 63 - 1
MIN_LONG = -2**63 MIN_LONG = -(2 ** 63)
class PlotFields: class PlotFields:
@ -94,7 +95,7 @@ class EventBLL(object):
def add_events( def add_events(
self, company_id, events, worker, allow_locked_tasks=False self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]: ) -> Tuple[int, int, dict]:
actions = [] actions: List[dict] = []
task_ids = set() task_ids = set()
task_iteration = defaultdict(lambda: 0) task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict( task_last_scalar_events = nested_dict(
@ -197,7 +198,6 @@ class EventBLL(object):
actions.append(es_action) actions.append(es_action)
action: Dict[dict]
plot_actions = [ plot_actions = [
action["_source"] action["_source"]
for action in actions for action in actions
@ -260,7 +260,8 @@ class EventBLL(object):
invalid_iterations_count = errors_per_type.get(invalid_iteration_error) invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count: if invalid_iterations_count:
raise BulkIndexError( raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error] f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
) )
if not added: if not added:
@ -466,10 +467,16 @@ class EventBLL(object):
task_id: str, task_id: str,
num_last_iterations: int, num_last_iterations: int,
event_type: EventType, event_type: EventType,
metric_variants: MetricVariants = None,
): ):
if check_empty_data(self.es, company_id=company_id, event_type=event_type): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return [] return []
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req: dict = { es_req: dict = {
"size": 0, "size": 0,
"aggs": { "aggs": {
@ -499,7 +506,7 @@ class EventBLL(object):
}, },
} }
}, },
"query": {"bool": {"must": [{"term": {"task": task_id}}]}}, "query": query,
} }
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
@ -527,6 +534,7 @@ class EventBLL(object):
sort=None, sort=None,
size: int = 500, size: int = 500,
scroll_id: str = None, scroll_id: str = None,
metric_variants: MetricVariants = None,
): ):
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
return TaskEventsResult() return TaskEventsResult()
@ -555,6 +563,8 @@ class EventBLL(object):
if last_iterations_per_plot is None: if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}}) must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else: else:
should = [] should = []
for i, task_id in enumerate(tasks): for i, task_id in enumerate(tasks):
@ -563,6 +573,7 @@ class EventBLL(object):
task_id=task_id, task_id=task_id,
num_last_iterations=last_iterations_per_plot, num_last_iterations=last_iterations_per_plot,
event_type=event_type, event_type=event_type,
metric_variants=metric_variants,
) )
if not last_iters: if not last_iters:
continue continue
@ -669,19 +680,19 @@ class EventBLL(object):
sort=None, sort=None,
size=500, size=500,
scroll_id=None, scroll_id=None,
): ) -> TaskEventsResult:
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
return [], scroll_id, 0 return TaskEventsResult()
if scroll_id: if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else: else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if check_empty_data(self.es, company_id=company_id, event_type=event_type): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult() return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = [] must = []
if metric: if metric:
must.append({"term": {"metric": metric}}) must.append({"term": {"metric": metric}})
@ -691,26 +702,24 @@ class EventBLL(object):
if last_iter_count is None: if last_iter_count is None:
must.append({"terms": {"task": task_ids}}) must.append({"terms": {"task": task_ids}})
else: else:
should = [] tasks_iters = self.get_last_iters(
for i, task_id in enumerate(task_ids): company_id=company_id,
last_iters = self.get_last_iters( event_type=event_type,
company_id=company_id, task_id=task_ids,
event_type=event_type, iters=last_iter_count,
task_id=task_id, )
iters=last_iter_count, should = [
) {
if not last_iters: "bool": {
continue "must": [
should.append( {"term": {"task": task}},
{ {"terms": {"iter": last_iters}},
"bool": { ]
"must": [
{"term": {"task": task_id}},
{"terms": {"iter": last_iters}},
]
}
} }
) }
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should: if not should:
return TaskEventsResult() return TaskEventsResult()
must.append({"bool": {"should": should}}) must.append({"bool": {"should": should}})
@ -748,6 +757,7 @@ class EventBLL(object):
if check_empty_data(self.es, company_id=company_id, event_type=event_type): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {} return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
es_req = { es_req = {
"size": 0, "size": 0,
"aggs": { "aggs": {
@ -768,7 +778,7 @@ class EventBLL(object):
}, },
} }
}, },
"query": {"bool": {"must": [{"term": {"task": task_id}}]}}, "query": query,
} }
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
@ -787,21 +797,24 @@ class EventBLL(object):
return metrics return metrics
def get_task_latest_scalar_values(self, company_id: str, task_id: str): def get_task_latest_scalar_values(
self, company_id, task_id
) -> Tuple[Sequence[dict], int]:
event_type = EventType.metrics_scalar event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {} return [], 0
query = {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
}
es_req = { es_req = {
"size": 0, "size": 0,
"query": { "query": query,
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
},
"aggs": { "aggs": {
"metrics": { "metrics": {
"terms": { "terms": {
@ -905,11 +918,16 @@ class EventBLL(object):
return iterations, vectors return iterations, vectors
def get_last_iters( def get_last_iters(
self, company_id: str, event_type: EventType, task_id: str, iters: int self,
): company_id: str,
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return [] return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
es_req: dict = { es_req: dict = {
"size": 0, "size": 0,
"aggs": { "aggs": {
@ -921,7 +939,7 @@ class EventBLL(object):
} }
} }
}, },
"query": {"bool": {"must": [{"term": {"task": task_id}}]}}, "query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
} }
with translate_errors_context(), TimingContext("es", "task_last_iter"): with translate_errors_context(), TimingContext("es", "task_last_iter"):
@ -930,9 +948,12 @@ class EventBLL(object):
) )
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return {}
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]] return {
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
def delete_task_events(self, company_id, task_id, allow_locked=False): def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context(): with translate_errors_context():
@ -965,7 +986,9 @@ class EventBLL(object):
so it should be checked by the calling code so it should be checked by the calling code
""" """
es_req = {"query": {"terms": {"task": task_ids}}} es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"): with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
es_res = delete_company_events( es_res = delete_company_events(
es=self.es, es=self.es,
company_id=company_id, company_id=company_id,

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Union, Sequence from typing import Union, Sequence, Mapping
from boltons.typeutils import classproperty from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@ -16,6 +16,9 @@ class EventType(Enum):
all = "*" all = "*"
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings: class EventSettings:
@classproperty @classproperty
def max_workers(self): def max_workers(self):
@ -64,3 +67,23 @@ def delete_company_events(
) -> dict: ) -> dict:
es_index = get_index_name(company_id, event_type.value) es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs) return es.delete_by_query(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(
metric_variants: MetricVariants,
) -> Sequence:
conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"terms": {"variant": variants}},
]
}
}
if variants
else {"term": {"metric": metric}}
for metric, variants in metric_variants.items()
]
return {"bool": {"should": conditions}}

View File

@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
EventSettings, EventSettings,
search_company_events, search_company_events,
check_empty_data, check_empty_data,
MetricVariants,
get_metric_variants_condition,
) )
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config from apiserver.config_repo import config
@ -34,7 +36,12 @@ class EventMetrics:
self.es = es self.es = es
def get_scalar_metrics_average_per_iter( def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict: ) -> dict:
""" """
Get scalar metric histogram per metric and variant Get scalar metric histogram per metric and variant
@ -46,7 +53,12 @@ class EventMetrics:
return {} return {}
return self._get_scalar_average_per_iter_core( return self._get_scalar_average_per_iter_core(
task_id, company_id, event_type, samples, ScalarKey.resolve(key) task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
) )
def _get_scalar_average_per_iter_core( def _get_scalar_average_per_iter_core(
@ -57,6 +69,7 @@ class EventMetrics:
samples: int, samples: int,
key: ScalarKey, key: ScalarKey,
run_parallel: bool = True, run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict: ) -> dict:
intervals = self._get_task_metric_intervals( intervals = self._get_task_metric_intervals(
company_id=company_id, company_id=company_id,
@ -64,6 +77,7 @@ class EventMetrics:
task_id=task_id, task_id=task_id,
samples=samples, samples=samples,
field=key.field, field=key.field,
metric_variants=metric_variants,
) )
if not intervals: if not intervals:
return {} return {}
@ -197,6 +211,7 @@ class EventMetrics:
task_id: str, task_id: str,
samples: int, samples: int,
field: str = "iter", field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]: ) -> Sequence[MetricInterval]:
""" """
Calculate interval per task metric variant so that the resulting Calculate interval per task metric variant so that the resulting
@ -204,9 +219,14 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple: Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples) (metric, variant, interval, samples)
""" """
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req = { es_req = {
"size": 0, "size": 0,
"query": {"term": {"task": task_id}}, "query": query,
"aggs": { "aggs": {
"metrics": { "metrics": {
"terms": { "terms": {

View File

@ -1,6 +1,18 @@
{ {
_description : "Provides an API for running tasks to report events collected by the system." _description : "Provides an API for running tasks to report events collected by the system."
_definitions { _definitions {
metric_variants {
type: object
metric {
description: The metric name
type: string
}
variants {
type: array
description: The names of the metric variants
items {type: string}
}
}
metrics_scalar_event { metrics_scalar_event {
description: "Used for reporting scalar metrics during training task" description: "Used for reporting scalar metrics during training task"
type: object type: object
@ -193,6 +205,29 @@
description: "Task ID" description: "Task ID"
type: string type: string
} }
metric {
description: "Metric name"
type: string
}
}
}
task_metric_variants {
type: object
required: [task]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
variants {
description: Metric variant names
type: array
items {type: string}
}
} }
} }
task_log_event { task_log_event {
@ -376,7 +411,7 @@
metrics { metrics {
type: array type: array
items { "$ref": "#/definitions/task_metric" } items { "$ref": "#/definitions/task_metric" }
description: "List metrics for which the envents will be retreived" description: "List of task metrics for which the envents will be retreived"
} }
iters { iters {
type: integer type: integer
@ -411,6 +446,17 @@
} }
} }
} }
"2.14": ${debug_images."2.7"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/task_metric_variants" }
}
}
}
}
} }
get_debug_image_sample { get_debug_image_sample {
"2.12": { "2.12": {
@ -804,6 +850,17 @@
} }
} }
} }
"2.14": ${get_task_plots."2.1"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
} }
get_multi_task_plots { get_multi_task_plots {
"2.1" { "2.1" {
@ -962,6 +1019,17 @@
} }
} }
} }
"2.14": ${scalar_metrics_iter_histogram."2.1"} {
request {
properties {
metrics {
type: array
description: List of metrics and variants
items { "$ref": "#/definitions/metric_variants" }
}
}
}
}
} }
multi_task_scalar_metrics_iter_histogram { multi_task_scalar_metrics_iter_histogram {
"2.1" { "2.1" {

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from operator import itemgetter from operator import itemgetter
import attr import attr
from typing import Sequence, Optional
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.events import ( from apiserver.apimodels.events import (
@ -17,9 +18,11 @@ from apiserver.apimodels.events import (
LogOrderEnum, LogOrderEnum,
GetDebugImageSampleRequest, GetDebugImageSampleRequest,
NextDebugImageSampleRequest, NextDebugImageSampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json from apiserver.utilities import json
@ -321,7 +324,7 @@ def get_task_latest_scalar_values(call, company_id, _):
) )
last_iters = event_bll.get_last_iters( last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1 company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
) ).get(task_id)
call.result.data = dict( call.result.data = dict(
metrics=metrics, metrics=metrics,
last_iter=last_iters[0] if last_iters else 0, last_iter=last_iters[0] if last_iters else 0,
@ -494,11 +497,22 @@ def get_task_plots_v1_7(call, company_id, _):
) )
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"]) def _get_metric_variants_from_request(
def get_task_plots(call, company_id, _): req_metrics: Sequence[ApiMetrics],
task_id = call.data["task"] ) -> Optional[MetricVariants]:
iters = call.data.get("iters", 1) if not req_metrics:
scroll_id = call.data.get("scroll_id") return None
return {m.metric: m.variants for m in req_metrics}
@endpoint(
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
)
def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin") company_id, task_id, allow_public=True, only=("company", "company_origin")
@ -509,6 +523,7 @@ def get_task_plots(call, company_id, _):
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters, last_iterations_per_plot=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
metric_variants=_get_metric_variants_from_request(request.metrics),
) )
return_events = result.events return_events = result.events
@ -594,9 +609,9 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=DebugImageResponse, response_data_model=DebugImageResponse,
) )
def get_debug_images(call, company_id, request: DebugImagesRequest): def get_debug_images(call, company_id, request: DebugImagesRequest):
task_metrics = defaultdict(set) task_metrics = defaultdict(dict)
for tm in request.metrics: for tm in request.metrics:
task_metrics[tm.task].add(tm.metric) task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values(): for metrics in task_metrics.values():
if None in metrics: if None in metrics:
metrics.clear() metrics.clear()
@ -734,11 +749,11 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
def _get_top_iter_unique_events(events, max_iters): def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: []) top_unique_events = defaultdict(lambda: [])
for e in events: for ev in events:
key = e.get("metric", "") + e.get("variant", "") key = ev.get("metric", "") + ev.get("variant", "")
evs = top_unique_events[key] evs = top_unique_events[key]
if len(evs) < max_iters: if len(evs) < max_iters:
evs.append(e) evs.append(ev)
unique_events = list( unique_events = list(
itertools.chain.from_iterable(list(top_unique_events.values())) itertools.chain.from_iterable(list(top_unique_events.values()))
) )