mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add support for pagination in events.debug_images
This commit is contained in:
parent
69714d5b5c
commit
6c8508eb7f
@ -89,6 +89,8 @@ _error_codes = {
|
|||||||
1003: ('worker_registered', 'worker is already registered'),
|
1003: ('worker_registered', 'worker is already registered'),
|
||||||
1004: ('worker_not_registered', 'worker is not registered'),
|
1004: ('worker_not_registered', 'worker is not registered'),
|
||||||
1005: ('worker_stats_not_found', 'worker stats not found'),
|
1005: ('worker_stats_not_found', 'worker stats not found'),
|
||||||
|
|
||||||
|
1104: ('invalid_scroll_id', 'Invalid scroll id'),
|
||||||
},
|
},
|
||||||
|
|
||||||
(401, 'unauthorized'): {
|
(401, 'unauthorized'): {
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from jsonmodels.fields import StringField
|
from jsonmodels import validators
|
||||||
|
from jsonmodels.fields import StringField, BoolField
|
||||||
from jsonmodels.models import Base
|
from jsonmodels.models import Base
|
||||||
|
from jsonmodels.validators import Length
|
||||||
|
|
||||||
from apimodels import ListField, IntField, ActualEnumField
|
from apimodels import ListField, IntField, ActualEnumField
|
||||||
|
from bll.event.event_metrics import EventType
|
||||||
from bll.event.scalar_key import ScalarKeyEnum
|
from bll.event.scalar_key import ScalarKeyEnum
|
||||||
|
|
||||||
|
|
||||||
@ -17,4 +20,44 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
|||||||
|
|
||||||
|
|
||||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||||
tasks: Sequence[str] = ListField(items_types=str, required=True)
|
tasks: Sequence[str] = ListField(
|
||||||
|
items_types=str, validators=[Length(minimum_value=1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskMetric(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
metric: str = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImagesRequest(Base):
|
||||||
|
metrics: Sequence[TaskMetric] = ListField(
|
||||||
|
items_types=TaskMetric, validators=[Length(minimum_value=1)]
|
||||||
|
)
|
||||||
|
iters: int = IntField(default=1, validators=validators.Min(1))
|
||||||
|
navigate_earlier: bool = BoolField(default=True)
|
||||||
|
refresh: bool = BoolField(default=False)
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
|
||||||
|
|
||||||
|
class IterationEvents(Base):
|
||||||
|
iter: int = IntField()
|
||||||
|
events: Sequence[dict] = ListField(items_types=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricEvents(Base):
|
||||||
|
task: str = StringField()
|
||||||
|
metric: str = StringField()
|
||||||
|
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImageResponse(Base):
|
||||||
|
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
|
||||||
|
scroll_id: str = StringField()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskMetricsRequest(Base):
|
||||||
|
tasks: Sequence[str] = ListField(
|
||||||
|
items_types=str, validators=[Length(minimum_value=1)]
|
||||||
|
)
|
||||||
|
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||||
|
464
server/bll/event/debug_images_iterator.py
Normal file
464
server/bll/event/debug_images_iterator.py
Normal file
@ -0,0 +1,464 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
from itertools import chain
|
||||||
|
from operator import attrgetter, itemgetter
|
||||||
|
|
||||||
|
import attr
|
||||||
|
import dpath
|
||||||
|
from boltons.iterutils import bucketize
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from redis import StrictRedis
|
||||||
|
from typing import Sequence, Tuple, Optional, Mapping
|
||||||
|
|
||||||
|
import database
|
||||||
|
from apierrors import errors
|
||||||
|
from bll.redis_cache_manager import RedisCacheManager
|
||||||
|
from bll.event.event_metrics import EventMetrics
|
||||||
|
from config import config
|
||||||
|
from database.errors import translate_errors_context
|
||||||
|
from jsonmodels.models import Base
|
||||||
|
from jsonmodels.fields import StringField, ListField, IntField
|
||||||
|
|
||||||
|
from database.model.task.metrics import MetricEventStats
|
||||||
|
from database.model.task.task import Task
|
||||||
|
from timing_context import TimingContext
|
||||||
|
from utilities.json import loads, dumps
|
||||||
|
|
||||||
|
|
||||||
|
class VariantScrollState(Base):
|
||||||
|
name: str = StringField(required=True)
|
||||||
|
recycle_url_marker: str = StringField()
|
||||||
|
last_invalid_iteration: int = IntField()
|
||||||
|
|
||||||
|
|
||||||
|
class MetricScrollState(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
name: str = StringField(required=True)
|
||||||
|
last_min_iter: Optional[int] = IntField()
|
||||||
|
last_max_iter: Optional[int] = IntField()
|
||||||
|
timestamp: int = IntField(default=0)
|
||||||
|
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the scrolling state for the metric"""
|
||||||
|
self.last_min_iter = self.last_max_iter = None
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImageEventsScrollState(Base):
|
||||||
|
id: str = StringField(required=True)
|
||||||
|
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return dumps(self.to_struct())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, s):
|
||||||
|
return cls(**loads(s))
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class DebugImagesResult(object):
|
||||||
|
metric_events: Sequence[tuple] = []
|
||||||
|
next_scroll_id: str = None
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImagesIterator:
|
||||||
|
EVENT_TYPE = "training_debug_image"
|
||||||
|
STATE_EXPIRATION_SECONDS = 3600
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _max_workers(self):
|
||||||
|
return config.get("services.events.max_metrics_concurrency", 4)
|
||||||
|
|
||||||
|
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||||
|
self.es = es
|
||||||
|
self.cache_manager = RedisCacheManager(
|
||||||
|
state_class=DebugImageEventsScrollState,
|
||||||
|
redis=redis,
|
||||||
|
expiration_interval=self.STATE_EXPIRATION_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_task_events(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
metrics: Sequence[Tuple[str, str]],
|
||||||
|
iter_count: int,
|
||||||
|
navigate_earlier: bool = True,
|
||||||
|
refresh: bool = False,
|
||||||
|
state_id: str = None,
|
||||||
|
) -> DebugImagesResult:
|
||||||
|
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||||
|
if not self.es.indices.exists(es_index):
|
||||||
|
return DebugImagesResult()
|
||||||
|
|
||||||
|
unique_metrics = set(metrics)
|
||||||
|
state = self.cache_manager.get_state(state_id) if state_id else None
|
||||||
|
if not state:
|
||||||
|
state = DebugImageEventsScrollState(
|
||||||
|
id=database.utils.id(),
|
||||||
|
metrics=self._init_metric_states(es_index, list(unique_metrics)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state_metrics = set((m.task, m.name) for m in state.metrics)
|
||||||
|
if state_metrics != unique_metrics:
|
||||||
|
raise errors.bad_request.InvalidScrollId(
|
||||||
|
"while getting debug images events", scroll_id=state_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if refresh:
|
||||||
|
self._reinit_outdated_metric_states(company_id, es_index, state)
|
||||||
|
for metric_state in state.metrics:
|
||||||
|
metric_state.reset()
|
||||||
|
|
||||||
|
res = DebugImagesResult(next_scroll_id=state.id)
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||||
|
res.metric_events = list(
|
||||||
|
pool.map(
|
||||||
|
partial(
|
||||||
|
self._get_task_metric_events,
|
||||||
|
es_index=es_index,
|
||||||
|
iter_count=iter_count,
|
||||||
|
navigate_earlier=navigate_earlier,
|
||||||
|
),
|
||||||
|
state.metrics,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self.cache_manager.set_state(state)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _reinit_outdated_metric_states(
|
||||||
|
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Determines the metrics for which new debug image events were added
|
||||||
|
since their states were initialized and reinits these states
|
||||||
|
"""
|
||||||
|
task_ids = set(metric.task for metric in state.metrics)
|
||||||
|
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
|
||||||
|
"id", "metric_stats"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
|
||||||
|
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
|
||||||
|
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
|
||||||
|
if not metric_stats:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
(task.id, stats.metric),
|
||||||
|
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
|
||||||
|
)
|
||||||
|
for stats in metric_stats.values()
|
||||||
|
if self.EVENT_TYPE in stats.event_stats_by_type
|
||||||
|
]
|
||||||
|
|
||||||
|
update_times = dict(
|
||||||
|
chain.from_iterable(
|
||||||
|
get_last_update_times_for_task_metrics(task) for task in tasks
|
||||||
|
)
|
||||||
|
)
|
||||||
|
outdated_metrics = [
|
||||||
|
metric
|
||||||
|
for metric in state.metrics
|
||||||
|
if (metric.task, metric.name) in update_times
|
||||||
|
and update_times[metric.task, metric.name] > metric.timestamp
|
||||||
|
]
|
||||||
|
state.metrics = [
|
||||||
|
*(metric for metric in state.metrics if metric not in outdated_metrics),
|
||||||
|
*(
|
||||||
|
self._init_metric_states(
|
||||||
|
es_index,
|
||||||
|
[(metric.task, metric.name) for metric in outdated_metrics],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _init_metric_states(
|
||||||
|
self, es_index, metrics: Sequence[Tuple[str, str]]
|
||||||
|
) -> Sequence[MetricScrollState]:
|
||||||
|
"""
|
||||||
|
Returned initialized metric scroll stated for the requested task metrics
|
||||||
|
"""
|
||||||
|
tasks = defaultdict(list)
|
||||||
|
for (task, metric) in metrics:
|
||||||
|
tasks[task].append(metric)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||||
|
return list(
|
||||||
|
chain.from_iterable(
|
||||||
|
pool.map(
|
||||||
|
partial(self._init_metric_states_for_task, es_index=es_index),
|
||||||
|
tasks.items(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_metric_states_for_task(
|
||||||
|
self, task_metrics: Tuple[str, Sequence[str]], es_index
|
||||||
|
) -> Sequence[MetricScrollState]:
|
||||||
|
"""
|
||||||
|
Return metric scroll states for the task filled with the variant states
|
||||||
|
for the variants that reported any debug images
|
||||||
|
"""
|
||||||
|
task, metrics = task_metrics
|
||||||
|
es_req: dict = {
|
||||||
|
"size": 0,
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"metrics": {
|
||||||
|
"terms": {
|
||||||
|
"field": "metric",
|
||||||
|
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"last_event_timestamp": {"max": {"field": "timestamp"}},
|
||||||
|
"variants": {
|
||||||
|
"terms": {
|
||||||
|
"field": "variant",
|
||||||
|
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"urls": {
|
||||||
|
"terms": {
|
||||||
|
"field": "url",
|
||||||
|
"order": {"max_iter": "desc"},
|
||||||
|
"size": 1, # we need only one url from the most recent iteration
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"max_iter": {"max": {"field": "iter"}},
|
||||||
|
"iters": {
|
||||||
|
"top_hits": {
|
||||||
|
"sort": {"iter": {"order": "desc"}},
|
||||||
|
"size": 2, # need two last iterations so that we can take
|
||||||
|
# the second one as invalid
|
||||||
|
"_source": "iter",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||||
|
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||||
|
if "aggregations" not in es_res:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def init_variant_scroll_state(variant: dict):
|
||||||
|
"""
|
||||||
|
Return new variant scroll state for the passed variant bucket
|
||||||
|
If the image urls get recycled then fill the last_invalid_iteration field
|
||||||
|
"""
|
||||||
|
state = VariantScrollState(name=variant["key"])
|
||||||
|
top_iter_url = dpath.get(variant, "urls/buckets")[0]
|
||||||
|
iters = dpath.get(top_iter_url, "iters/hits/hits")
|
||||||
|
if len(iters) > 1:
|
||||||
|
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
|
||||||
|
return state
|
||||||
|
|
||||||
|
return [
|
||||||
|
MetricScrollState(
|
||||||
|
task=task,
|
||||||
|
name=metric["key"],
|
||||||
|
variants=[
|
||||||
|
init_variant_scroll_state(variant)
|
||||||
|
for variant in dpath.get(metric, "variants/buckets")
|
||||||
|
],
|
||||||
|
timestamp=dpath.get(metric, "last_event_timestamp/value"),
|
||||||
|
)
|
||||||
|
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_task_metric_events(
|
||||||
|
self,
|
||||||
|
metric: MetricScrollState,
|
||||||
|
es_index: str,
|
||||||
|
iter_count: int,
|
||||||
|
navigate_earlier: bool,
|
||||||
|
) -> Tuple:
|
||||||
|
"""
|
||||||
|
Return task metric events grouped by iterations
|
||||||
|
Update metric scroll state
|
||||||
|
"""
|
||||||
|
if metric.last_max_iter is None:
|
||||||
|
# the first fetch is always from the latest iteration to the earlier ones
|
||||||
|
navigate_earlier = True
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": metric.task}},
|
||||||
|
{"term": {"metric": metric.name}},
|
||||||
|
]
|
||||||
|
must_not_conditions = []
|
||||||
|
|
||||||
|
range_condition = None
|
||||||
|
if navigate_earlier and metric.last_min_iter is not None:
|
||||||
|
range_condition = {"lt": metric.last_min_iter}
|
||||||
|
elif not navigate_earlier and metric.last_max_iter is not None:
|
||||||
|
range_condition = {"gt": metric.last_max_iter}
|
||||||
|
if range_condition:
|
||||||
|
must_conditions.append({"range": {"iter": range_condition}})
|
||||||
|
|
||||||
|
if navigate_earlier:
|
||||||
|
"""
|
||||||
|
When navigating to earlier iterations consider only
|
||||||
|
variants whose invalid iterations border is lower than
|
||||||
|
our starting iteration. For these variants make sure
|
||||||
|
that only events from the valid iterations are returned
|
||||||
|
"""
|
||||||
|
if not metric.last_min_iter:
|
||||||
|
variants = metric.variants
|
||||||
|
else:
|
||||||
|
variants = list(
|
||||||
|
v
|
||||||
|
for v in metric.variants
|
||||||
|
if v.last_invalid_iteration is None
|
||||||
|
or v.last_invalid_iteration < metric.last_min_iter
|
||||||
|
)
|
||||||
|
if not variants:
|
||||||
|
return metric.task, metric.name, []
|
||||||
|
must_conditions.append(
|
||||||
|
{"terms": {"variant": list(v.name for v in variants)}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
When navigating to later iterations all variants may be relevant.
|
||||||
|
For the variants whose invalid border is higher than our starting
|
||||||
|
iteration make sure that only events from valid iterations are returned
|
||||||
|
"""
|
||||||
|
variants = list(
|
||||||
|
v
|
||||||
|
for v in metric.variants
|
||||||
|
if v.last_invalid_iteration is not None
|
||||||
|
and v.last_invalid_iteration > metric.last_max_iter
|
||||||
|
)
|
||||||
|
|
||||||
|
variants_conditions = [
|
||||||
|
{
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"variant": v.name}},
|
||||||
|
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for v in variants
|
||||||
|
if v.last_invalid_iteration is not None
|
||||||
|
]
|
||||||
|
if variants_conditions:
|
||||||
|
must_not_conditions.append({"bool": {"should": variants_conditions}})
|
||||||
|
|
||||||
|
es_req = {
|
||||||
|
"size": 0,
|
||||||
|
"query": {
|
||||||
|
"bool": {"must": must_conditions, "must_not": must_not_conditions}
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"iters": {
|
||||||
|
"terms": {
|
||||||
|
"field": "iter",
|
||||||
|
"size": iter_count,
|
||||||
|
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"variants": {
|
||||||
|
"terms": {
|
||||||
|
"field": "variant",
|
||||||
|
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"events": {
|
||||||
|
"top_hits": {"sort": {"url": {"order": "desc"}}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||||
|
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||||
|
if "aggregations" not in es_res:
|
||||||
|
return metric.task, metric.name, []
|
||||||
|
|
||||||
|
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
|
||||||
|
return [
|
||||||
|
ev["_source"]
|
||||||
|
for v in variant_buckets
|
||||||
|
for ev in dpath.get(v, "events/hits/hits")
|
||||||
|
]
|
||||||
|
|
||||||
|
iterations = [
|
||||||
|
{
|
||||||
|
"iter": it["key"],
|
||||||
|
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
|
||||||
|
}
|
||||||
|
for it in dpath.get(es_res, "aggregations/iters/buckets")
|
||||||
|
]
|
||||||
|
if not navigate_earlier:
|
||||||
|
iterations.sort(key=itemgetter("iter"), reverse=True)
|
||||||
|
if iterations:
|
||||||
|
metric.last_max_iter = iterations[0]["iter"]
|
||||||
|
metric.last_min_iter = iterations[-1]["iter"]
|
||||||
|
|
||||||
|
# Commented for now since the last invalid iteration is calculated in the beginning
|
||||||
|
# if navigate_earlier and any(
|
||||||
|
# variant.last_invalid_iteration is None for variant in variants
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Variants validation flags due to recycling can
|
||||||
|
# be set only on navigation to earlier frames
|
||||||
|
# """
|
||||||
|
# iterations = self._update_variants_invalid_iterations(variants, iterations)
|
||||||
|
|
||||||
|
return metric.task, metric.name, iterations
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_variants_invalid_iterations(
|
||||||
|
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
|
||||||
|
) -> Sequence[dict]:
|
||||||
|
"""
|
||||||
|
This code is currently not in used since the invalid iterations
|
||||||
|
are calculated during MetricState initialization
|
||||||
|
For variants that do not have recycle url marker set it from the
|
||||||
|
first event
|
||||||
|
For variants that do not have last_invalid_iteration set check if the
|
||||||
|
recycle marker was reached on a certain iteration and set it to the
|
||||||
|
corresponding iteration
|
||||||
|
For variants that have a newly set last_invalid_iteration remove
|
||||||
|
events from the invalid iterations
|
||||||
|
Return the updated iterations list
|
||||||
|
"""
|
||||||
|
variants_lookup = bucketize(variants, attrgetter("name"))
|
||||||
|
for it in iterations:
|
||||||
|
iteration = it["iter"]
|
||||||
|
events_to_remove = []
|
||||||
|
for event in it["events"]:
|
||||||
|
variant = variants_lookup[event["variant"]][0]
|
||||||
|
if (
|
||||||
|
variant.last_invalid_iteration
|
||||||
|
and variant.last_invalid_iteration >= iteration
|
||||||
|
):
|
||||||
|
events_to_remove.append(event)
|
||||||
|
continue
|
||||||
|
event_url = event.get("url")
|
||||||
|
if not variant.recycle_url_marker:
|
||||||
|
variant.recycle_url_marker = event_url
|
||||||
|
elif variant.recycle_url_marker == event_url:
|
||||||
|
variant.last_invalid_iteration = iteration
|
||||||
|
events_to_remove.append(event)
|
||||||
|
if events_to_remove:
|
||||||
|
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
|
||||||
|
return [it for it in iterations if it["events"]]
|
@ -2,7 +2,6 @@ import hashlib
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
@ -15,46 +14,39 @@ from nested_dict import nested_dict
|
|||||||
import database.utils as dbutils
|
import database.utils as dbutils
|
||||||
import es_factory
|
import es_factory
|
||||||
from apierrors import errors
|
from apierrors import errors
|
||||||
from bll.event.event_metrics import EventMetrics
|
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||||
|
from bll.event.event_metrics import EventMetrics, EventType
|
||||||
from bll.task import TaskBLL
|
from bll.task import TaskBLL
|
||||||
from config import config
|
from config import config
|
||||||
from database.errors import translate_errors_context
|
from database.errors import translate_errors_context
|
||||||
from database.model.task.task import Task, TaskStatus
|
from database.model.task.task import Task, TaskStatus
|
||||||
|
from redis_manager import redman
|
||||||
from timing_context import TimingContext
|
from timing_context import TimingContext
|
||||||
from utilities.dicts import flatten_nested_items
|
from utilities.dicts import flatten_nested_items
|
||||||
|
|
||||||
|
|
||||||
class EventType(Enum):
|
|
||||||
metrics_scalar = "training_stats_scalar"
|
|
||||||
metrics_vector = "training_stats_vector"
|
|
||||||
metrics_image = "training_debug_image"
|
|
||||||
metrics_plot = "plot"
|
|
||||||
task_log = "log"
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||||
|
|
||||||
|
|
||||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(auto_attribs=True)
|
||||||
class TaskEventsResult(object):
|
class TaskEventsResult(object):
|
||||||
events = attr.ib(type=list, default=attr.Factory(list))
|
total_events: int = 0
|
||||||
total_events = attr.ib(type=int, default=0)
|
next_scroll_id: str = None
|
||||||
next_scroll_id = attr.ib(type=str, default=None)
|
events: list = attr.ib(factory=list)
|
||||||
|
|
||||||
|
|
||||||
class EventBLL(object):
|
class EventBLL(object):
|
||||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||||
|
|
||||||
def __init__(self, events_es=None):
|
def __init__(self, events_es=None, redis=None):
|
||||||
self.es = events_es or es_factory.connect("events")
|
self.es = events_es or es_factory.connect("events")
|
||||||
self._metrics = EventMetrics(self.es)
|
self._metrics = EventMetrics(self.es)
|
||||||
self._skip_iteration_for_metric = set(
|
self._skip_iteration_for_metric = set(
|
||||||
config.get("services.events.ignore_iteration.metrics", [])
|
config.get("services.events.ignore_iteration.metrics", [])
|
||||||
)
|
)
|
||||||
|
self.redis = redis or redman.connection("apiserver")
|
||||||
|
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metrics(self) -> EventMetrics:
|
def metrics(self) -> EventMetrics:
|
||||||
@ -64,9 +56,12 @@ class EventBLL(object):
|
|||||||
actions = []
|
actions = []
|
||||||
task_ids = set()
|
task_ids = set()
|
||||||
task_iteration = defaultdict(lambda: 0)
|
task_iteration = defaultdict(lambda: 0)
|
||||||
task_last_events = nested_dict(
|
task_last_scalar_events = nested_dict(
|
||||||
3, dict
|
3, dict
|
||||||
) # task_id -> metric_hash -> variant_hash -> MetricEvent
|
) # task_id -> metric_hash -> variant_hash -> MetricEvent
|
||||||
|
task_last_events = nested_dict(
|
||||||
|
3, dict
|
||||||
|
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
# remove spaces from event type
|
# remove spaces from event type
|
||||||
@ -108,6 +103,9 @@ class EventBLL(object):
|
|||||||
event["value"] = event["values"]
|
event["value"] = event["values"]
|
||||||
del event["values"]
|
del event["values"]
|
||||||
|
|
||||||
|
event["metric"] = event.get("metric") or ""
|
||||||
|
event["variant"] = event.get("variant") or ""
|
||||||
|
|
||||||
index_name = EventMetrics.get_index_name(company_id, event_type)
|
index_name = EventMetrics.get_index_name(company_id, event_type)
|
||||||
es_action = {
|
es_action = {
|
||||||
"_op_type": "index", # overwrite if exists with same ID
|
"_op_type": "index", # overwrite if exists with same ID
|
||||||
@ -132,9 +130,12 @@ class EventBLL(object):
|
|||||||
):
|
):
|
||||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||||
|
|
||||||
|
self._update_last_metric_events_for_task(
|
||||||
|
last_events=task_last_events[task_id], event=event,
|
||||||
|
)
|
||||||
if event_type == EventType.metrics_scalar.value:
|
if event_type == EventType.metrics_scalar.value:
|
||||||
self._update_last_metric_event_for_task(
|
self._update_last_scalar_events_for_task(
|
||||||
task_last_events=task_last_events, task_id=task_id, event=event
|
last_events=task_last_scalar_events[task_id], event=event
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
es_action["_routing"] = task_id
|
es_action["_routing"] = task_id
|
||||||
@ -187,6 +188,7 @@ class EventBLL(object):
|
|||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
now=now,
|
now=now,
|
||||||
iter_max=task_iteration.get(task_id),
|
iter_max=task_iteration.get(task_id),
|
||||||
|
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||||
last_events=task_last_events.get(task_id),
|
last_events=task_last_events.get(task_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,12 +204,12 @@ class EventBLL(object):
|
|||||||
|
|
||||||
return added, errors_in_bulk
|
return added, errors_in_bulk
|
||||||
|
|
||||||
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
|
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||||
"""
|
"""
|
||||||
Update task_last_events structure for the provided task_id with the provided event details if this event is more
|
Update last_events structure with the provided event details if this event is more
|
||||||
recent than the currently stored event for its metric/variant combination.
|
recent than the currently stored event for its metric/variant combination.
|
||||||
|
|
||||||
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||||
key conflicts due to invalid characters and/or long field names.
|
key conflicts due to invalid characters and/or long field names.
|
||||||
"""
|
"""
|
||||||
metric = event.get("metric")
|
metric = event.get("metric")
|
||||||
@ -218,13 +220,34 @@ class EventBLL(object):
|
|||||||
metric_hash = dbutils.hash_field_name(metric)
|
metric_hash = dbutils.hash_field_name(metric)
|
||||||
variant_hash = dbutils.hash_field_name(variant)
|
variant_hash = dbutils.hash_field_name(variant)
|
||||||
|
|
||||||
last_events = task_last_events[task_id]
|
|
||||||
|
|
||||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||||
if timestamp is None or timestamp < event["timestamp"]:
|
if timestamp is None or timestamp < event["timestamp"]:
|
||||||
last_events[metric_hash][variant_hash] = event
|
last_events[metric_hash][variant_hash] = event
|
||||||
|
|
||||||
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None):
|
def _update_last_metric_events_for_task(self, last_events, event):
|
||||||
|
"""
|
||||||
|
Update last_events structure with the provided event details if this event is more
|
||||||
|
recent than the currently stored event for its metric/event_type combination.
|
||||||
|
last_events contains [metric_name -> event_type -> event]
|
||||||
|
"""
|
||||||
|
metric = event.get("metric")
|
||||||
|
event_type = event.get("type")
|
||||||
|
if not (metric and event_type):
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = last_events[metric][event_type].get("timestamp", None)
|
||||||
|
if timestamp is None or timestamp < event["timestamp"]:
|
||||||
|
last_events[metric][event_type] = event
|
||||||
|
|
||||||
|
def _update_task(
|
||||||
|
self,
|
||||||
|
company_id,
|
||||||
|
task_id,
|
||||||
|
now,
|
||||||
|
iter_max=None,
|
||||||
|
last_scalar_events=None,
|
||||||
|
last_events=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Update task information in DB with aggregated results after handling event(s) related to this task.
|
Update task information in DB with aggregated results after handling event(s) related to this task.
|
||||||
|
|
||||||
@ -237,15 +260,18 @@ class EventBLL(object):
|
|||||||
if iter_max is not None:
|
if iter_max is not None:
|
||||||
fields["last_iteration_max"] = iter_max
|
fields["last_iteration_max"] = iter_max
|
||||||
|
|
||||||
if last_events:
|
if last_scalar_events:
|
||||||
fields["last_values"] = list(
|
fields["last_scalar_values"] = list(
|
||||||
flatten_nested_items(
|
flatten_nested_items(
|
||||||
last_events,
|
last_scalar_events,
|
||||||
nesting=2,
|
nesting=2,
|
||||||
include_leaves=["value", "metric", "variant"],
|
include_leaves=["value", "metric", "variant"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if last_events:
|
||||||
|
fields["last_events"] = last_events
|
||||||
|
|
||||||
if not fields:
|
if not fields:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
from typing import Sequence, Tuple, Callable, Iterable
|
||||||
|
|
||||||
from boltons.iterutils import bucketize
|
from boltons.iterutils import bucketize
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from typing import Sequence, Tuple, Callable, Iterable
|
|
||||||
|
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
|
|
||||||
from apierrors import errors
|
from apierrors import errors
|
||||||
@ -21,6 +21,14 @@ from utilities import safe_get
|
|||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class EventType(Enum):
|
||||||
|
metrics_scalar = "training_stats_scalar"
|
||||||
|
metrics_vector = "training_stats_vector"
|
||||||
|
metrics_image = "training_debug_image"
|
||||||
|
metrics_plot = "plot"
|
||||||
|
task_log = "log"
|
||||||
|
|
||||||
|
|
||||||
class EventMetrics:
|
class EventMetrics:
|
||||||
MAX_TASKS_COUNT = 50
|
MAX_TASKS_COUNT = 50
|
||||||
MAX_METRICS_COUNT = 200
|
MAX_METRICS_COUNT = 200
|
||||||
@ -66,7 +74,8 @@ class EventMetrics:
|
|||||||
"""
|
"""
|
||||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||||
raise errors.BadRequest(
|
raise errors.BadRequest(
|
||||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", len(task_ids)
|
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
|
||||||
|
len(task_ids),
|
||||||
)
|
)
|
||||||
|
|
||||||
task_name_by_id = {}
|
task_name_by_id = {}
|
||||||
@ -168,9 +177,7 @@ class EventMetrics:
|
|||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||||
metrics = itertools.chain.from_iterable(
|
metrics = itertools.chain.from_iterable(
|
||||||
pool.map(
|
pool.map(
|
||||||
partial(
|
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
|
||||||
get_func, task_ids=task_ids, es_index=es_index, key=key
|
|
||||||
),
|
|
||||||
intervals,
|
intervals,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -440,3 +447,50 @@ class EventMetrics:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_tasks_metrics(
|
||||||
|
self, company_id, task_ids: Sequence, event_type: EventType
|
||||||
|
) -> Sequence:
|
||||||
|
"""
|
||||||
|
For the requested tasks return all the metrics that
|
||||||
|
reported events of the requested types
|
||||||
|
"""
|
||||||
|
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||||
|
if not self.es.indices.exists(es_index):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||||
|
with ThreadPoolExecutor(max_concurrency) as pool:
|
||||||
|
res = pool.map(
|
||||||
|
partial(
|
||||||
|
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||||
|
),
|
||||||
|
task_ids,
|
||||||
|
)
|
||||||
|
return list(zip(task_ids, res))
|
||||||
|
|
||||||
|
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
|
||||||
|
es_req = {
|
||||||
|
"size": 0,
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"task": task_id}},
|
||||||
|
{"term": {"type": event_type.value}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"metrics": {
|
||||||
|
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||||
|
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||||
|
|
||||||
|
return [
|
||||||
|
metric["key"]
|
||||||
|
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
|
||||||
|
]
|
||||||
|
44
server/bll/redis_cache_manager.py
Normal file
44
server/bll/redis_cache_manager.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from typing import Optional, TypeVar, Generic, Type
|
||||||
|
|
||||||
|
from redis import StrictRedis
|
||||||
|
|
||||||
|
from timing_context import TimingContext
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCacheManager(Generic[T]):
|
||||||
|
"""
|
||||||
|
Class for store/retreive of state objects from redis
|
||||||
|
|
||||||
|
self.state_class - class of the state
|
||||||
|
self.redis - instance of redis
|
||||||
|
self.expiration_interval - expiration interval in seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
|
||||||
|
):
|
||||||
|
self.state_class = state_class
|
||||||
|
self.redis = redis
|
||||||
|
self.expiration_interval = expiration_interval
|
||||||
|
|
||||||
|
def set_state(self, state: T) -> None:
|
||||||
|
redis_key = self._get_redis_key(state.id)
|
||||||
|
with TimingContext("redis", "cache_set_state"):
|
||||||
|
self.redis.set(redis_key, state.to_json())
|
||||||
|
self.redis.expire(redis_key, self.expiration_interval)
|
||||||
|
|
||||||
|
def get_state(self, state_id) -> Optional[T]:
|
||||||
|
redis_key = self._get_redis_key(state_id)
|
||||||
|
with TimingContext("redis", "cache_get_state"):
|
||||||
|
response = self.redis.get(redis_key)
|
||||||
|
if response:
|
||||||
|
return self.state_class.from_json(response)
|
||||||
|
|
||||||
|
def delete_state(self, state_id) -> None:
|
||||||
|
with TimingContext("redis", "cache_delete_state"):
|
||||||
|
self.redis.delete(self._get_redis_key(state_id))
|
||||||
|
|
||||||
|
def _get_redis_key(self, state_id):
|
||||||
|
return f"{self.state_class}/{state_id}"
|
@ -3,13 +3,14 @@ from datetime import datetime, timedelta
|
|||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from random import random
|
from random import random
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Collection, Sequence, Tuple, Any, Optional, List
|
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||||
|
|
||||||
import pymongo.results
|
import pymongo.results
|
||||||
import six
|
import six
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
|
import database.utils as dbutils
|
||||||
import es_factory
|
import es_factory
|
||||||
from apierrors import errors
|
from apierrors import errors
|
||||||
from apimodels.tasks import Artifact as ApiArtifact
|
from apimodels.tasks import Artifact as ApiArtifact
|
||||||
@ -17,6 +18,7 @@ from config import config
|
|||||||
from database.errors import translate_errors_context
|
from database.errors import translate_errors_context
|
||||||
from database.model.model import Model
|
from database.model.model import Model
|
||||||
from database.model.project import Project
|
from database.model.project import Project
|
||||||
|
from database.model.task.metrics import EventStats, MetricEventStats
|
||||||
from database.model.task.output import Output
|
from database.model.task.output import Output
|
||||||
from database.model.task.task import (
|
from database.model.task.task import (
|
||||||
Task,
|
Task,
|
||||||
@ -197,7 +199,9 @@ class TaskBLL(object):
|
|||||||
system_tags=system_tags or [],
|
system_tags=system_tags or [],
|
||||||
type=task.type,
|
type=task.type,
|
||||||
script=task.script,
|
script=task.script,
|
||||||
output=Output(destination=task.output.destination) if task.output else None,
|
output=Output(destination=task.output.destination)
|
||||||
|
if task.output
|
||||||
|
else None,
|
||||||
execution=execution_dict,
|
execution=execution_dict,
|
||||||
)
|
)
|
||||||
cls.validate(new_task)
|
cls.validate(new_task)
|
||||||
@ -277,7 +281,8 @@ class TaskBLL(object):
|
|||||||
last_update: datetime = None,
|
last_update: datetime = None,
|
||||||
last_iteration: int = None,
|
last_iteration: int = None,
|
||||||
last_iteration_max: int = None,
|
last_iteration_max: int = None,
|
||||||
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||||
|
last_events: Dict[str, Dict[str, dict]] = None,
|
||||||
**extra_updates,
|
**extra_updates,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -289,7 +294,8 @@ class TaskBLL(object):
|
|||||||
task's last iteration value.
|
task's last iteration value.
|
||||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||||
if the current task's last iteration value is smaller than the provided value.
|
if the current task's last iteration value is smaller than the provided value.
|
||||||
:param last_values: Last reported metrics summary (value, metric, variant).
|
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
|
||||||
|
:param last_events: Last reported metrics summary (value, metric, event type).
|
||||||
:param extra_updates: Extra task updates to include in this update call.
|
:param extra_updates: Extra task updates to include in this update call.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -300,17 +306,33 @@ class TaskBLL(object):
|
|||||||
elif last_iteration_max is not None:
|
elif last_iteration_max is not None:
|
||||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||||
|
|
||||||
if last_values is not None:
|
if last_scalar_values is not None:
|
||||||
|
|
||||||
def op_path(op, *path):
|
def op_path(op, *path):
|
||||||
return "__".join((op, "last_metrics") + path)
|
return "__".join((op, "last_metrics") + path)
|
||||||
|
|
||||||
for path, value in last_values:
|
for path, value in last_scalar_values:
|
||||||
extra_updates[op_path("set", *path)] = value
|
extra_updates[op_path("set", *path)] = value
|
||||||
if path[-1] == "value":
|
if path[-1] == "value":
|
||||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||||
|
|
||||||
|
if last_events is not None:
|
||||||
|
|
||||||
|
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
|
||||||
|
return {
|
||||||
|
event_type: EventStats(last_update=event["timestamp"])
|
||||||
|
for event_type, event in metric_data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
metric_stats = {
|
||||||
|
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||||
|
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
|
||||||
|
)
|
||||||
|
for metric_key, metric_data in last_events.items()
|
||||||
|
}
|
||||||
|
extra_updates["metric_stats"] = metric_stats
|
||||||
|
|
||||||
Task.objects(id=task_id, company=company_id).update(
|
Task.objects(id=task_id, company=company_id).update(
|
||||||
upsert=False, last_update=last_update, **extra_updates
|
upsert=False, last_update=last_update, **extra_updates
|
||||||
)
|
)
|
||||||
|
@ -32,6 +32,11 @@ mongo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
redis {
|
redis {
|
||||||
|
apiserver {
|
||||||
|
host: "127.0.0.1"
|
||||||
|
port: 6379
|
||||||
|
db: 0
|
||||||
|
}
|
||||||
workers {
|
workers {
|
||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
port: 6379
|
port: 6379
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
from mongoengine import EmbeddedDocument, StringField, DynamicField
|
from mongoengine import (
|
||||||
|
EmbeddedDocument,
|
||||||
|
StringField,
|
||||||
|
DynamicField,
|
||||||
|
LongField,
|
||||||
|
EmbeddedDocumentField,
|
||||||
|
)
|
||||||
|
|
||||||
|
from database.fields import SafeMapField
|
||||||
|
|
||||||
|
|
||||||
class MetricEvent(EmbeddedDocument):
|
class MetricEvent(EmbeddedDocument):
|
||||||
meta = {
|
meta = {
|
||||||
# For backwards compatibility reasons
|
# For backwards compatibility reasons
|
||||||
'strict': False,
|
"strict": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
metric = StringField(required=True)
|
metric = StringField(required=True)
|
||||||
@ -12,3 +20,20 @@ class MetricEvent(EmbeddedDocument):
|
|||||||
value = DynamicField(required=True)
|
value = DynamicField(required=True)
|
||||||
min_value = DynamicField() # for backwards compatibility reasons
|
min_value = DynamicField() # for backwards compatibility reasons
|
||||||
max_value = DynamicField() # for backwards compatibility reasons
|
max_value = DynamicField() # for backwards compatibility reasons
|
||||||
|
|
||||||
|
|
||||||
|
class EventStats(EmbeddedDocument):
|
||||||
|
meta = {
|
||||||
|
# For backwards compatibility reasons
|
||||||
|
"strict": False,
|
||||||
|
}
|
||||||
|
last_update = LongField()
|
||||||
|
|
||||||
|
|
||||||
|
class MetricEventStats(EmbeddedDocument):
|
||||||
|
meta = {
|
||||||
|
# For backwards compatibility reasons
|
||||||
|
"strict": False,
|
||||||
|
}
|
||||||
|
metric = StringField(required=True)
|
||||||
|
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))
|
||||||
|
@ -22,7 +22,7 @@ from database.model.base import ProperDictMixin
|
|||||||
from database.model.model_labels import ModelLabels
|
from database.model.model_labels import ModelLabels
|
||||||
from database.model.project import Project
|
from database.model.project import Project
|
||||||
from database.utils import get_options
|
from database.utils import get_options
|
||||||
from .metrics import MetricEvent
|
from .metrics import MetricEvent, MetricEventStats
|
||||||
from .output import Output
|
from .output import Output
|
||||||
|
|
||||||
DEFAULT_LAST_ITERATION = 0
|
DEFAULT_LAST_ITERATION = 0
|
||||||
@ -162,3 +162,4 @@ class Task(AttributedDocument):
|
|||||||
last_update = DateTimeField()
|
last_update = DateTimeField()
|
||||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||||
|
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||||
|
@ -171,6 +171,30 @@
|
|||||||
critical
|
critical
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
event_type_enum {
|
||||||
|
type: string
|
||||||
|
enum: [
|
||||||
|
training_stats_scalar
|
||||||
|
training_stats_vector
|
||||||
|
training_debug_image
|
||||||
|
plot
|
||||||
|
log
|
||||||
|
]
|
||||||
|
}
|
||||||
|
task_metric {
|
||||||
|
type: object
|
||||||
|
required: [task, metric]
|
||||||
|
properties {
|
||||||
|
task {
|
||||||
|
description: "Task ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
metric {
|
||||||
|
description: "Metric name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
task_log_event {
|
task_log_event {
|
||||||
description: """A log event associated with a task."""
|
description: """A log event associated with a task."""
|
||||||
type: object
|
type: object
|
||||||
@ -319,6 +343,84 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.7" {
|
||||||
|
description: "Get the debug image events for the requested amount of iterations per each task's metric"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [
|
||||||
|
metrics
|
||||||
|
]
|
||||||
|
properties {
|
||||||
|
metrics {
|
||||||
|
type: array
|
||||||
|
items { "$ref": "#/definitions/task_metric" }
|
||||||
|
description: "List metrics for which the envents will be retreived"
|
||||||
|
}
|
||||||
|
iters {
|
||||||
|
type: integer
|
||||||
|
description: "Max number of latest iterations for which to return debug images"
|
||||||
|
}
|
||||||
|
navigate_earlier {
|
||||||
|
type: boolean
|
||||||
|
description: "If set then events are retreived from later iterations to earlier ones. Otherwise from earlier iterations to the later. The default is True"
|
||||||
|
}
|
||||||
|
refresh {
|
||||||
|
type: boolean
|
||||||
|
description: "If set then scroll will be moved to the latest iterations. The default is False"
|
||||||
|
}
|
||||||
|
scroll_id {
|
||||||
|
type: string
|
||||||
|
description: "Scroll ID of previous call (used for getting more results)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
metrics {
|
||||||
|
type: array
|
||||||
|
items: { type: object }
|
||||||
|
description: "Debug image events grouped by task metrics and iterations"
|
||||||
|
}
|
||||||
|
scroll_id {
|
||||||
|
type: string
|
||||||
|
description: "Scroll ID for getting more results"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
get_task_metrics{
|
||||||
|
"2.7": {
|
||||||
|
description: "For each task, get a list of metrics for which the requested event type was reported"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [
|
||||||
|
tasks
|
||||||
|
]
|
||||||
|
properties {
|
||||||
|
tasks {
|
||||||
|
type: array
|
||||||
|
items { type: string }
|
||||||
|
description: "Task IDs"
|
||||||
|
}
|
||||||
|
event_type {
|
||||||
|
"description": "Event type"
|
||||||
|
"$ref": "#/definitions/event_type_enum"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
metrics {
|
||||||
|
type: array
|
||||||
|
items { type: object }
|
||||||
|
description: "List of task with their metrics"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_task_log {
|
get_task_log {
|
||||||
"1.5" {
|
"1.5" {
|
||||||
|
@ -2,12 +2,15 @@ import itertools
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from apierrors import errors
|
from apierrors import errors
|
||||||
from apimodels.events import (
|
from apimodels.events import (
|
||||||
MultiTaskScalarMetricsIterHistogramRequest,
|
MultiTaskScalarMetricsIterHistogramRequest,
|
||||||
ScalarMetricsIterHistogramRequest,
|
ScalarMetricsIterHistogramRequest,
|
||||||
|
DebugImagesRequest,
|
||||||
|
DebugImageResponse,
|
||||||
|
MetricEvents,
|
||||||
|
IterationEvents,
|
||||||
|
TaskMetricsRequest,
|
||||||
)
|
)
|
||||||
from bll.event import EventBLL
|
from bll.event import EventBLL
|
||||||
from bll.event.event_metrics import EventMetrics
|
from bll.event.event_metrics import EventMetrics
|
||||||
@ -299,7 +302,7 @@ def multi_task_scalar_metrics_iter_histogram(
|
|||||||
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
|
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
|
||||||
):
|
):
|
||||||
task_ids = req_model.tasks
|
task_ids = req_model.tasks
|
||||||
if isinstance(task_ids, six.string_types):
|
if isinstance(task_ids, str):
|
||||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||||
# Note, bll already validates task ids as it needs their names
|
# Note, bll already validates task ids as it needs their names
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
@ -481,7 +484,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
|
|||||||
|
|
||||||
|
|
||||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
||||||
def get_debug_images(call, company_id, req_model):
|
def get_debug_images_v1_8(call, company_id, req_model):
|
||||||
task_id = call.data["task"]
|
task_id = call.data["task"]
|
||||||
iters = call.data.get("iters") or 1
|
iters = call.data.get("iters") or 1
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = call.data.get("scroll_id")
|
||||||
@ -507,6 +510,53 @@ def get_debug_images(call, company_id, req_model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"events.debug_images",
|
||||||
|
min_version="2.7",
|
||||||
|
request_data_model=DebugImagesRequest,
|
||||||
|
response_data_model=DebugImageResponse,
|
||||||
|
)
|
||||||
|
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
|
||||||
|
tasks = set(m.task for m in req_model.metrics)
|
||||||
|
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
|
||||||
|
result = event_bll.debug_images_iterator.get_task_events(
|
||||||
|
company_id=company_id,
|
||||||
|
metrics=[(m.task, m.metric) for m in req_model.metrics],
|
||||||
|
iter_count=req_model.iters,
|
||||||
|
navigate_earlier=req_model.navigate_earlier,
|
||||||
|
refresh=req_model.refresh,
|
||||||
|
state_id=req_model.scroll_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
call.result.data_model = DebugImageResponse(
|
||||||
|
scroll_id=result.next_scroll_id,
|
||||||
|
metrics=[
|
||||||
|
MetricEvents(
|
||||||
|
task=task,
|
||||||
|
metric=metric,
|
||||||
|
iterations=[
|
||||||
|
IterationEvents(iter=iteration["iter"], events=iteration["events"])
|
||||||
|
for iteration in iterations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for (task, metric, iterations) in result.metric_events
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||||
|
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
|
||||||
|
task_bll.assert_exists(
|
||||||
|
call.identity.company, task_ids=req_model.tasks, allow_public=True
|
||||||
|
)
|
||||||
|
res = event_bll.metrics.get_tasks_metrics(
|
||||||
|
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
|
||||||
|
)
|
||||||
|
call.result.data = {
|
||||||
|
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||||
def delete_for_task(call, company_id, req_model):
|
def delete_for_task(call, company_id, req_model):
|
||||||
task_id = call.data["task"]
|
task_id = call.data["task"]
|
||||||
|
@ -4,19 +4,16 @@ Comprehensive test of all(?) use cases of datasets and frames
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
from functools import partial
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import es_factory
|
import es_factory
|
||||||
from config import config
|
|
||||||
from tests.automated import TestService
|
from tests.automated import TestService
|
||||||
|
|
||||||
log = config.logger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskEvents(TestService):
|
class TestTaskEvents(TestService):
|
||||||
def setUp(self, version="1.7"):
|
def setUp(self, version="2.7"):
|
||||||
super().setUp(version=version)
|
super().setUp(version=version)
|
||||||
|
|
||||||
def _temp_task(self, name="test task events"):
|
def _temp_task(self, name="test task events"):
|
||||||
@ -25,13 +22,14 @@ class TestTaskEvents(TestService):
|
|||||||
)
|
)
|
||||||
return self.create_temp("tasks", **task_input)
|
return self.create_temp("tasks", **task_input)
|
||||||
|
|
||||||
def _create_task_event(self, type_, task, iteration):
|
def _create_task_event(self, type_, task, iteration, **kwargs):
|
||||||
return {
|
return {
|
||||||
"worker": "test",
|
"worker": "test",
|
||||||
"type": type_,
|
"type": type_,
|
||||||
"task": task,
|
"task": task,
|
||||||
"iter": iteration,
|
"iter": iteration,
|
||||||
"timestamp": es_factory.get_timestamp_millis(),
|
"timestamp": es_factory.get_timestamp_millis(),
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _copy_and_update(self, src_obj, new_data):
|
def _copy_and_update(self, src_obj, new_data):
|
||||||
@ -39,6 +37,134 @@ class TestTaskEvents(TestService):
|
|||||||
obj.update(new_data)
|
obj.update(new_data)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def test_task_metrics(self):
|
||||||
|
tasks = {
|
||||||
|
self._temp_task(): {
|
||||||
|
"Metric1": ["training_debug_image"],
|
||||||
|
"Metric2": ["training_debug_image", "log"],
|
||||||
|
},
|
||||||
|
self._temp_task(): {"Metric3": ["training_debug_image"]},
|
||||||
|
}
|
||||||
|
events = [
|
||||||
|
self._create_task_event(
|
||||||
|
event_type,
|
||||||
|
task=task,
|
||||||
|
iteration=1,
|
||||||
|
metric=metric,
|
||||||
|
variant="Test variant",
|
||||||
|
)
|
||||||
|
for task, metrics in tasks.items()
|
||||||
|
for metric, event_types in metrics.items()
|
||||||
|
for event_type in event_types
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
self._assert_task_metrics(tasks, "training_debug_image")
|
||||||
|
self._assert_task_metrics(tasks, "log")
|
||||||
|
self._assert_task_metrics(tasks, "training_stats_scalar")
|
||||||
|
|
||||||
|
def _assert_task_metrics(self, tasks: dict, event_type: str):
|
||||||
|
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
|
||||||
|
for task, metrics in tasks.items():
|
||||||
|
res_metrics = next(
|
||||||
|
(tm.metrics for tm in res.metrics if tm.task == task), ()
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
set(res_metrics),
|
||||||
|
set(
|
||||||
|
metric for metric, events in metrics.items() if event_type in events
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_task_debug_images(self):
|
||||||
|
task = self._temp_task()
|
||||||
|
metric = "Metric1"
|
||||||
|
variants = [("Variant1", 7), ("Variant2", 4)]
|
||||||
|
iterations = 10
|
||||||
|
|
||||||
|
# test empty
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task, "metric": metric}],
|
||||||
|
iters=5,
|
||||||
|
)
|
||||||
|
self.assertFalse(res.metrics)
|
||||||
|
|
||||||
|
# create events
|
||||||
|
events = [
|
||||||
|
self._create_task_event(
|
||||||
|
"training_debug_image",
|
||||||
|
task=task,
|
||||||
|
iteration=n,
|
||||||
|
metric=metric,
|
||||||
|
variant=variant,
|
||||||
|
url=f"{metric}_{variant}_{n % unique_images}",
|
||||||
|
)
|
||||||
|
for n in range(iterations)
|
||||||
|
for (variant, unique_images) in variants
|
||||||
|
]
|
||||||
|
self.send_batch(events)
|
||||||
|
|
||||||
|
# init testing
|
||||||
|
unique_images = [unique for (_, unique) in variants]
|
||||||
|
scroll_id = None
|
||||||
|
assert_debug_images = partial(
|
||||||
|
self._assertDebugImages,
|
||||||
|
task=task,
|
||||||
|
metric=metric,
|
||||||
|
max_iter=iterations - 1,
|
||||||
|
unique_images=unique_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test forward navigation
|
||||||
|
for page in range(3):
|
||||||
|
scroll_id = assert_debug_images(scroll_id=scroll_id, page=page)
|
||||||
|
|
||||||
|
# test backwards navigation
|
||||||
|
scroll_id = assert_debug_images(
|
||||||
|
scroll_id=scroll_id, page=0, navigate_earlier=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# beyond the latest iteration and back
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task, "metric": metric}],
|
||||||
|
iters=5,
|
||||||
|
scroll_id=scroll_id,
|
||||||
|
navigate_earlier=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
|
||||||
|
assert_debug_images(scroll_id=scroll_id, page=1)
|
||||||
|
|
||||||
|
# refresh
|
||||||
|
assert_debug_images(scroll_id=scroll_id, page=0, refresh=True)
|
||||||
|
|
||||||
|
def _assertDebugImages(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
metric,
|
||||||
|
max_iter: int,
|
||||||
|
unique_images: Sequence[int],
|
||||||
|
scroll_id,
|
||||||
|
page: int,
|
||||||
|
iters: int = 5,
|
||||||
|
**extra_params,
|
||||||
|
):
|
||||||
|
res = self.api.events.debug_images(
|
||||||
|
metrics=[{"task": task, "metric": metric}],
|
||||||
|
iters=iters,
|
||||||
|
scroll_id=scroll_id,
|
||||||
|
**extra_params,
|
||||||
|
)
|
||||||
|
data = res["metrics"][0]
|
||||||
|
self.assertEqual(data["task"], task)
|
||||||
|
self.assertEqual(data["metric"], metric)
|
||||||
|
left_iterations = max(0, max(unique_images) - page * iters)
|
||||||
|
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
||||||
|
for it in data["iterations"]:
|
||||||
|
events_per_iter = sum(
|
||||||
|
1 for unique in unique_images if unique > max_iter - it["iter"]
|
||||||
|
)
|
||||||
|
self.assertEqual(len(it["events"]), events_per_iter)
|
||||||
|
return res.scroll_id
|
||||||
|
|
||||||
def test_task_logs(self):
|
def test_task_logs(self):
|
||||||
events = []
|
events = []
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
|
Loading…
Reference in New Issue
Block a user