mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Support iterating over all task metrics in task debug images
This commit is contained in:
parent
c034c1a986
commit
0d5174c453
@ -38,7 +38,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
|||||||
|
|
||||||
class TaskMetric(Base):
|
class TaskMetric(Base):
|
||||||
task: str = StringField(required=True)
|
task: str = StringField(required=True)
|
||||||
metric: str = StringField(required=True)
|
metric: str = StringField(default=None)
|
||||||
|
|
||||||
|
|
||||||
class DebugImagesRequest(Base):
|
class DebugImagesRequest(Base):
|
||||||
|
|||||||
@ -3,17 +3,16 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from operator import attrgetter, itemgetter
|
from operator import attrgetter, itemgetter
|
||||||
from typing import Sequence, Tuple, Optional, Mapping
|
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import dpath
|
import dpath
|
||||||
from boltons.iterutils import bucketize
|
from boltons.iterutils import bucketize, first
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from jsonmodels.fields import StringField, ListField, IntField
|
from jsonmodels.fields import StringField, ListField, IntField
|
||||||
from jsonmodels.models import Base
|
from jsonmodels.models import Base
|
||||||
from redis import StrictRedis
|
from redis import StrictRedis
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
|
||||||
from apiserver.apimodels import JsonSerializableMixin
|
from apiserver.apimodels import JsonSerializableMixin
|
||||||
from apiserver.bll.event.event_common import (
|
from apiserver.bll.event.event_common import (
|
||||||
EventSettings,
|
EventSettings,
|
||||||
@ -73,7 +72,7 @@ class DebugImagesIterator:
|
|||||||
def get_task_events(
|
def get_task_events(
|
||||||
self,
|
self,
|
||||||
company_id: str,
|
company_id: str,
|
||||||
metrics: Sequence[Tuple[str, str]],
|
task_metrics: Mapping[str, Set[str]],
|
||||||
iter_count: int,
|
iter_count: int,
|
||||||
navigate_earlier: bool = True,
|
navigate_earlier: bool = True,
|
||||||
refresh: bool = False,
|
refresh: bool = False,
|
||||||
@ -83,8 +82,7 @@ class DebugImagesIterator:
|
|||||||
return DebugImagesResult()
|
return DebugImagesResult()
|
||||||
|
|
||||||
def init_state(state_: DebugImageEventsScrollState):
|
def init_state(state_: DebugImageEventsScrollState):
|
||||||
unique_metrics = set(metrics)
|
state_.metrics = self._init_metric_states(company_id, task_metrics)
|
||||||
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
|
|
||||||
|
|
||||||
def validate_state(state_: DebugImageEventsScrollState):
|
def validate_state(state_: DebugImageEventsScrollState):
|
||||||
"""
|
"""
|
||||||
@ -92,14 +90,8 @@ class DebugImagesIterator:
|
|||||||
as requested in the current call.
|
as requested in the current call.
|
||||||
Refresh the state if requested
|
Refresh the state if requested
|
||||||
"""
|
"""
|
||||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
|
||||||
if state_metrics != set(metrics):
|
|
||||||
raise errors.bad_request.InvalidScrollId(
|
|
||||||
"Task metrics stored in the state do not match the passed ones",
|
|
||||||
scroll_id=state_.id,
|
|
||||||
)
|
|
||||||
if refresh:
|
if refresh:
|
||||||
self._reinit_outdated_metric_states(company_id, state_)
|
self._reinit_outdated_metric_states(company_id, state_, task_metrics)
|
||||||
for metric_state in state_.metrics:
|
for metric_state in state_.metrics:
|
||||||
metric_state.reset()
|
metric_state.reset()
|
||||||
|
|
||||||
@ -123,14 +115,16 @@ class DebugImagesIterator:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def _reinit_outdated_metric_states(
|
def _reinit_outdated_metric_states(
|
||||||
self, company_id, state: DebugImageEventsScrollState
|
self,
|
||||||
|
company_id,
|
||||||
|
state: DebugImageEventsScrollState,
|
||||||
|
task_metrics: Mapping[str, Set[str]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Determines the metrics for which new debug image events were added
|
Determines the metrics for which new debug image events were added
|
||||||
since their states were initialized and reinits these states
|
since their states were initialized and reinits these states
|
||||||
"""
|
"""
|
||||||
task_ids = set(metric.task for metric in state.metrics)
|
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||||
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
|
||||||
"id", "metric_stats"
|
"id", "metric_stats"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -140,6 +134,7 @@ class DebugImagesIterator:
|
|||||||
if not metric_stats:
|
if not metric_stats:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
requested_metrics = task_metrics[task.id]
|
||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
(task.id, stats.metric),
|
(task.id, stats.metric),
|
||||||
@ -147,6 +142,7 @@ class DebugImagesIterator:
|
|||||||
)
|
)
|
||||||
for stats in metric_stats.values()
|
for stats in metric_stats.values()
|
||||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||||
|
and (not requested_metrics or stats.metric in requested_metrics)
|
||||||
]
|
]
|
||||||
|
|
||||||
update_times = dict(
|
update_times = dict(
|
||||||
@ -154,32 +150,31 @@ class DebugImagesIterator:
|
|||||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
outdated_metrics = [
|
|
||||||
|
metrics_to_update = defaultdict(set)
|
||||||
|
for (task, metric), update_time in update_times.items():
|
||||||
|
state_metric = first(
|
||||||
|
m for m in state.metrics if m.task == task and m.name == metric
|
||||||
|
)
|
||||||
|
if not state_metric or state_metric.timestamp < update_time:
|
||||||
|
metrics_to_update[task].add(metric)
|
||||||
|
|
||||||
|
if metrics_to_update:
|
||||||
|
state.metrics = [
|
||||||
|
*(
|
||||||
metric
|
metric
|
||||||
for metric in state.metrics
|
for metric in state.metrics
|
||||||
if (metric.task, metric.name) in update_times
|
if metric.name not in metrics_to_update.get(metric.task, [])
|
||||||
and update_times[metric.task, metric.name] > metric.timestamp
|
|
||||||
]
|
|
||||||
state.metrics = [
|
|
||||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
|
||||||
*(
|
|
||||||
self._init_metric_states(
|
|
||||||
company_id,
|
|
||||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
|
*(self._init_metric_states(company_id, metrics_to_update)),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _init_metric_states(
|
def _init_metric_states(
|
||||||
self, company_id: str, metrics: Sequence[Tuple[str, str]]
|
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||||
) -> Sequence[MetricScrollState]:
|
) -> Sequence[MetricScrollState]:
|
||||||
"""
|
"""
|
||||||
Returned initialized metric scroll stated for the requested task metrics
|
Returned initialized metric scroll stated for the requested task metrics
|
||||||
"""
|
"""
|
||||||
tasks = defaultdict(list)
|
|
||||||
for (task, metric) in metrics:
|
|
||||||
tasks[task].append(metric)
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||||
return list(
|
return list(
|
||||||
chain.from_iterable(
|
chain.from_iterable(
|
||||||
@ -187,30 +182,25 @@ class DebugImagesIterator:
|
|||||||
partial(
|
partial(
|
||||||
self._init_metric_states_for_task, company_id=company_id
|
self._init_metric_states_for_task, company_id=company_id
|
||||||
),
|
),
|
||||||
tasks.items(),
|
task_metrics.items(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_metric_states_for_task(
|
def _init_metric_states_for_task(
|
||||||
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
|
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||||
) -> Sequence[MetricScrollState]:
|
) -> Sequence[MetricScrollState]:
|
||||||
"""
|
"""
|
||||||
Return metric scroll states for the task filled with the variant states
|
Return metric scroll states for the task filled with the variant states
|
||||||
for the variants that reported any debug images
|
for the variants that reported any debug images
|
||||||
"""
|
"""
|
||||||
task, metrics = task_metrics
|
task, metrics = task_metrics
|
||||||
|
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||||
|
if metrics:
|
||||||
|
must.append({"terms": {"metric": list(metrics)}})
|
||||||
es_req: dict = {
|
es_req: dict = {
|
||||||
"size": 0,
|
"size": 0,
|
||||||
"query": {
|
"query": {"bool": {"must": must}},
|
||||||
"bool": {
|
|
||||||
"must": [
|
|
||||||
{"term": {"task": task}},
|
|
||||||
{"terms": {"metric": metrics}},
|
|
||||||
{"exists": {"field": "url"}},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"aggs": {
|
"aggs": {
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"terms": {
|
"terms": {
|
||||||
@ -254,10 +244,7 @@ class DebugImagesIterator:
|
|||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||||
es_res = search_company_events(
|
es_res = search_company_events(
|
||||||
self.es,
|
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||||
company_id=company_id,
|
|
||||||
event_type=self.EVENT_TYPE,
|
|
||||||
body=es_req,
|
|
||||||
)
|
)
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return []
|
return []
|
||||||
@ -397,10 +384,7 @@ class DebugImagesIterator:
|
|||||||
}
|
}
|
||||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||||
es_res = search_company_events(
|
es_res = search_company_events(
|
||||||
self.es,
|
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||||
company_id=company_id,
|
|
||||||
event_type=self.EVENT_TYPE,
|
|
||||||
body=es_req,
|
|
||||||
)
|
)
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return metric.task, metric.name, []
|
return metric.task, metric.name, []
|
||||||
|
|||||||
@ -110,12 +110,15 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
|||||||
if not metrics:
|
if not metrics:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
task_metrics = [(task, metric) for metric in metrics]
|
task_metrics = {task: set(metrics)}
|
||||||
scroll_id = None
|
scroll_id = None
|
||||||
urls = defaultdict(set)
|
urls = defaultdict(set)
|
||||||
while True:
|
while True:
|
||||||
res = event_bll.debug_images_iterator.get_task_events(
|
res = event_bll.debug_images_iterator.get_task_events(
|
||||||
company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id
|
company_id=company,
|
||||||
|
task_metrics=task_metrics,
|
||||||
|
iter_count=100,
|
||||||
|
state_id=scroll_id,
|
||||||
)
|
)
|
||||||
if not res.metric_events or not any(
|
if not res.metric_events or not any(
|
||||||
events for _, _, events in res.metric_events
|
events for _, _, events in res.metric_events
|
||||||
|
|||||||
@ -187,14 +187,14 @@
|
|||||||
}
|
}
|
||||||
task_metric {
|
task_metric {
|
||||||
type: object
|
type: object
|
||||||
required: [task, metric]
|
required: [task]
|
||||||
properties {
|
properties {
|
||||||
task {
|
task {
|
||||||
description: "Task ID"
|
description: "Task ID"
|
||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
metric {
|
metric {
|
||||||
description: "Metric name"
|
description: "Metric name. If not specified then all metrics for this task will be returned"
|
||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -594,10 +594,16 @@ def get_debug_images_v1_8(call, company_id, _):
|
|||||||
response_data_model=DebugImageResponse,
|
response_data_model=DebugImageResponse,
|
||||||
)
|
)
|
||||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||||
task_ids = {m.task for m in request.metrics}
|
task_metrics = defaultdict(set)
|
||||||
|
for tm in request.metrics:
|
||||||
|
task_metrics[tm.task].add(tm.metric)
|
||||||
|
for metrics in task_metrics.values():
|
||||||
|
if None in metrics:
|
||||||
|
metrics.clear()
|
||||||
|
|
||||||
tasks = task_bll.assert_exists(
|
tasks = task_bll.assert_exists(
|
||||||
company_id,
|
company_id,
|
||||||
task_ids=task_ids,
|
task_ids=list(task_metrics),
|
||||||
allow_public=True,
|
allow_public=True,
|
||||||
only=("company", "company_origin"),
|
only=("company", "company_origin"),
|
||||||
)
|
)
|
||||||
@ -610,7 +616,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
|||||||
|
|
||||||
result = event_bll.debug_images_iterator.get_task_events(
|
result = event_bll.debug_images_iterator.get_task_events(
|
||||||
company_id=next(iter(companies)),
|
company_id=next(iter(companies)),
|
||||||
metrics=[(m.task, m.metric) for m in request.metrics],
|
task_metrics=task_metrics,
|
||||||
iter_count=request.iters,
|
iter_count=request.iters,
|
||||||
navigate_earlier=request.navigate_earlier,
|
navigate_earlier=request.navigate_earlier,
|
||||||
refresh=request.refresh,
|
refresh=request.refresh,
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Sequence
|
from typing import Sequence, Mapping
|
||||||
|
|
||||||
|
|
||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
from apiserver.tests.automated import TestService
|
from apiserver.tests.automated import TestService
|
||||||
@ -128,6 +127,98 @@ class TestTaskDebugImages(TestService):
|
|||||||
|
|
||||||
def test_task_debug_images(self):
|
def test_task_debug_images(self):
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
|
|
||||||
|
# test empty
|
||||||
|
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
|
||||||
|
self.assertFalse(res.metrics)
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
|
||||||
|
)
|
||||||
|
self.assertFalse(res.metrics)
|
||||||
|
|
||||||
|
# test not empty
|
||||||
|
metrics = {
|
||||||
|
"Metric1": ["Variant1", "Variant2"],
|
||||||
|
"Metric2": ["Variant3", "Variant4"],
|
||||||
|
}
|
||||||
|
events = [
|
||||||
|
self._create_task_event(
|
||||||
|
task=task,
|
||||||
|
iteration=1,
|
||||||
|
metric=metric,
|
||||||
|
variant=variant,
|
||||||
|
url=f"{metric}_{variant}_{1}",
|
||||||
|
)
|
||||||
|
for metric, variants in metrics.items()
|
||||||
|
for variant in variants
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
scroll_id = self._assertTaskMetrics(
|
||||||
|
task=task, expected_metrics=metrics, iterations=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# test refresh
|
||||||
|
update = {
|
||||||
|
"Metric2": ["Variant3", "Variant4", "Variant5"],
|
||||||
|
"Metric3": ["VariantA", "VariantB"],
|
||||||
|
}
|
||||||
|
events = [
|
||||||
|
self._create_task_event(
|
||||||
|
task=task,
|
||||||
|
iteration=2,
|
||||||
|
metric=metric,
|
||||||
|
variant=variant,
|
||||||
|
url=f"{metric}_{variant}_{2}",
|
||||||
|
)
|
||||||
|
for metric, variants in update.items()
|
||||||
|
for variant in variants
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
# without refresh the metric states are not updated
|
||||||
|
scroll_id = self._assertTaskMetrics(
|
||||||
|
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# with refresh there are new metrics and existing ones are updated
|
||||||
|
metrics.update(update)
|
||||||
|
self._assertTaskMetrics(
|
||||||
|
task=task,
|
||||||
|
expected_metrics=metrics,
|
||||||
|
iterations=1,
|
||||||
|
scroll_id=scroll_id,
|
||||||
|
refresh=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _assertTaskMetrics(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
expected_metrics: Mapping[str, Sequence[str]],
|
||||||
|
iterations,
|
||||||
|
scroll_id: str = None,
|
||||||
|
refresh=False,
|
||||||
|
) -> str:
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
|
||||||
|
)
|
||||||
|
self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics))
|
||||||
|
if not iterations:
|
||||||
|
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||||
|
return res.scroll_id
|
||||||
|
|
||||||
|
for metric_data in res.metrics:
|
||||||
|
expected_variants = set(expected_metrics[metric_data.metric])
|
||||||
|
self.assertEqual(len(metric_data.iterations), iterations)
|
||||||
|
for it_data in metric_data.iterations:
|
||||||
|
self.assertEqual(
|
||||||
|
set(e.variant for e in it_data.events), expected_variants
|
||||||
|
)
|
||||||
|
|
||||||
|
return res.scroll_id
|
||||||
|
|
||||||
|
def test_get_debug_images_navigation(self):
|
||||||
|
task = self._temp_task()
|
||||||
metric = "Metric1"
|
metric = "Metric1"
|
||||||
variants = [("Variant1", 7), ("Variant2", 4)]
|
variants = [("Variant1", 7), ("Variant2", 4)]
|
||||||
iterations = 10
|
iterations = 10
|
||||||
@ -195,7 +286,7 @@ class TestTaskDebugImages(TestService):
|
|||||||
expected_page: int,
|
expected_page: int,
|
||||||
iters: int = 5,
|
iters: int = 5,
|
||||||
**extra_params,
|
**extra_params,
|
||||||
):
|
) -> str:
|
||||||
res = self.api.events.debug_images(
|
res = self.api.events.debug_images(
|
||||||
metrics=[{"task": task, "metric": metric}],
|
metrics=[{"task": task, "metric": metric}],
|
||||||
iters=iters,
|
iters=iters,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user