mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 10:43:10 +00:00
Support querying task events per specific metrics and variants
This commit is contained in:
parent
677bb3ba6d
commit
9069cfe1da
@ -14,12 +14,20 @@ from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class MetricVariants(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[str] = ListField(
|
||||
items_types=str, validators=Length(minimum_value=1)
|
||||
)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@ -39,6 +47,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
@ -59,8 +68,8 @@ class TaskMetricVariant(Base):
|
||||
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
@ -102,3 +111,10 @@ class TaskMetricsRequest(Base):
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@ -74,7 +75,7 @@ class DebugImagesIterator:
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
@ -118,7 +119,7 @@ class DebugImagesIterator:
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new debug image events were added
|
||||
@ -158,11 +159,11 @@ class DebugImagesIterator:
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = set(
|
||||
m
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
)
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
@ -196,7 +197,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
@ -213,7 +214,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
@ -222,10 +223,11 @@ class DebugImagesIterator:
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||
if metrics:
|
||||
must.append({"terms": {"metric": list(metrics)}})
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
|
@ -6,9 +6,8 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
@ -22,6 +21,8 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
MAX_LONG = 2**63 - 1
|
||||
MIN_LONG = -2**63
|
||||
MAX_LONG = 2 ** 63 - 1
|
||||
MIN_LONG = -(2 ** 63)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
@ -94,7 +95,7 @@ class EventBLL(object):
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
actions: List[dict] = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_scalar_events = nested_dict(
|
||||
@ -197,7 +198,6 @@ class EventBLL(object):
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
@ -260,7 +260,8 @@ class EventBLL(object):
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
if invalid_iterations_count:
|
||||
raise BulkIndexError(
|
||||
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
|
||||
f"{invalid_iterations_count} document(s) failed to index.",
|
||||
[invalid_iteration_error],
|
||||
)
|
||||
|
||||
if not added:
|
||||
@ -466,10 +467,16 @@ class EventBLL(object):
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@ -499,7 +506,7 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
@ -527,6 +534,7 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@ -555,6 +563,8 @@ class EventBLL(object):
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
@ -563,6 +573,7 @@ class EventBLL(object):
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@ -669,19 +680,19 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
@ -691,26 +702,24 @@ class EventBLL(object):
|
||||
if last_iter_count is None:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
)
|
||||
}
|
||||
for task, last_iters in tasks_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
@ -748,6 +757,7 @@ class EventBLL(object):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@ -768,7 +778,7 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
@ -787,21 +797,24 @@ class EventBLL(object):
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||
def get_task_latest_scalar_values(
|
||||
self, company_id, task_id
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
return [], 0
|
||||
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@ -905,11 +918,16 @@ class EventBLL(object):
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(
|
||||
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||
):
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
) -> Mapping[str, Sequence]:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@ -921,7 +939,7 @@ class EventBLL(object):
|
||||
}
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
@ -930,9 +948,12 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
return {}
|
||||
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
return {
|
||||
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
}
|
||||
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
with translate_errors_context():
|
||||
@ -965,7 +986,9 @@ class EventBLL(object):
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
es_req = {"query": {"terms": {"task": task_ids}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "delete_multi_tasks_events"
|
||||
):
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, Mapping
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
@ -16,6 +16,9 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
@ -64,3 +67,23 @@ def delete_company_events(
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_metric_variants_condition(
|
||||
metric_variants: MetricVariants,
|
||||
) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"variant": variants}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if variants
|
||||
else {"term": {"metric": metric}}
|
||||
for metric, variants in metric_variants.items()
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
@ -34,7 +36,12 @@ class EventMetrics:
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
@ -46,7 +53,12 @@ class EventMetrics:
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
@ -57,6 +69,7 @@ class EventMetrics:
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
@ -64,6 +77,7 @@ class EventMetrics:
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
@ -197,6 +211,7 @@ class EventMetrics:
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
@ -204,9 +219,14 @@ class EventMetrics:
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
|
@ -1,6 +1,18 @@
|
||||
{
|
||||
_description : "Provides an API for running tasks to report events collected by the system."
|
||||
_definitions {
|
||||
metric_variants {
|
||||
type: object
|
||||
metric {
|
||||
description: The metric name
|
||||
type: string
|
||||
}
|
||||
variants {
|
||||
type: array
|
||||
description: The names of the metric variants
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
metrics_scalar_event {
|
||||
description: "Used for reporting scalar metrics during training task"
|
||||
type: object
|
||||
@ -193,6 +205,29 @@
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_metric_variants {
|
||||
type: object
|
||||
required: [task]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variants {
|
||||
description: Metric variant names
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
task_log_event {
|
||||
@ -376,7 +411,7 @@
|
||||
metrics {
|
||||
type: array
|
||||
items { "$ref": "#/definitions/task_metric" }
|
||||
description: "List metrics for which the envents will be retreived"
|
||||
description: "List of task metrics for which the envents will be retreived"
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
@ -411,6 +446,17 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${debug_images."2.7"} {
|
||||
request {
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/task_metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_debug_image_sample {
|
||||
"2.12": {
|
||||
@ -804,6 +850,17 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${get_task_plots."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_plots {
|
||||
"2.1" {
|
||||
@ -962,6 +1019,17 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${scalar_metrics_iter_histogram."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_task_scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
|
@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
|
||||
import attr
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.events import (
|
||||
@ -17,9 +18,11 @@ from apiserver.apimodels.events import (
|
||||
LogOrderEnum,
|
||||
GetDebugImageSampleRequest,
|
||||
NextDebugImageSampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json
|
||||
@ -321,7 +324,7 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
)
|
||||
last_iters = event_bll.get_last_iters(
|
||||
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||
)
|
||||
).get(task_id)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
last_iter=last_iters[0] if last_iters else 0,
|
||||
@ -494,11 +497,22 @@ def get_task_plots_v1_7(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
|
||||
def get_task_plots(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
def _get_metric_variants_from_request(
|
||||
req_metrics: Sequence[ApiMetrics],
|
||||
) -> Optional[MetricVariants]:
|
||||
if not req_metrics:
|
||||
return None
|
||||
|
||||
return {m.metric: m.variants for m in req_metrics}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
|
||||
)
|
||||
def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@ -509,6 +523,7 @@ def get_task_plots(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@ -594,9 +609,9 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_metrics = defaultdict(set)
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task].add(tm.metric)
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
@ -734,11 +749,11 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
|
||||
def _get_top_iter_unique_events(events, max_iters):
|
||||
top_unique_events = defaultdict(lambda: [])
|
||||
for e in events:
|
||||
key = e.get("metric", "") + e.get("variant", "")
|
||||
for ev in events:
|
||||
key = ev.get("metric", "") + ev.get("variant", "")
|
||||
evs = top_unique_events[key]
|
||||
if len(evs) < max_iters:
|
||||
evs.append(e)
|
||||
evs.append(ev)
|
||||
unique_events = list(
|
||||
itertools.chain.from_iterable(list(top_unique_events.values()))
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user