mirror of
https://github.com/clearml/clearml-server
synced 2025-03-09 21:51:54 +00:00
Support iterating over all task metrics in task debug images
This commit is contained in:
parent
c034c1a986
commit
0d5174c453
@ -38,7 +38,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
|
@ -3,17 +3,16 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from boltons.iterutils import bucketize, first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
@ -73,7 +72,7 @@ class DebugImagesIterator:
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
@ -83,8 +82,7 @@ class DebugImagesIterator:
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(company_id, list(unique_metrics))
|
||||
state_.metrics = self._init_metric_states(company_id, task_metrics)
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
@ -92,14 +90,8 @@ class DebugImagesIterator:
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, state_)
|
||||
self._reinit_outdated_metric_states(company_id, state_, task_metrics)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
@ -123,14 +115,16 @@ class DebugImagesIterator:
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, state: DebugImageEventsScrollState
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
since their states were initialized and reinits these states
|
||||
"""
|
||||
task_ids = set(metric.task for metric in state.metrics)
|
||||
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
@ -140,6 +134,7 @@ class DebugImagesIterator:
|
||||
if not metric_stats:
|
||||
return []
|
||||
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
@ -147,6 +142,7 @@ class DebugImagesIterator:
|
||||
)
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
]
|
||||
|
||||
update_times = dict(
|
||||
@ -154,32 +150,31 @@ class DebugImagesIterator:
|
||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
)
|
||||
)
|
||||
outdated_metrics = [
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if (metric.task, metric.name) in update_times
|
||||
and update_times[metric.task, metric.name] > metric.timestamp
|
||||
]
|
||||
state.metrics = [
|
||||
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||
*(
|
||||
self._init_metric_states(
|
||||
company_id,
|
||||
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
metrics_to_update = defaultdict(set)
|
||||
for (task, metric), update_time in update_times.items():
|
||||
state_metric = first(
|
||||
m for m in state.metrics if m.task == task and m.name == metric
|
||||
)
|
||||
if not state_metric or state_metric.timestamp < update_time:
|
||||
metrics_to_update[task].add(metric)
|
||||
|
||||
if metrics_to_update:
|
||||
state.metrics = [
|
||||
*(
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if metric.name not in metrics_to_update.get(metric.task, [])
|
||||
),
|
||||
*(self._init_metric_states(company_id, metrics_to_update)),
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
self, company_id: str, metrics: Sequence[Tuple[str, str]]
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
tasks = defaultdict(list)
|
||||
for (task, metric) in metrics:
|
||||
tasks[task].append(metric)
|
||||
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
@ -187,30 +182,25 @@ class DebugImagesIterator:
|
||||
partial(
|
||||
self._init_metric_states_for_task, company_id=company_id
|
||||
),
|
||||
tasks.items(),
|
||||
task_metrics.items(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Sequence[str]], company_id: str
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
) -> Sequence[MetricScrollState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
"""
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||
if metrics:
|
||||
must.append({"terms": {"metric": list(metrics)}})
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@ -254,10 +244,7 @@ class DebugImagesIterator:
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
@ -397,10 +384,7 @@ class DebugImagesIterator:
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
@ -110,12 +110,15 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
if not metrics:
|
||||
return set()
|
||||
|
||||
task_metrics = [(task, metric) for metric in metrics]
|
||||
task_metrics = {task: set(metrics)}
|
||||
scroll_id = None
|
||||
urls = defaultdict(set)
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=100,
|
||||
state_id=scroll_id,
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
events for _, _, events in res.metric_events
|
||||
|
@ -187,14 +187,14 @@
|
||||
}
|
||||
task_metric {
|
||||
type: object
|
||||
required: [task, metric]
|
||||
required: [task]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
description: "Metric name. If not specified then all metrics for this task will be returned"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
|
@ -594,10 +594,16 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_ids = {m.task for m in request.metrics}
|
||||
task_metrics = defaultdict(set)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task].add(tm.metric)
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)
|
||||
@ -610,7 +616,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
metrics=[(m.task, m.metric) for m in request.metrics],
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
refresh=request.refresh,
|
||||
|
@ -1,6 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.tests.automated import TestService
|
||||
@ -128,6 +127,98 @@ class TestTaskDebugImages(TestService):
|
||||
|
||||
def test_task_debug_images(self):
|
||||
task = self._temp_task()
|
||||
|
||||
# test empty
|
||||
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
|
||||
self.assertFalse(res.metrics)
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
|
||||
)
|
||||
self.assertFalse(res.metrics)
|
||||
|
||||
# test not empty
|
||||
metrics = {
|
||||
"Metric1": ["Variant1", "Variant2"],
|
||||
"Metric2": ["Variant3", "Variant4"],
|
||||
}
|
||||
events = [
|
||||
self._create_task_event(
|
||||
task=task,
|
||||
iteration=1,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
url=f"{metric}_{variant}_{1}",
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for variant in variants
|
||||
]
|
||||
self.send_batch(events)
|
||||
scroll_id = self._assertTaskMetrics(
|
||||
task=task, expected_metrics=metrics, iterations=1
|
||||
)
|
||||
|
||||
# test refresh
|
||||
update = {
|
||||
"Metric2": ["Variant3", "Variant4", "Variant5"],
|
||||
"Metric3": ["VariantA", "VariantB"],
|
||||
}
|
||||
events = [
|
||||
self._create_task_event(
|
||||
task=task,
|
||||
iteration=2,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
url=f"{metric}_{variant}_{2}",
|
||||
)
|
||||
for metric, variants in update.items()
|
||||
for variant in variants
|
||||
]
|
||||
self.send_batch(events)
|
||||
# without refresh the metric states are not updated
|
||||
scroll_id = self._assertTaskMetrics(
|
||||
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
# with refresh there are new metrics and existing ones are updated
|
||||
metrics.update(update)
|
||||
self._assertTaskMetrics(
|
||||
task=task,
|
||||
expected_metrics=metrics,
|
||||
iterations=1,
|
||||
scroll_id=scroll_id,
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def _assertTaskMetrics(
|
||||
self,
|
||||
task: str,
|
||||
expected_metrics: Mapping[str, Sequence[str]],
|
||||
iterations,
|
||||
scroll_id: str = None,
|
||||
refresh=False,
|
||||
) -> str:
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
|
||||
)
|
||||
self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics))
|
||||
if not iterations:
|
||||
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||
return res.scroll_id
|
||||
|
||||
for metric_data in res.metrics:
|
||||
expected_variants = set(expected_metrics[metric_data.metric])
|
||||
self.assertEqual(len(metric_data.iterations), iterations)
|
||||
for it_data in metric_data.iterations:
|
||||
self.assertEqual(
|
||||
set(e.variant for e in it_data.events), expected_variants
|
||||
)
|
||||
|
||||
return res.scroll_id
|
||||
|
||||
def test_get_debug_images_navigation(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
variants = [("Variant1", 7), ("Variant2", 4)]
|
||||
iterations = 10
|
||||
@ -195,7 +286,7 @@ class TestTaskDebugImages(TestService):
|
||||
expected_page: int,
|
||||
iters: int = 5,
|
||||
**extra_params,
|
||||
):
|
||||
) -> str:
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}],
|
||||
iters=iters,
|
||||
|
Loading…
Reference in New Issue
Block a user