Add support for pagination in events.debug_images

This commit is contained in:
allegroai 2020-03-01 18:00:07 +02:00
parent 69714d5b5c
commit 6c8508eb7f
13 changed files with 1021 additions and 57 deletions

View File

@ -89,6 +89,8 @@ _error_codes = {
1003: ('worker_registered', 'worker is already registered'),
1004: ('worker_not_registered', 'worker is not registered'),
1005: ('worker_stats_not_found', 'worker stats not found'),
1104: ('invalid_scroll_id', 'Invalid scroll id'),
},
(401, 'unauthorized'): {

View File

@ -1,9 +1,12 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apimodels import ListField, IntField, ActualEnumField
from bll.event.event_metrics import EventType
from bll.event.scalar_key import ScalarKeyEnum
@ -17,4 +20,44 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(items_types=str, required=True)
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
class DebugImagesRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
iters: int = IntField(default=1, validators=validators.Min(1))
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class TaskMetricsRequest(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)

View File

@ -0,0 +1,464 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from redis import StrictRedis
from typing import Sequence, Tuple, Optional, Mapping
import database
from apierrors import errors
from bll.redis_cache_manager import RedisCacheManager
from bll.event.event_metrics import EventMetrics
from config import config
from database.errors import translate_errors_context
from jsonmodels.models import Base
from jsonmodels.fields import StringField, ListField, IntField
from database.model.task.metrics import MetricEventStats
from database.model.task.task import Task
from timing_context import TimingContext
from utilities.json import loads, dumps
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
def to_json(self):
return dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**loads(s))
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
STATE_EXPIRATION_SECONDS = 3600
@property
def _max_workers(self):
return config.get("services.events.max_metrics_concurrency", 4)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.STATE_EXPIRATION_SECONDS,
)
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return DebugImagesResult()
unique_metrics = set(metrics)
state = self.cache_manager.get_state(state_id) if state_id else None
if not state:
state = DebugImageEventsScrollState(
id=database.utils.id(),
metrics=self._init_metric_states(es_index, list(unique_metrics)),
)
else:
state_metrics = set((m.task, m.name) for m in state.metrics)
if state_metrics != unique_metrics:
raise errors.bad_request.InvalidScrollId(
"while getting debug images events", scroll_id=state_id
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state)
for metric_state in state.metrics:
metric_state.reset()
res = DebugImagesResult(next_scroll_id=state.id)
try:
with ThreadPoolExecutor(self._max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
es_index=es_index,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
)
)
finally:
self.cache_manager.set_state(state)
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
)
for stats in metric_stats.values()
if self.EVENT_TYPE in stats.event_stats_by_type
]
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
es_index,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
]
def _init_metric_states(
self, es_index, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(self._max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(self._init_metric_states_for_task, es_index=es_index),
tasks.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], es_index
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
}
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
variants=[
init_variant_scroll_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
es_index: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
"""
if metric.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_term": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
if "aggregations" not in es_res:
return metric.task, metric.name, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for ev in dpath.get(v, "events/hits/hits")
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]

View File

@ -2,7 +2,6 @@ import hashlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from enum import Enum
from operator import attrgetter
from typing import Sequence
@ -15,46 +14,39 @@ from nested_dict import nested_dict
import database.utils as dbutils
import es_factory
from apierrors import errors
from bll.event.event_metrics import EventMetrics
from bll.event.debug_images_iterator import DebugImagesIterator
from bll.event.event_metrics import EventMetrics, EventType
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task, TaskStatus
from redis_manager import redman
from timing_context import TimingContext
from utilities.dicts import flatten_nested_items
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
# noinspection PyTypeChecker
EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s
@attr.s(auto_attribs=True)
class TaskEventsResult(object):
events = attr.ib(type=list, default=attr.Factory(list))
total_events = attr.ib(type=int, default=0)
next_scroll_id = attr.ib(type=str, default=None)
total_events: int = 0
next_scroll_id: str = None
events: list = attr.ib(factory=list)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
def __init__(self, events_es=None):
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
self._metrics = EventMetrics(self.es)
self._skip_iteration_for_metric = set(
config.get("services.events.ignore_iteration.metrics", [])
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
@property
def metrics(self) -> EventMetrics:
@ -64,9 +56,12 @@ class EventBLL(object):
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_events = nested_dict(
task_last_scalar_events = nested_dict(
3, dict
) # task_id -> metric_hash -> variant_hash -> MetricEvent
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
for event in events:
# remove spaces from event type
@ -108,6 +103,9 @@ class EventBLL(object):
event["value"] = event["values"]
del event["values"]
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
index_name = EventMetrics.get_index_name(company_id, event_type)
es_action = {
"_op_type": "index", # overwrite if exists with same ID
@ -132,9 +130,12 @@ class EventBLL(object):
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_metric_event_for_task(
task_last_events=task_last_events, task_id=task_id, event=event
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
else:
es_action["_routing"] = task_id
@ -187,6 +188,7 @@ class EventBLL(object):
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
@ -202,12 +204,12 @@ class EventBLL(object):
return added, errors_in_bulk
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
def _update_last_scalar_events_for_task(self, last_events, event):
"""
Update task_last_events structure for the provided task_id with the provided event details if this event is more
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
@ -218,13 +220,34 @@ class EventBLL(object):
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_events = task_last_events[task_id]
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None):
def _update_last_metric_events_for_task(self, last_events, event):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/event_type combination.
last_events contains [metric_name -> event_type -> event]
"""
metric = event.get("metric")
event_type = event.get("type")
if not (metric and event_type):
return
timestamp = last_events[metric][event_type].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric][event_type] = event
def _update_task(
self,
company_id,
task_id,
now,
iter_max=None,
last_scalar_events=None,
last_events=None,
):
"""
Update task information in DB with aggregated results after handling event(s) related to this task.
@ -237,15 +260,18 @@ class EventBLL(object):
if iter_max is not None:
fields["last_iteration_max"] = iter_max
if last_events:
fields["last_values"] = list(
if last_scalar_events:
fields["last_scalar_values"] = list(
flatten_nested_items(
last_events,
last_scalar_events,
nesting=2,
include_leaves=["value", "metric", "variant"],
)
)
if last_events:
fields["last_events"] = last_events
if not fields:
return False

View File

@ -1,13 +1,13 @@
import itertools
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Callable, Iterable
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from typing import Sequence, Tuple, Callable, Iterable
from mongoengine import Q
from apierrors import errors
@ -21,6 +21,14 @@ from utilities import safe_get
log = config.logger(__file__)
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventMetrics:
MAX_TASKS_COUNT = 50
MAX_METRICS_COUNT = 200
@ -66,7 +74,8 @@ class EventMetrics:
"""
if len(task_ids) > self.MAX_TASKS_COUNT:
raise errors.BadRequest(
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", len(task_ids)
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
len(task_ids),
)
task_name_by_id = {}
@ -168,9 +177,7 @@ class EventMetrics:
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(
partial(
get_func, task_ids=task_ids, es_index=es_index, key=key
),
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
intervals,
)
)
@ -440,3 +447,50 @@ class EventMetrics:
]
}
}
def get_tasks_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = EventMetrics.get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
return {}
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_concurrency) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
),
task_ids,
)
return list(zip(task_ids, res))
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"type": event_type.value}},
]
}
},
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
]

View File

@ -0,0 +1,44 @@
from typing import Optional, TypeVar, Generic, Type
from redis import StrictRedis
from timing_context import TimingContext
T = TypeVar("T")
class RedisCacheManager(Generic[T]):
"""
Class for store/retreive of state objects from redis
self.state_class - class of the state
self.redis - instance of redis
self.expiration_interval - expiration interval in seconds
"""
def __init__(
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
):
self.state_class = state_class
self.redis = redis
self.expiration_interval = expiration_interval
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"

View File

@ -3,13 +3,14 @@ from datetime import datetime, timedelta
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
@ -17,6 +18,7 @@ from config import config
from database.errors import translate_errors_context
from database.model.model import Model
from database.model.project import Project
from database.model.task.metrics import EventStats, MetricEventStats
from database.model.task.output import Output
from database.model.task.task import (
Task,
@ -197,7 +199,9 @@ class TaskBLL(object):
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
output=Output(destination=task.output.destination)
if task.output
else None,
execution=execution_dict,
)
cls.validate(new_task)
@ -277,7 +281,8 @@ class TaskBLL(object):
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
"""
@ -289,7 +294,8 @@ class TaskBLL(object):
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_values: Last reported metrics summary (value, metric, variant).
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
"""
@ -300,17 +306,33 @@ class TaskBLL(object):
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_values is not None:
if last_scalar_values is not None:
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
for path, value in last_values:
for path, value in last_scalar_values:
extra_updates[op_path("set", *path)] = value
if path[-1] == "value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
extra_updates[op_path("max", *path[:-1], "max_value")] = value
if last_events is not None:
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
return {
event_type: EventStats(last_update=event["timestamp"])
for event_type, event in metric_data.items()
}
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
)
for metric_key, metric_data in last_events.items()
}
extra_updates["metric_stats"] = metric_stats
Task.objects(id=task_id, company=company_id).update(
upsert=False, last_update=last_update, **extra_updates
)

View File

@ -32,6 +32,11 @@ mongo {
}
redis {
apiserver {
host: "127.0.0.1"
port: 6379
db: 0
}
workers {
host: "127.0.0.1"
port: 6379

View File

@ -1,10 +1,18 @@
from mongoengine import EmbeddedDocument, StringField, DynamicField
from mongoengine import (
EmbeddedDocument,
StringField,
DynamicField,
LongField,
EmbeddedDocumentField,
)
from database.fields import SafeMapField
class MetricEvent(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
'strict': False,
"strict": False,
}
metric = StringField(required=True)
@ -12,3 +20,20 @@ class MetricEvent(EmbeddedDocument):
value = DynamicField(required=True)
min_value = DynamicField() # for backwards compatibility reasons
max_value = DynamicField() # for backwards compatibility reasons
class EventStats(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
"strict": False,
}
last_update = LongField()
class MetricEventStats(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
"strict": False,
}
metric = StringField(required=True)
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))

View File

@ -22,7 +22,7 @@ from database.model.base import ProperDictMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
from .metrics import MetricEvent
from .metrics import MetricEvent, MetricEventStats
from .output import Output
DEFAULT_LAST_ITERATION = 0
@ -162,3 +162,4 @@ class Task(AttributedDocument):
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))

View File

@ -171,6 +171,30 @@
critical
]
}
event_type_enum {
type: string
enum: [
training_stats_scalar
training_stats_vector
training_debug_image
plot
log
]
}
task_metric {
type: object
required: [task, metric]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
}
}
task_log_event {
description: """A log event associated with a task."""
type: object
@ -319,6 +343,84 @@
}
}
}
"2.7" {
description: "Get the debug image events for the requested amount of iterations per each task's metric"
request {
type: object
required: [
metrics
]
properties {
metrics {
type: array
items { "$ref": "#/definitions/task_metric" }
description: "List metrics for which the envents will be retreived"
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
navigate_earlier {
type: boolean
description: "If set then events are retreived from later iterations to earlier ones. Otherwise from earlier iterations to the later. The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest iterations. The default is False"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
metrics {
type: array
items: { type: object }
description: "Debug image events grouped by task metrics and iterations"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_task_metrics{
"2.7": {
description: "For each task, get a list of metrics for which the requested event type was reported"
request {
type: object
required: [
tasks
]
properties {
tasks {
type: array
items { type: string }
description: "Task IDs"
}
event_type {
"description": "Event type"
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
items { type: object }
description: "List of task with their metrics"
}
}
}
}
}
get_task_log {
"1.5" {

View File

@ -2,12 +2,15 @@ import itertools
from collections import defaultdict
from operator import itemgetter
import six
from apierrors import errors
from apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
@ -299,7 +302,7 @@ def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
if isinstance(task_ids, six.string_types):
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
@ -481,7 +484,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images(call, company_id, req_model):
def get_debug_images_v1_8(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@ -507,6 +510,53 @@ def get_debug_images(call, company_id, req_model):
)
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
tasks = set(m.task for m in req_model.metrics)
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
result = event_bll.debug_images_iterator.get_task_events(
company_id=company_id,
metrics=[(m.task, m.metric) for m in req_model.metrics],
iter_count=req_model.iters,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
)
call.result.data_model = DebugImageResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
metric=metric,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, metric, iterations) in result.metric_events
],
)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
task_bll.assert_exists(
call.identity.company, task_ids=req_model.tasks, allow_public=True
)
res = event_bll.metrics.get_tasks_metrics(
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]

View File

@ -4,19 +4,16 @@ Comprehensive test of all(?) use cases of datasets and frames
import json
import time
import unittest
from functools import partial
from statistics import mean
from typing import Sequence
import es_factory
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestTaskEvents(TestService):
def setUp(self, version="1.7"):
def setUp(self, version="2.7"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):
@ -25,13 +22,14 @@ class TestTaskEvents(TestService):
)
return self.create_temp("tasks", **task_input)
def _create_task_event(self, type_, task, iteration):
def _create_task_event(self, type_, task, iteration, **kwargs):
return {
"worker": "test",
"type": type_,
"task": task,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis(),
**kwargs,
}
def _copy_and_update(self, src_obj, new_data):
@ -39,6 +37,134 @@ class TestTaskEvents(TestService):
obj.update(new_data)
return obj
def test_task_metrics(self):
tasks = {
self._temp_task(): {
"Metric1": ["training_debug_image"],
"Metric2": ["training_debug_image", "log"],
},
self._temp_task(): {"Metric3": ["training_debug_image"]},
}
events = [
self._create_task_event(
event_type,
task=task,
iteration=1,
metric=metric,
variant="Test variant",
)
for task, metrics in tasks.items()
for metric, event_types in metrics.items()
for event_type in event_types
]
self.send_batch(events)
self._assert_task_metrics(tasks, "training_debug_image")
self._assert_task_metrics(tasks, "log")
self._assert_task_metrics(tasks, "training_stats_scalar")
def _assert_task_metrics(self, tasks: dict, event_type: str):
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
for task, metrics in tasks.items():
res_metrics = next(
(tm.metrics for tm in res.metrics if tm.task == task), ()
)
self.assertEqual(
set(res_metrics),
set(
metric for metric, events in metrics.items() if event_type in events
),
)
def test_task_debug_images(self):
task = self._temp_task()
metric = "Metric1"
variants = [("Variant1", 7), ("Variant2", 4)]
iterations = 10
# test empty
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
)
self.assertFalse(res.metrics)
# create events
events = [
self._create_task_event(
"training_debug_image",
task=task,
iteration=n,
metric=metric,
variant=variant,
url=f"{metric}_{variant}_{n % unique_images}",
)
for n in range(iterations)
for (variant, unique_images) in variants
]
self.send_batch(events)
# init testing
unique_images = [unique for (_, unique) in variants]
scroll_id = None
assert_debug_images = partial(
self._assertDebugImages,
task=task,
metric=metric,
max_iter=iterations - 1,
unique_images=unique_images,
)
# test forward navigation
for page in range(3):
scroll_id = assert_debug_images(scroll_id=scroll_id, page=page)
# test backwards navigation
scroll_id = assert_debug_images(
scroll_id=scroll_id, page=0, navigate_earlier=False
)
# beyond the latest iteration and back
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
scroll_id=scroll_id,
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_debug_images(scroll_id=scroll_id, page=1)
# refresh
assert_debug_images(scroll_id=scroll_id, page=0, refresh=True)
def _assertDebugImages(
self,
task,
metric,
max_iter: int,
unique_images: Sequence[int],
scroll_id,
page: int,
iters: int = 5,
**extra_params,
):
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=iters,
scroll_id=scroll_id,
**extra_params,
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
self.assertEqual(data["metric"], metric)
left_iterations = max(0, max(unique_images) - page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
events_per_iter = sum(
1 for unique in unique_images if unique > max_iter - it["iter"]
)
self.assertEqual(len(it["events"]), events_per_iter)
return res.scroll_id
def test_task_logs(self):
events = []
task = self._temp_task()