mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 02:33:02 +00:00
Return results per task iterations in debug images request
This commit is contained in:
parent
6b0c45a861
commit
b67aa05d6f
@ -89,7 +89,6 @@ class IterationEvents(Base):
|
||||
|
||||
class MetricEvents(Base):
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||
|
||||
|
||||
|
@ -1,13 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize, first
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
@ -27,19 +26,22 @@ from apiserver.database.model.task.task import Task
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
name: str = StringField(required=True)
|
||||
recycle_url_marker: str = StringField()
|
||||
class VariantState(Base):
|
||||
variant: str = StringField(required=True)
|
||||
last_invalid_iteration: int = IntField()
|
||||
|
||||
|
||||
class MetricScrollState(Base):
|
||||
class MetricState(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[VariantState] = ListField([VariantState], required=True)
|
||||
timestamp: int = IntField(default=0)
|
||||
|
||||
|
||||
class TaskScrollState(Base):
|
||||
task: str = StringField(required=True)
|
||||
name: str = StringField(required=True)
|
||||
metrics: Sequence[MetricState] = ListField([MetricState], required=True)
|
||||
last_min_iter: Optional[int] = IntField()
|
||||
last_max_iter: Optional[int] = IntField()
|
||||
timestamp: int = IntField(default=0)
|
||||
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state for the metric"""
|
||||
@ -48,7 +50,7 @@ class MetricScrollState(Base):
|
||||
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
tasks: Sequence[TaskScrollState] = ListField([TaskScrollState])
|
||||
warning: str = StringField()
|
||||
|
||||
|
||||
@ -82,7 +84,7 @@ class DebugImagesIterator:
|
||||
return DebugImagesResult()
|
||||
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
state_.metrics = self._init_metric_states(company_id, task_metrics)
|
||||
state_.tasks = self._init_task_states(company_id, task_metrics)
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
@ -91,9 +93,7 @@ class DebugImagesIterator:
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, state_, task_metrics)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
self._reinit_outdated_task_states(company_id, state_, task_metrics)
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
@ -108,88 +108,113 @@ class DebugImagesIterator:
|
||||
iter_count=iter_count,
|
||||
navigate_earlier=navigate_earlier,
|
||||
),
|
||||
state.metrics,
|
||||
state.tasks,
|
||||
)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
def _reinit_outdated_task_states(
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
):
|
||||
"""
|
||||
Determines the metrics for which new debug image events were added
|
||||
since their states were initialized and reinits these states
|
||||
Determine the metrics for which new debug image events were added
|
||||
since their states were initialized and re-init these states
|
||||
"""
|
||||
tasks = Task.objects(id__in=list(task_metrics), company=company_id).only(
|
||||
"id", "metric_stats"
|
||||
)
|
||||
|
||||
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
|
||||
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
|
||||
def get_last_update_times_for_task_metrics(
|
||||
task: Task,
|
||||
) -> Mapping[str, datetime]:
|
||||
"""For metrics that reported debug image events get mapping of the metric name to the last update times"""
|
||||
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||
if not metric_stats:
|
||||
return []
|
||||
return {}
|
||||
|
||||
requested_metrics = task_metrics[task.id]
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
|
||||
)
|
||||
return {
|
||||
stats.metric: stats.event_stats_by_type[
|
||||
self.EVENT_TYPE.value
|
||||
].last_update
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
and (not requested_metrics or stats.metric in requested_metrics)
|
||||
]
|
||||
}
|
||||
|
||||
update_times = dict(
|
||||
chain.from_iterable(
|
||||
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
update_times = {
|
||||
task.id: get_last_update_times_for_task_metrics(task) for task in tasks
|
||||
}
|
||||
task_metric_states = {
|
||||
task_state.task: {
|
||||
metric_state.metric: metric_state for metric_state in task_state.metrics
|
||||
}
|
||||
for task_state in state.tasks
|
||||
}
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = set(
|
||||
m
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
)
|
||||
)
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
metrics_to_update = defaultdict(set)
|
||||
for (task, metric), update_time in update_times.items():
|
||||
state_metric = first(
|
||||
m for m in state.metrics if m.task == task and m.name == metric
|
||||
updated_task_states = self._init_task_states(company_id, task_metrics_to_recalc)
|
||||
|
||||
def merge_with_updated_task_states(
|
||||
old_state: TaskScrollState, updates: Sequence[TaskScrollState]
|
||||
) -> TaskScrollState:
|
||||
task = old_state.task
|
||||
updated_state = first(uts for uts in updates if uts.task == task)
|
||||
if not updated_state:
|
||||
old_state.reset()
|
||||
return old_state
|
||||
|
||||
updated_metrics = [m.metric for m in updated_state.metrics]
|
||||
return TaskScrollState(
|
||||
task=task,
|
||||
metrics=[
|
||||
*updated_state.metrics,
|
||||
*(
|
||||
old_metric
|
||||
for old_metric in old_state.metrics
|
||||
if old_metric.metric not in updated_metrics
|
||||
),
|
||||
],
|
||||
)
|
||||
if not state_metric or state_metric.timestamp < update_time:
|
||||
metrics_to_update[task].add(metric)
|
||||
|
||||
if metrics_to_update:
|
||||
state.metrics = [
|
||||
*(
|
||||
metric
|
||||
for metric in state.metrics
|
||||
if metric.name not in metrics_to_update.get(metric.task, [])
|
||||
),
|
||||
*(self._init_metric_states(company_id, metrics_to_update)),
|
||||
]
|
||||
state.tasks = [
|
||||
merge_with_updated_task_states(task_state, updated_task_states)
|
||||
for task_state in state.tasks
|
||||
]
|
||||
|
||||
def _init_metric_states(
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
) -> Sequence[MetricScrollState]:
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
"""
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
self._init_metric_states_for_task, company_id=company_id
|
||||
),
|
||||
task_metrics.items(),
|
||||
)
|
||||
)
|
||||
task_metric_states = pool.map(
|
||||
partial(self._init_metric_states_for_task, company_id=company_id),
|
||||
task_metrics.items(),
|
||||
)
|
||||
|
||||
return [
|
||||
TaskScrollState(task=task, metrics=metric_states,)
|
||||
for task, metric_states in zip(task_metrics, task_metric_states)
|
||||
]
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
) -> Sequence[MetricScrollState]:
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
for the variants that reported any debug images
|
||||
@ -249,12 +274,12 @@ class DebugImagesIterator:
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
def init_variant_scroll_state(variant: dict):
|
||||
def init_variant_state(variant: dict):
|
||||
"""
|
||||
Return new variant scroll state for the passed variant bucket
|
||||
Return new variant state for the passed variant bucket
|
||||
If the image urls get recycled then fill the last_invalid_iteration field
|
||||
"""
|
||||
state = VariantScrollState(name=variant["key"])
|
||||
state = VariantState(variant=variant["key"])
|
||||
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||
if len(iters) > 1:
|
||||
@ -262,102 +287,52 @@ class DebugImagesIterator:
|
||||
return state
|
||||
|
||||
return [
|
||||
MetricScrollState(
|
||||
task=task,
|
||||
name=metric["key"],
|
||||
MetricState(
|
||||
metric=metric["key"],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
variants=[
|
||||
init_variant_scroll_state(variant)
|
||||
init_variant_state(variant)
|
||||
for variant in dpath.get(metric, "variants/buckets")
|
||||
],
|
||||
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||
)
|
||||
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||
]
|
||||
|
||||
def _get_task_metric_events(
|
||||
self,
|
||||
metric: MetricScrollState,
|
||||
task_state: TaskScrollState,
|
||||
company_id: str,
|
||||
iter_count: int,
|
||||
navigate_earlier: bool,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Return task metric events grouped by iterations
|
||||
Update metric scroll state
|
||||
Update task scroll state
|
||||
"""
|
||||
if metric.last_max_iter is None:
|
||||
if not task_state.metrics:
|
||||
return task_state.task, []
|
||||
|
||||
if task_state.last_max_iter is None:
|
||||
# the first fetch is always from the latest iteration to the earlier ones
|
||||
navigate_earlier = True
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"term": {"task": task_state.task}},
|
||||
{"terms": {"metric": [m.metric for m in task_state.metrics]}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
range_condition = None
|
||||
if navigate_earlier and metric.last_min_iter is not None:
|
||||
range_condition = {"lt": metric.last_min_iter}
|
||||
elif not navigate_earlier and metric.last_max_iter is not None:
|
||||
range_condition = {"gt": metric.last_max_iter}
|
||||
if navigate_earlier and task_state.last_min_iter is not None:
|
||||
range_condition = {"lt": task_state.last_min_iter}
|
||||
elif not navigate_earlier and task_state.last_max_iter is not None:
|
||||
range_condition = {"gt": task_state.last_max_iter}
|
||||
if range_condition:
|
||||
must_conditions.append({"range": {"iter": range_condition}})
|
||||
|
||||
if navigate_earlier:
|
||||
"""
|
||||
When navigating to earlier iterations consider only
|
||||
variants whose invalid iterations border is lower than
|
||||
our starting iteration. For these variants make sure
|
||||
that only events from the valid iterations are returned
|
||||
"""
|
||||
if not metric.last_min_iter:
|
||||
variants = metric.variants
|
||||
else:
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is None
|
||||
or v.last_invalid_iteration < metric.last_min_iter
|
||||
)
|
||||
if not variants:
|
||||
return metric.task, metric.name, []
|
||||
must_conditions.append(
|
||||
{"terms": {"variant": list(v.name for v in variants)}}
|
||||
)
|
||||
else:
|
||||
"""
|
||||
When navigating to later iterations all variants may be relevant.
|
||||
For the variants whose invalid border is higher than our starting
|
||||
iteration make sure that only events from valid iterations are returned
|
||||
"""
|
||||
variants = list(
|
||||
v
|
||||
for v in metric.variants
|
||||
if v.last_invalid_iteration is not None
|
||||
and v.last_invalid_iteration > metric.last_max_iter
|
||||
)
|
||||
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"variant": v.name}},
|
||||
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
if v.last_invalid_iteration is not None
|
||||
]
|
||||
if variants_conditions:
|
||||
must_not_conditions.append({"bool": {"should": variants_conditions}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {"must": must_conditions, "must_not": must_not_conditions}
|
||||
},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
@ -366,15 +341,26 @@ class DebugImagesIterator:
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"field": "metric",
|
||||
"size": EventSettings.max_metrics_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {"sort": {"url": {"order": "desc"}}}
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventSettings.max_variants_count,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"url": {"order": "desc"}}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -387,74 +373,41 @@ class DebugImagesIterator:
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
return task_state.task, []
|
||||
|
||||
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
|
||||
invalid_iterations = {
|
||||
(m.metric, v.variant): v.last_invalid_iteration
|
||||
for m in task_state.metrics
|
||||
for v in m.variants
|
||||
}
|
||||
|
||||
def is_valid_event(event: dict) -> bool:
|
||||
key = event.get("metric"), event.get("variant")
|
||||
if key not in invalid_iterations:
|
||||
return False
|
||||
|
||||
max_invalid = invalid_iterations[key]
|
||||
return max_invalid is None or event.get("iter") > max_invalid
|
||||
|
||||
def get_iteration_events(it_: dict) -> Sequence:
|
||||
return [
|
||||
ev["_source"]
|
||||
for v in variant_buckets
|
||||
for m in dpath.get(it_, "metrics/buckets")
|
||||
for v in dpath.get(m, "variants/buckets")
|
||||
for ev in dpath.get(v, "events/hits/hits")
|
||||
if is_valid_event(ev["_source"])
|
||||
]
|
||||
|
||||
iterations = [
|
||||
{
|
||||
"iter": it["key"],
|
||||
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
|
||||
}
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets")
|
||||
]
|
||||
iterations = []
|
||||
for it in dpath.get(es_res, "aggregations/iters/buckets"):
|
||||
events = get_iteration_events(it)
|
||||
if events:
|
||||
iterations.append({"iter": it["key"], "events": events})
|
||||
|
||||
if not navigate_earlier:
|
||||
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||
if iterations:
|
||||
metric.last_max_iter = iterations[0]["iter"]
|
||||
metric.last_min_iter = iterations[-1]["iter"]
|
||||
task_state.last_max_iter = iterations[0]["iter"]
|
||||
task_state.last_min_iter = iterations[-1]["iter"]
|
||||
|
||||
# Commented for now since the last invalid iteration is calculated in the beginning
|
||||
# if navigate_earlier and any(
|
||||
# variant.last_invalid_iteration is None for variant in variants
|
||||
# ):
|
||||
# """
|
||||
# Variants validation flags due to recycling can
|
||||
# be set only on navigation to earlier frames
|
||||
# """
|
||||
# iterations = self._update_variants_invalid_iterations(variants, iterations)
|
||||
|
||||
return metric.task, metric.name, iterations
|
||||
|
||||
@staticmethod
|
||||
def _update_variants_invalid_iterations(
|
||||
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
This code is currently not in used since the invalid iterations
|
||||
are calculated during MetricState initialization
|
||||
For variants that do not have recycle url marker set it from the
|
||||
first event
|
||||
For variants that do not have last_invalid_iteration set check if the
|
||||
recycle marker was reached on a certain iteration and set it to the
|
||||
corresponding iteration
|
||||
For variants that have a newly set last_invalid_iteration remove
|
||||
events from the invalid iterations
|
||||
Return the updated iterations list
|
||||
"""
|
||||
variants_lookup = bucketize(variants, attrgetter("name"))
|
||||
for it in iterations:
|
||||
iteration = it["iter"]
|
||||
events_to_remove = []
|
||||
for event in it["events"]:
|
||||
variant = variants_lookup[event["variant"]][0]
|
||||
if (
|
||||
variant.last_invalid_iteration
|
||||
and variant.last_invalid_iteration >= iteration
|
||||
):
|
||||
events_to_remove.append(event)
|
||||
continue
|
||||
event_url = event.get("url")
|
||||
if not variant.recycle_url_marker:
|
||||
variant.recycle_url_marker = event_url
|
||||
elif variant.recycle_url_marker == event_url:
|
||||
variant.last_invalid_iteration = iteration
|
||||
events_to_remove.append(event)
|
||||
if events_to_remove:
|
||||
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||
return [it for it in iterations if it["events"]]
|
||||
return task_state.task, iterations
|
||||
|
@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
|
||||
@ -133,7 +132,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
|
||||
task_metrics = {task: set(metrics)}
|
||||
scroll_id = None
|
||||
urls = defaultdict(set)
|
||||
urls = set()
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
@ -142,17 +141,16 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
state_id=scroll_id,
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
events for _, _, events in res.metric_events
|
||||
iterations for _, iterations in res.metric_events
|
||||
):
|
||||
break
|
||||
|
||||
scroll_id = res.next_scroll_id
|
||||
for _, metric, iterations in res.metric_events:
|
||||
metric_urls = set(ev.get("url") for it in iterations for ev in it["events"])
|
||||
metric_urls.discard(None)
|
||||
urls[metric].update(metric_urls)
|
||||
for task, iterations in res.metric_events:
|
||||
urls.update(ev.get("url") for it in iterations for ev in it["events"])
|
||||
|
||||
return set(chain.from_iterable(urls.values()))
|
||||
urls.discard({None})
|
||||
return urls
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
|
@ -193,10 +193,6 @@
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name. If not specified then all metrics for this task will be returned"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_log_event {
|
||||
@ -370,7 +366,7 @@
|
||||
}
|
||||
}
|
||||
"2.7" {
|
||||
description: "Get the debug image events for the requested amount of iterations per each task's metric"
|
||||
description: "Get the debug image events for the requested amount of iterations per each task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
|
@ -628,13 +628,12 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
metrics=[
|
||||
MetricEvents(
|
||||
task=task,
|
||||
metric=metric,
|
||||
iterations=[
|
||||
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||
for iteration in iterations
|
||||
],
|
||||
)
|
||||
for (task, metric, iterations) in result.metric_events
|
||||
for (task, iterations) in result.metric_events
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -130,11 +130,11 @@ class TestTaskDebugImages(TestService):
|
||||
|
||||
# test empty
|
||||
res = self.api.events.debug_images(metrics=[{"task": task}], iters=5)
|
||||
self.assertFalse(res.metrics)
|
||||
self.assertFalse(res.metrics[0].iterations)
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
|
||||
)
|
||||
self.assertFalse(res.metrics)
|
||||
self.assertFalse(res.metrics[0].iterations)
|
||||
|
||||
# test not empty
|
||||
metrics = {
|
||||
@ -180,10 +180,9 @@ class TestTaskDebugImages(TestService):
|
||||
)
|
||||
|
||||
# with refresh there are new metrics and existing ones are updated
|
||||
metrics.update(update)
|
||||
self._assertTaskMetrics(
|
||||
task=task,
|
||||
expected_metrics=metrics,
|
||||
expected_metrics=update,
|
||||
iterations=1,
|
||||
scroll_id=scroll_id,
|
||||
refresh=True,
|
||||
@ -202,17 +201,16 @@ class TestTaskDebugImages(TestService):
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
|
||||
)
|
||||
self.assertEqual(set(m.metric for m in res.metrics), set(expected_metrics))
|
||||
if not iterations:
|
||||
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||
return res.scroll_id
|
||||
|
||||
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
|
||||
for metric_data in res.metrics:
|
||||
expected_variants = set(expected_metrics[metric_data.metric])
|
||||
self.assertEqual(len(metric_data.iterations), iterations)
|
||||
for it_data in metric_data.iterations:
|
||||
self.assertEqual(
|
||||
set(e.variant for e in it_data.events), expected_variants
|
||||
set((e.metric, e.variant) for e in it_data.events), expected_variants
|
||||
)
|
||||
|
||||
return res.scroll_id
|
||||
@ -227,7 +225,7 @@ class TestTaskDebugImages(TestService):
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}], iters=5,
|
||||
)
|
||||
self.assertFalse(res.metrics)
|
||||
self.assertFalse(res.metrics[0].iterations)
|
||||
|
||||
# create events
|
||||
events = [
|
||||
@ -295,7 +293,6 @@ class TestTaskDebugImages(TestService):
|
||||
)
|
||||
data = res["metrics"][0]
|
||||
self.assertEqual(data["task"], task)
|
||||
self.assertEqual(data["metric"], metric)
|
||||
left_iterations = max(0, max(unique_images) - expected_page * iters)
|
||||
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
||||
for it in data["iterations"]:
|
||||
|
Loading…
Reference in New Issue
Block a user