Compare commits

22 Commits

Author SHA1 Message Date
allegroai
50c373cf0d Version bump to v0.14.1 2020-03-16 18:47:35 +02:00
allegroai
394a9de5fa Update docs with AMI IDs for v0.14.1 2020-03-16 18:47:20 +02:00
allegroai
fb5c06e9c3 Version bump to v0.14.0 2020-03-05 20:03:48 +02:00
allegroai
1a9bbc9420 Update docs with AMI IDs for v0.14.0 2020-03-05 20:03:33 +02:00
allegroai
294da32401 Fix getting empty metrics from task 2020-03-05 14:57:20 +02:00
allegroai
7f00672010 Fix missing routing value when downloading tasks events 2020-03-05 14:55:40 +02:00
allegroai
99bf89a360 Add pre-populate feature to allow starting a new server installation with packaged example experiments 2020-03-05 14:54:34 +02:00
allegroai
6c8508eb7f Add support for pagination in events.debug_images 2020-03-01 18:00:07 +02:00
allegroai
69714d5b5c Use top-level module for api version number instead of a fixed value 2020-03-01 17:51:03 +02:00
allegroai
f9516ec7d3 Fix ActualEnumField initialization in case default was not provided 2020-03-01 17:47:47 +02:00
allegroai
6fdde93dee Add migration script 2020-03-01 17:46:10 +02:00
allegroai
7afc71ec91 Update requirements 2020-02-26 17:26:59 +02:00
allegroai
4595117d91 Support setting fileserver upload folder using an environment variable 2020-02-26 17:26:46 +02:00
allegroai
8630cc1021 Fix queue update time to update when task is taken from queue, not when queried 2020-02-20 18:26:56 +02:00
allegroai
135885b609 Improve unit test for entity ordering 2020-02-04 18:21:13 +02:00
allegroai
eb0865662c Fix projects aggregation on tasks with invalid status 2020-02-04 18:21:04 +02:00
allegroai
b7b94e7ae5 Add more validation when parsing task call 2020-02-04 18:19:07 +02:00
allegroai
72be8bee19 Limit metrics and variants to avoid ES error 2020-02-04 18:18:26 +02:00
allegroai
0722b20c1c Fix task scalars comparison aggregation 2020-02-04 18:16:27 +02:00
allegroai
a392a0e6ff Fix request field required constraint 2020-02-04 18:12:30 +02:00
allegroai
e22fa2f478 Limit dpath requirement 2020-02-04 18:09:55 +02:00
allegroai
8b49c1ac06 Update docs with AWS AMI IDs for v0.13.0 2020-01-07 14:40:09 +02:00
37 changed files with 1854 additions and 419 deletions

View File

@@ -22,6 +22,8 @@ services:
TRAINS_MONGODB_SERVICE_PORT: 27017
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
ports:
- "8008:8008"
networks:

View File

@@ -50,44 +50,81 @@ To upgrade the AMI:
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
### Latest version AMI - v0.13.0 (auto update)<a name="autoupdate"></a>
### Latest version AMI - v0.14.1 (auto update)<a name="autoupdate"></a>
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
* **eu-north-1** : ami-003024b7b575d3f2a
* **ap-south-1** : ami-0d784c7ac2ab4cc72
* **eu-west-3** : ami-091d745be445b69db
* **eu-west-2** : ami-0a4ebf5d45c672411
* **eu-west-1** : ami-021e3421c50d1482c
* **ap-northeast-2** : ami-0d0a25ec610d6d122
* **ap-northeast-1** : ami-01d896f9ae5d87890
* **sa-east-1** : ami-09bcb93835428a412
* **ca-central-1** : ami-077fa58c9f73690c7
* **ap-southeast-1** : ami-046fe4832b077b517
* **ap-southeast-2** : ami-0ab9acb41f8abbba7
* **eu-central-1** : ami-079be664aae12db00
* **us-east-2** : ami-0d48555f80cb7993a
* **us-west-1** : ami-0ed85ab91a7bb5a8a
* **us-west-2** : ami-0b4fe4ca18e9b1227
* **us-east-1** : ami-043b95dd034e581e6
* **eu-north-1** : ami-033fd0d9163e0a36e
* **ap-south-1** : ami-0cdd7f9880336f9b5
* **eu-west-3** : ami-085f508eef9f650d5
* **eu-west-2** : ami-0936deb94a193502d
* **eu-west-1** : ami-0c20b3620f1c6fff7
* **ap-northeast-2** : ami-0707c9790f8a224c4
* **ap-northeast-1** : ami-04595f0745c090328
* **sa-east-1** : ami-03898299742d43ad4
* **ca-central-1** : ami-0502dfea5223d572a
* **ap-southeast-1** : ami-02aa1f9308404e464
* **ap-southeast-2** : ami-0b66189b90df79b38
* **eu-central-1** : ami-0eb919d8234c49cdc
* **us-east-2** : ami-02fb63fca1b9f0d4b
* **us-west-1** : ami-01fdda7351725a689
* **us-west-2** : ami-004a2f40cdc095870
* **us-east-1** : ami-0a8acd1172ffebc7e
### v0.14.1 (static update)
* **eu-north-1** : ami-0ccdf4700a6c989ad
* **ap-south-1** : ami-0f6a8de6441d64a68
* **eu-west-3** : ami-0c51b9a8e7b3371cc
* **eu-west-2** : ami-099c094598f72bcb5
* **eu-west-1** : ami-0af20d5e4ab764212
* **ap-northeast-2** : ami-011455e8d852e02d6
* **ap-northeast-1** : ami-0211827ee11d6ed9c
* **sa-east-1** : ami-07509b07aa4554dc2
* **ca-central-1** : ami-07153c171d97e460e
* **ap-southeast-1** : ami-042d61c497063675b
* **ap-southeast-2** : ami-0dcf27f88bd2dd622
* **eu-central-1** : ami-0ae29f89d9bcb1a95
* **us-east-2** : ami-053144df2cea2bd97
* **us-west-1** : ami-0f703537206ee05f1
* **us-west-2** : ami-007c954572c86a583
* **us-east-1** : ami-07c59cbc7541f58e9
### v0.14.0 (static update)
* **eu-north-1** : ami-02de71586ec496e38
* **ap-south-1** : ami-074b03849b51852e5
* **eu-west-3** : ami-022c388835e0eeb03
* **eu-west-2** : ami-0a151c236c6b27707
* **eu-west-1** : ami-06de69b06b4e73312
* **ap-northeast-2** : ami-0ee821b72d9f669b1
* **ap-northeast-1** : ami-03687ae215e64e100
* **sa-east-1** : ami-01eb83364b7f667af
* **ca-central-1** : ami-02e9b35f9c90377e6
* **ap-southeast-1** : ami-0d3ab5ab0048fea51
* **ap-southeast-2** : ami-0bd39d908fe3a9e06
* **eu-central-1** : ami-0b8638701311b35c4
* **us-east-2** : ami-02ff039693fc3a614
* **us-west-1** : ami-08634f7dfb608a9a7
* **us-west-2** : ami-034d693ef742b9333
* **us-east-1** : ami-0b828b05c323dde7f
### v0.13.0 (static update)
* **eu-north-1** : ami-0e26c3af1663428dc
* **ap-south-1** : ami-07451eb44f51380a8
* **eu-west-3** : ami-0108e506c6e0be8d8
* **eu-west-2** : ami-0fc1fdbc7699f0dde
* **eu-west-1** : ami-0efbf8d2f580a9cee
* **ap-northeast-2** : ami-08f0bbd7e08d0603e
* **ap-northeast-1** : ami-024522bea34dbe3ce
* **sa-east-1** : ami-0fe5b6e0ddc1553d9
* **ca-central-1** : ami-0037c26178a584ade
* **ap-southeast-1** : ami-049dbcc0f0a6dba20
* **ap-southeast-2** : ami-02d1ce8d31c27f187
* **eu-central-1** : ami-0550b14b40371182a
* **us-east-2** : ami-040a1f16ceda8f255
* **us-west-1** : ami-003b5673c08d68cdb
* **us-west-2** : ami-0fec951d8043da62d
* **us-east-1** : ami-049694de0137fdea4
* **eu-north-1** : ami-0d9c74a015e7510d8
* **ap-south-1** : ami-02acd6dd0659bb5c1
* **eu-west-3** : ami-0f0cc5cb6d9afd194
* **eu-west-2** : ami-0298fdc0860206ed9
* **eu-west-1** : ami-0cdc072e528401d5e
* **ap-northeast-2** : ami-0055579cc95b0e53e
* **ap-northeast-1** : ami-0ced7becb9b83b5d0
* **sa-east-1** : ami-033345d0f16a1b5e4
* **ca-central-1** : ami-06c63b05aed47ae67
* **ap-southeast-1** : ami-09f0355f367f30602
* **ap-southeast-2** : ami-0bd2314163ce0fba0
* **eu-central-1** : ami-05fbae957df63e366
* **us-east-2** : ami-050c51b5b4074d3fc
* **us-west-1** : ami-06ad513073d4e5a19
* **us-west-2** : ami-0c96e1361d1d4ca94
* **us-east-1** : ami-07b669040d1eea213
### v0.12.1 (static update)
* **eu-north-1** : ami-003118a8103286d84

View File

@@ -14,6 +14,9 @@ app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
if os.environ.get("TRAINS_UPLOAD_FOLDER"):
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER")
@app.route("/", methods=["POST"])
def upload():

1
server/api_version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "2.7.0"

View File

@@ -89,6 +89,8 @@ _error_codes = {
1003: ('worker_registered', 'worker is already registered'),
1004: ('worker_not_registered', 'worker is not registered'),
1005: ('worker_stats_not_found', 'worker stats not found'),
1104: ('invalid_scroll_id', 'Invalid scroll id'),
},
(401, 'unauthorized'): {

View File

@@ -168,7 +168,7 @@ class ActualEnumField(fields.StringField):
validator_cls = EnumValidator if required else NullableEnumValidator
validators = [*(validators or []), validator_cls(*choices)]
super().__init__(
default=default and self.parse_value(default),
default=self.parse_value(default) if default else NotSet,
*args,
required=required,
validators=validators,

View File

@@ -1,9 +1,12 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from apimodels import ListField, IntField, ActualEnumField
from bll.event.event_metrics import EventType
from bll.event.scalar_key import ScalarKeyEnum
@@ -17,4 +20,44 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(items_types=str)
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
class DebugImagesRequest(Base):
metrics: Sequence[TaskMetric] = ListField(
items_types=TaskMetric, validators=[Length(minimum_value=1)]
)
iters: int = IntField(default=1, validators=validators.Min(1))
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)
class MetricEvents(Base):
task: str = StringField()
metric: str = StringField()
iterations: Sequence[IterationEvents] = ListField(items_types=IterationEvents)
class DebugImageResponse(Base):
metrics: Sequence[MetricEvents] = ListField(items_types=MetricEvents)
scroll_id: str = StringField()
class TaskMetricsRequest(Base):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)

View File

@@ -0,0 +1,464 @@
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from itertools import chain
from operator import attrgetter, itemgetter
import attr
import dpath
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from redis import StrictRedis
from typing import Sequence, Tuple, Optional, Mapping
import database
from apierrors import errors
from bll.redis_cache_manager import RedisCacheManager
from bll.event.event_metrics import EventMetrics
from config import config
from database.errors import translate_errors_context
from jsonmodels.models import Base
from jsonmodels.fields import StringField, ListField, IntField
from database.model.task.metrics import MetricEventStats
from database.model.task.task import Task
from timing_context import TimingContext
from utilities.json import loads, dumps
class VariantScrollState(Base):
name: str = StringField(required=True)
recycle_url_marker: str = StringField()
last_invalid_iteration: int = IntField()
class MetricScrollState(Base):
task: str = StringField(required=True)
name: str = StringField(required=True)
last_min_iter: Optional[int] = IntField()
last_max_iter: Optional[int] = IntField()
timestamp: int = IntField(default=0)
variants: Sequence[VariantScrollState] = ListField([VariantScrollState])
def reset(self):
"""Reset the scrolling state for the metric"""
self.last_min_iter = self.last_max_iter = None
class DebugImageEventsScrollState(Base):
id: str = StringField(required=True)
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
def to_json(self):
return dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**loads(s))
@attr.s(auto_attribs=True)
class DebugImagesResult(object):
metric_events: Sequence[tuple] = []
next_scroll_id: str = None
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
STATE_EXPIRATION_SECONDS = 3600
@property
def _max_workers(self):
return config.get("services.events.max_metrics_concurrency", 4)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageEventsScrollState,
redis=redis,
expiration_interval=self.STATE_EXPIRATION_SECONDS,
)
def get_task_events(
self,
company_id: str,
metrics: Sequence[Tuple[str, str]],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
) -> DebugImagesResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return DebugImagesResult()
unique_metrics = set(metrics)
state = self.cache_manager.get_state(state_id) if state_id else None
if not state:
state = DebugImageEventsScrollState(
id=database.utils.id(),
metrics=self._init_metric_states(es_index, list(unique_metrics)),
)
else:
state_metrics = set((m.task, m.name) for m in state.metrics)
if state_metrics != unique_metrics:
raise errors.bad_request.InvalidScrollId(
"while getting debug images events", scroll_id=state_id
)
if refresh:
self._reinit_outdated_metric_states(company_id, es_index, state)
for metric_state in state.metrics:
metric_state.reset()
res = DebugImagesResult(next_scroll_id=state.id)
try:
with ThreadPoolExecutor(self._max_workers) as pool:
res.metric_events = list(
pool.map(
partial(
self._get_task_metric_events,
es_index=es_index,
iter_count=iter_count,
navigate_earlier=navigate_earlier,
),
state.metrics,
)
)
finally:
self.cache_manager.set_state(state)
return res
def _reinit_outdated_metric_states(
self, company_id, es_index, state: DebugImageEventsScrollState
):
"""
Determines the metrics for which new debug image events were added
since their states were initialized and reinits these states
"""
task_ids = set(metric.task for metric in state.metrics)
tasks = Task.objects(id__in=list(task_ids), company=company_id).only(
"id", "metric_stats"
)
def get_last_update_times_for_task_metrics(task: Task) -> Sequence[Tuple]:
"""For metrics that reported debug image events get tuples of task_id/metric_name and last update times"""
metric_stats: Mapping[str, MetricEventStats] = task.metric_stats
if not metric_stats:
return []
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
)
for stats in metric_stats.values()
if self.EVENT_TYPE in stats.event_stats_by_type
]
update_times = dict(
chain.from_iterable(
get_last_update_times_for_task_metrics(task) for task in tasks
)
)
outdated_metrics = [
metric
for metric in state.metrics
if (metric.task, metric.name) in update_times
and update_times[metric.task, metric.name] > metric.timestamp
]
state.metrics = [
*(metric for metric in state.metrics if metric not in outdated_metrics),
*(
self._init_metric_states(
es_index,
[(metric.task, metric.name) for metric in outdated_metrics],
)
),
]
def _init_metric_states(
self, es_index, metrics: Sequence[Tuple[str, str]]
) -> Sequence[MetricScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
"""
tasks = defaultdict(list)
for (task, metric) in metrics:
tasks[task].append(metric)
with ThreadPoolExecutor(self._max_workers) as pool:
return list(
chain.from_iterable(
pool.map(
partial(self._init_metric_states_for_task, es_index=es_index),
tasks.items(),
)
)
)
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Sequence[str]], es_index
) -> Sequence[MetricScrollState]:
"""
Return metric scroll states for the task filled with the variant states
for the variants that reported any debug images
"""
task, metrics = task_metrics
es_req: dict = {
"size": 0,
"query": {
"bool": {
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
}
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"last_event_timestamp": {"max": {"field": "timestamp"}},
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"urls": {
"terms": {
"field": "url",
"order": {"max_iter": "desc"},
"size": 1, # we need only one url from the most recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 2, # need two last iterations so that we can take
# the second one as invalid
"_source": "iter",
}
},
},
}
},
},
},
}
},
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
if "aggregations" not in es_res:
return []
def init_variant_scroll_state(variant: dict):
"""
Return new variant scroll state for the passed variant bucket
If the image urls get recycled then fill the last_invalid_iteration field
"""
state = VariantScrollState(name=variant["key"])
top_iter_url = dpath.get(variant, "urls/buckets")[0]
iters = dpath.get(top_iter_url, "iters/hits/hits")
if len(iters) > 1:
state.last_invalid_iteration = dpath.get(iters[1], "_source/iter")
return state
return [
MetricScrollState(
task=task,
name=metric["key"],
variants=[
init_variant_scroll_state(variant)
for variant in dpath.get(metric, "variants/buckets")
],
timestamp=dpath.get(metric, "last_event_timestamp/value"),
)
for metric in dpath.get(es_res, "aggregations/metrics/buckets")
]
def _get_task_metric_events(
self,
metric: MetricScrollState,
es_index: str,
iter_count: int,
navigate_earlier: bool,
) -> Tuple:
"""
Return task metric events grouped by iterations
Update metric scroll state
"""
if metric.last_max_iter is None:
# the first fetch is always from the latest iteration to the earlier ones
navigate_earlier = True
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
]
must_not_conditions = []
range_condition = None
if navigate_earlier and metric.last_min_iter is not None:
range_condition = {"lt": metric.last_min_iter}
elif not navigate_earlier and metric.last_max_iter is not None:
range_condition = {"gt": metric.last_max_iter}
if range_condition:
must_conditions.append({"range": {"iter": range_condition}})
if navigate_earlier:
"""
When navigating to earlier iterations consider only
variants whose invalid iterations border is lower than
our starting iteration. For these variants make sure
that only events from the valid iterations are returned
"""
if not metric.last_min_iter:
variants = metric.variants
else:
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is None
or v.last_invalid_iteration < metric.last_min_iter
)
if not variants:
return metric.task, metric.name, []
must_conditions.append(
{"terms": {"variant": list(v.name for v in variants)}}
)
else:
"""
When navigating to later iterations all variants may be relevant.
For the variants whose invalid border is higher than our starting
iteration make sure that only events from valid iterations are returned
"""
variants = list(
v
for v in metric.variants
if v.last_invalid_iteration is not None
and v.last_invalid_iteration > metric.last_max_iter
)
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"lte": v.last_invalid_iteration}}},
]
}
}
for v in variants
if v.last_invalid_iteration is not None
]
if variants_conditions:
must_not_conditions.append({"bool": {"should": variants_conditions}})
es_req = {
"size": 0,
"query": {
"bool": {"must": must_conditions, "must_not": must_not_conditions}
},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_term": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"events": {
"top_hits": {"sort": {"url": {"order": "desc"}}}
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
if "aggregations" not in es_res:
return metric.task, metric.name, []
def get_iteration_events(variant_buckets: Sequence[dict]) -> Sequence:
return [
ev["_source"]
for v in variant_buckets
for ev in dpath.get(v, "events/hits/hits")
]
iterations = [
{
"iter": it["key"],
"events": get_iteration_events(dpath.get(it, "variants/buckets")),
}
for it in dpath.get(es_res, "aggregations/iters/buckets")
]
if not navigate_earlier:
iterations.sort(key=itemgetter("iter"), reverse=True)
if iterations:
metric.last_max_iter = iterations[0]["iter"]
metric.last_min_iter = iterations[-1]["iter"]
# Commented for now since the last invalid iteration is calculated in the beginning
# if navigate_earlier and any(
# variant.last_invalid_iteration is None for variant in variants
# ):
# """
# Variants validation flags due to recycling can
# be set only on navigation to earlier frames
# """
# iterations = self._update_variants_invalid_iterations(variants, iterations)
return metric.task, metric.name, iterations
@staticmethod
def _update_variants_invalid_iterations(
variants: Sequence[VariantScrollState], iterations: Sequence[dict]
) -> Sequence[dict]:
"""
This code is currently not in used since the invalid iterations
are calculated during MetricState initialization
For variants that do not have recycle url marker set it from the
first event
For variants that do not have last_invalid_iteration set check if the
recycle marker was reached on a certain iteration and set it to the
corresponding iteration
For variants that have a newly set last_invalid_iteration remove
events from the invalid iterations
Return the updated iterations list
"""
variants_lookup = bucketize(variants, attrgetter("name"))
for it in iterations:
iteration = it["iter"]
events_to_remove = []
for event in it["events"]:
variant = variants_lookup[event["variant"]][0]
if (
variant.last_invalid_iteration
and variant.last_invalid_iteration >= iteration
):
events_to_remove.append(event)
continue
event_url = event.get("url")
if not variant.recycle_url_marker:
variant.recycle_url_marker = event_url
elif variant.recycle_url_marker == event_url:
variant.last_invalid_iteration = iteration
events_to_remove.append(event)
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]

View File

@@ -2,7 +2,6 @@ import hashlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from enum import Enum
from operator import attrgetter
from typing import Sequence
@@ -15,44 +14,39 @@ from nested_dict import nested_dict
import database.utils as dbutils
import es_factory
from apierrors import errors
from bll.event.event_metrics import EventMetrics
from bll.event.debug_images_iterator import DebugImagesIterator
from bll.event.event_metrics import EventMetrics, EventType
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task, TaskStatus
from redis_manager import redman
from timing_context import TimingContext
from utilities.dicts import flatten_nested_items
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
# noinspection PyTypeChecker
EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s
@attr.s(auto_attribs=True)
class TaskEventsResult(object):
events = attr.ib(type=list, default=attr.Factory(list))
total_events = attr.ib(type=int, default=0)
next_scroll_id = attr.ib(type=str, default=None)
total_events: int = 0
next_scroll_id: str = None
events: list = attr.ib(factory=list)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
def __init__(self, events_es=None):
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
self._metrics = EventMetrics(self.es)
self._skip_iteration_for_metric = set(config.get("services.events.ignore_iteration.metrics", []))
self._skip_iteration_for_metric = set(
config.get("services.events.ignore_iteration.metrics", [])
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
@property
def metrics(self) -> EventMetrics:
@@ -62,9 +56,12 @@ class EventBLL(object):
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_events = nested_dict(
task_last_scalar_events = nested_dict(
3, dict
) # task_id -> metric_hash -> variant_hash -> MetricEvent
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
for event in events:
# remove spaces from event type
@@ -106,6 +103,9 @@ class EventBLL(object):
event["value"] = event["values"]
del event["values"]
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
index_name = EventMetrics.get_index_name(company_id, event_type)
es_action = {
"_op_type": "index", # overwrite if exists with same ID
@@ -124,12 +124,18 @@ class EventBLL(object):
if task_id is not None:
es_action["_routing"] = task_id
task_ids.add(task_id)
if iter is not None and event.get("metric") not in self._skip_iteration_for_metric:
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_metric_event_for_task(
task_last_events=task_last_events, task_id=task_id, event=event
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
else:
es_action["_routing"] = task_id
@@ -182,6 +188,7 @@ class EventBLL(object):
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
@@ -197,12 +204,12 @@ class EventBLL(object):
return added, errors_in_bulk
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
def _update_last_scalar_events_for_task(self, last_events, event):
"""
Update task_last_events structure for the provided task_id with the provided event details if this event is more
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
@@ -213,13 +220,34 @@ class EventBLL(object):
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_events = task_last_events[task_id]
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None):
def _update_last_metric_events_for_task(self, last_events, event):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/event_type combination.
last_events contains [metric_name -> event_type -> event]
"""
metric = event.get("metric")
event_type = event.get("type")
if not (metric and event_type):
return
timestamp = last_events[metric][event_type].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric][event_type] = event
def _update_task(
self,
company_id,
task_id,
now,
iter_max=None,
last_scalar_events=None,
last_events=None,
):
"""
Update task information in DB with aggregated results after handling event(s) related to this task.
@@ -232,15 +260,18 @@ class EventBLL(object):
if iter_max is not None:
fields["last_iteration_max"] = iter_max
if last_events:
fields["last_values"] = list(
if last_scalar_events:
fields["last_scalar_values"] = list(
flatten_nested_items(
last_events,
last_scalar_events,
nesting=2,
include_leaves=["value", "metric", "variant"],
)
)
if last_events:
fields["last_events"] = last_events
if not fields:
return False
@@ -279,7 +310,9 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
es_res = self.es.search(
index=es_index, body=es_req, scroll="1h", routing=task_id
)
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
next_scroll_id = es_res["_scroll_id"]
@@ -297,10 +330,16 @@ class EventBLL(object):
"size": 0,
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"variants": {
"terms": {"field": "variant"},
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"iters": {
"terms": {
@@ -499,8 +538,18 @@ class EventBLL(object):
"size": 0,
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": 200},
"aggs": {"variants": {"terms": {"field": "variant", "size": 200}}},
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
}
}
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
@@ -540,14 +589,14 @@ class EventBLL(object):
"metrics": {
"terms": {
"field": "metric",
"size": 1000,
"size": EventMetrics.MAX_METRICS_COUNT,
"order": {"_term": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": 1000,
"size": EventMetrics.MAX_VARIANTS_COUNT,
"order": {"_term": "asc"},
},
"aggs": {

View File

@@ -1,12 +1,13 @@
import itertools
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Callable, Iterable
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from typing import Sequence, Tuple, Callable
from mongoengine import Q
from apierrors import errors
@@ -20,10 +21,19 @@ from utilities import safe_get
log = config.logger(__file__)
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
class EventMetrics:
MAX_TASKS_COUNT = 100
MAX_TASKS_COUNT = 50
MAX_METRICS_COUNT = 200
MAX_VARIANTS_COUNT = 500
MAX_AGGS_ELEMENTS_COUNT = 50
def __init__(self, es: Elasticsearch):
self.es = es
@@ -62,6 +72,12 @@ class EventMetrics:
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
if len(task_ids) > self.MAX_TASKS_COUNT:
raise errors.BadRequest(
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
len(task_ids),
)
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
@@ -97,6 +113,31 @@ class EventMetrics:
MetricInterval = Tuple[int, Sequence[TaskMetric]]
MetricData = Tuple[str, dict]
def _split_metrics_by_max_aggs_count(
self, task_metrics: Sequence[TaskMetric]
) -> Iterable[Sequence[TaskMetric]]:
"""
Return task metrics in groups where amount of task metrics in each group
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
variants while always preserving all their tasks in the same group
"""
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
yield task_metrics
return
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
groups = []
for group in tm_grouped.values():
groups.append(group)
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
yield list(itertools.chain(*groups))
groups = []
if groups:
yield list(itertools.chain(*groups))
return
def _run_get_scalar_metrics_as_parallel(
self,
company_id: str,
@@ -126,21 +167,25 @@ class EventMetrics:
if not intervals:
return {}
with ThreadPoolExecutor(len(intervals)) as pool:
metrics = list(
itertools.chain.from_iterable(
pool.map(
partial(
get_func, task_ids=task_ids, es_index=es_index, key=key
),
intervals,
)
intervals = list(
itertools.chain.from_iterable(
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
for i, tms in intervals
)
)
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
intervals,
)
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def _get_metric_intervals(
@@ -310,7 +355,13 @@ class EventMetrics:
"variants": {
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
"aggs": {
"tasks": {"terms": {"field": "task"}, "aggs": aggregation}
"tasks": {
"terms": {
"field": "task",
"size": self.MAX_TASKS_COUNT,
},
"aggs": aggregation,
}
},
}
},
@@ -396,3 +447,50 @@ class EventMetrics:
]
}
}
def get_tasks_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence[Tuple]:
"""
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = EventMetrics.get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
return [(tid, []) for tid in task_ids]
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_concurrency) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
),
task_ids,
)
return list(zip(task_ids, res))
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"type": event_type.value}},
]
}
},
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
]

View File

@@ -9,9 +9,12 @@ import es_factory
from apierrors import errors
from bll.queue.queue_metrics import QueueMetrics
from bll.workers import WorkerBLL
from config import config
from database.errors import translate_errors_context
from database.model.queue import Queue, Entry
log = config.logger(__file__)
class QueueBLL(object):
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
@@ -189,9 +192,7 @@ class QueueBLL(object):
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(
pop__entries=-1, last_update=datetime.utcnow(), upsert=False
)
queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
@@ -200,6 +201,11 @@ class QueueBLL(object):
if not queue.entries:
return
try:
Queue.objects(**query).update(last_update=datetime.utcnow())
except Exception:
log.exception("Error while updating Queue.last_update")
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:

View File

@@ -0,0 +1,44 @@
from typing import Optional, TypeVar, Generic, Type
from redis import StrictRedis
from timing_context import TimingContext
T = TypeVar("T")
class RedisCacheManager(Generic[T]):
"""
Class for store/retreive of state objects from redis
self.state_class - class of the state
self.redis - instance of redis
self.expiration_interval - expiration interval in seconds
"""
def __init__(
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
):
self.state_class = state_class
self.redis = redis
self.expiration_interval = expiration_interval
def set_state(self, state: T) -> None:
redis_key = self._get_redis_key(state.id)
with TimingContext("redis", "cache_set_state"):
self.redis.set(redis_key, state.to_json())
self.redis.expire(redis_key, self.expiration_interval)
def get_state(self, state_id) -> Optional[T]:
redis_key = self._get_redis_key(state_id)
with TimingContext("redis", "cache_get_state"):
response = self.redis.get(redis_key)
if response:
return self.state_class.from_json(response)
def delete_state(self, state_id) -> None:
with TimingContext("redis", "cache_delete_state"):
self.redis.delete(self._get_redis_key(state_id))
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"

View File

@@ -3,13 +3,14 @@ from datetime import datetime, timedelta
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
@@ -17,6 +18,7 @@ from config import config
from database.errors import translate_errors_context
from database.model.model import Model
from database.model.project import Project
from database.model.task.metrics import EventStats, MetricEventStats
from database.model.task.output import Output
from database.model.task.task import (
Task,
@@ -197,7 +199,9 @@ class TaskBLL(object):
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
output=Output(destination=task.output.destination)
if task.output
else None,
execution=execution_dict,
)
cls.validate(new_task)
@@ -277,7 +281,8 @@ class TaskBLL(object):
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_scalar_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
last_events: Dict[str, Dict[str, dict]] = None,
**extra_updates,
):
"""
@@ -289,7 +294,8 @@ class TaskBLL(object):
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_values: Last reported metrics summary (value, metric, variant).
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:
"""
@@ -300,17 +306,33 @@ class TaskBLL(object):
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_values is not None:
if last_scalar_values is not None:
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
for path, value in last_values:
for path, value in last_scalar_values:
extra_updates[op_path("set", *path)] = value
if path[-1] == "value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
extra_updates[op_path("max", *path[:-1], "max_value")] = value
if last_events is not None:
def events_per_type(metric_data: Dict[str, dict]) -> Dict[str, EventStats]:
return {
event_type: EventStats(last_update=event["timestamp"])
for event_type, event in metric_data.items()
}
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
)
for metric_key, metric_data in last_events.items()
}
extra_updates["metric_stats"] = metric_stats
Task.objects(id=task_id, company=company_id).update(
upsert=False, last_update=last_update, **extra_updates
)

View File

@@ -47,7 +47,7 @@ class BasicConfig:
def logger(self, name):
if Path(name).is_file():
name = Path(name).stem
path = ".".join((self.prefix, Path(name).stem))
path = ".".join((self.prefix, name))
return logging.getLogger(path)
def _read_extra_env_config_values(self):

View File

@@ -34,6 +34,12 @@
aggregate {
allow_disk_use: true
}
pre_populate {
enabled: false
zip_file: "/path/to/export.zip"
fail_on_error: false
}
}
auth {

View File

@@ -32,6 +32,11 @@ mongo {
}
redis {
apiserver {
host: "127.0.0.1"
port: 6379
db: 0
}
workers {
host: "127.0.0.1"
port: 6379

View File

@@ -2,4 +2,8 @@ es_index_prefix: "events"
ignore_iteration {
metrics: [":monitor:machine", ":monitor:gpu"]
}
}
# max number of concurrent queries to ES when calculating events metrics
# should not exceed the amount of concurrent connections set in the ES driver
max_metrics_concurrency: 4

View File

@@ -1,10 +1,18 @@
from mongoengine import EmbeddedDocument, StringField, DynamicField
from mongoengine import (
EmbeddedDocument,
StringField,
DynamicField,
LongField,
EmbeddedDocumentField,
)
from database.fields import SafeMapField
class MetricEvent(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
'strict': False,
"strict": False,
}
metric = StringField(required=True)
@@ -12,3 +20,20 @@ class MetricEvent(EmbeddedDocument):
value = DynamicField(required=True)
min_value = DynamicField() # for backwards compatibility reasons
max_value = DynamicField() # for backwards compatibility reasons
class EventStats(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
"strict": False,
}
last_update = LongField()
class MetricEventStats(EmbeddedDocument):
meta = {
# For backwards compatibility reasons
"strict": False,
}
metric = StringField(required=True)
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))

View File

@@ -22,7 +22,7 @@ from database.model.base import ProperDictMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
from .metrics import MetricEvent
from .metrics import MetricEvent, MetricEventStats
from .output import Output
DEFAULT_LAST_ITERATION = 0
@@ -162,3 +162,4 @@ class Task(AttributedDocument):
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))

View File

@@ -96,7 +96,12 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
continue
if desc:
if callable(desc):
desc(value)
try:
desc(value)
except TypeError:
raise ParseCallError(f"expecting {desc.__name__}", field=field)
except Exception as ex:
raise ParseCallError(str(ex), field=field)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(
value, desc

View File

@@ -0,0 +1,27 @@
from furl import furl
from config import config
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
log = config.logger(__file__)
class MissingElasticConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
def init_es_data():
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
log.info(res)

View File

@@ -1,222 +0,0 @@
import importlib.util
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import attr
from furl import furl
from mongoengine.connection import get_db
from semantic_version import Version
import database.utils
from bll.queue import QueueBLL
from config import config
from config.info import get_default_company
from database import Database
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings
from database.model.user import User
from database.model.version import Version as DatabaseVersion
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
from service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
migration_dir = Path(__file__).resolve().parent / "mongo" / "migrations"
class MissingElasticConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
def init_es_data():
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
log.info(res)
def _ensure_company():
company_id = get_default_company()
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
return company_id
def _ensure_default_queue(company):
"""
If no queue is present for the company then
create a new one and mark it as a default
"""
queue = Queue.objects(company=company).only("id").first()
if queue:
return
QueueBLL.create(company, name="default", system_tags=["default"])
def _ensure_auth_user(user_data, company_id):
ensure_credentials = {"key", "secret"}.issubset(user_data.keys())
if ensure_credentials:
user = AuthUser.objects(
credentials__match=Credentials(
key=user_data["key"], secret=user_data["secret"]
)
).first()
if user:
return user.id
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_data.get("id", f"__{user_data['name']}__"),
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
if ensure_credentials
else None,
)
user.save()
return user.id
def _ensure_user(user: FixedUser, company_id: str):
if User.objects(id=user.user_id).first():
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
_ensure_auth_user(user_data=data, company_id=company_id)
given_name, _, family_name = user.name.partition(" ")
User(
id=user.user_id,
company=company_id,
name=user.name,
given_name=given_name,
family_name=family_name,
).save()
def _apply_migrations():
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:
raise ValueError(f"Failed parsing migration version from file: {ex}")
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
migration_log = log.getChild("mongodb_migration")
for script_version in sorted(new_scripts.keys()):
script = new_scripts[script_version]
spec = importlib.util.spec_from_file_location(script.stem, str(script))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for alias, func_name in dbs.items():
func = getattr(module, func_name, None)
if not func:
continue
try:
migration_log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
except Exception:
migration_log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError("Migration failed, aborting. Please restore backup.")
DatabaseVersion(
id=database.utils.id(),
num=script.stem,
created=datetime.utcnow(),
desc="Applied on server startup",
).save()
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))
def init_mongo_data():
try:
_apply_migrations()
_ensure_uuid()
company_id = _ensure_company()
_ensure_default_queue(company_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
for user in users:
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_auth_user(user, company_id)
if FixedUser.enabled():
log.info("Fixed users mode is enabled")
FixedUser.validate()
for user in FixedUser.from_config():
try:
_ensure_user(user, company_id)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
except Exception as ex:
log.exception("Failed initializing mongodb")

View File

@@ -0,0 +1,70 @@
from pathlib import Path
from config import config
from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations
from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
log = config.logger(__package__)
def init_mongo_data():
try:
empty_dbs = _apply_migrations(log)
_ensure_uuid()
company_id = _ensure_company(log)
_ensure_default_queue(company_id)
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
if not zip_file or not Path(zip_file).is_file():
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
PrePopulate.import_from_zip(zip_file, user_id=user_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
for user in users:
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_auth_user(user, company_id, log=log)
if FixedUser.enabled():
log.info("Fixed users mode is enabled")
FixedUser.validate()
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, company_id, log=log)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
except Exception as ex:
log.exception("Failed initializing mongodb")

View File

@@ -0,0 +1,86 @@
import importlib.util
from datetime import datetime
from logging import Logger
from pathlib import Path
from mongoengine.connection import get_db
from semantic_version import Version
import database.utils
from database import Database
from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def _apply_migrations(log: Logger) -> bool:
"""
Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration.
"""
log = log.getChild(Path(__file__).stem)
log.info(f"Started mongodb migrations")
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = not any(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:
raise ValueError(f"Failed parsing migration version from file: {ex}")
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
for script_version in sorted(new_scripts):
script = new_scripts[script_version]
if empty_dbs:
log.info(f"Skipping migration {script.name} (empty databases)")
else:
spec = importlib.util.spec_from_file_location(script.stem, str(script))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for alias, func_name in dbs.items():
func = getattr(module, func_name, None)
if not func:
continue
try:
log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
except Exception:
log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError(
"Migration failed, aborting. Please restore backup."
)
DatabaseVersion(
id=database.utils.id(),
num=script.stem,
created=datetime.utcnow(),
desc="Applied on server startup",
).save()
log.info("Finished mongodb migrations")
return empty_dbs

View File

@@ -0,0 +1,153 @@
import importlib
from collections import defaultdict
from datetime import datetime
from os.path import splitext
from typing import List, Optional, Any, Type, Set, Dict
from zipfile import ZipFile, ZIP_BZIP2
import mongoengine
from tqdm import tqdm
class PrePopulate:
@classmethod
def export_to_zip(
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
):
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
cls._export(zfile, experiments, projects)
@classmethod
def import_from_zip(cls, filename: str, user_id: str = None):
with ZipFile(filename) as zfile:
cls._import(zfile, user_id)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[List[str]]
) -> List[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
resolved = {i.id for i in items}
missing = ids - resolved
for name_candidate in missing:
results = list(cls.objects(name=name_candidate))
if not results:
print(f"ERROR: no match for `{name_candidate}`")
exit(1)
elif len(results) > 1:
print(f"ERROR: more than one match for `{name_candidate}`")
exit(1)
items.append(results[0])
return items
@classmethod
def _resolve_entities(
cls, experiments: List[str] = None, projects: List[str] = None
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
from database.model.project import Project
from database.model.task.task import Task
entities = defaultdict(set)
if projects:
print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects))
print("--> Reading project experiments...")
objs = Task.objects(
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
)
entities[Task].update(o for o in objs if o.id not in (experiments or []))
if experiments:
print("Reading experiments...")
entities[Task].update(cls._resolve_type(Task, experiments))
print("--> Reading experiments projects...")
objs = Project.objects(
id__in=list(set(filter(None, (p.project for p in entities[Task]))))
)
project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids)
return entities
@classmethod
def _cleanup_task(cls, task):
from database.model.task.task import TaskStatus
task.completed = None
task.started = None
if task.execution:
task.execution.model = None
task.execution.model_desc = None
task.execution.model_labels = None
if task.output:
task.output.model = None
task.status = TaskStatus.created
task.comment = "Auto generated by Allegro.ai"
task.created = datetime.utcnow()
task.last_iteration = 0
task.last_update = task.created
task.status_changed = task.created
task.status_message = ""
task.status_reason = ""
task.user = ""
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
from database.model.task.task import Task
if entity_cls == Task:
cls._cleanup_task(entity)
@classmethod
def _export(
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
):
entities = cls._resolve_entities(experiments, projects)
for cls_, items in entities.items():
if not items:
continue
filename = f"{cls_.__module__}.{cls_.__name__}.json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with writer.open(filename, "w") as f:
f.write("[\n".encode("utf-8"))
last = len(items) - 1
for i, item in enumerate(items):
cls._cleanup_entity(cls_, item)
f.write(item.to_json().encode("utf-8"))
if i != last:
f.write(",".encode("utf-8"))
f.write("\n".encode("utf-8"))
f.write("]\n".encode("utf-8"))
@staticmethod
def _import(reader: ZipFile, user_id: str = None):
for file_info in reader.filelist:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
with reader.open(file_info) as f:
for item in tqdm(
f.readlines(),
desc=f"Writing {cls_.__name__.lower()}s into database",
unit="doc",
):
item = (
item.decode("utf-8")
.strip()
.lstrip("[")
.rstrip("]")
.rstrip(",")
.strip()
)
if not item:
continue
doc = cls_.from_json(item)
if user_id is not None and hasattr(doc, "user"):
doc.user = user_id
doc.save(force_insert=True)

View File

@@ -0,0 +1,74 @@
from datetime import datetime
from logging import Logger
import attr
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.user import User
from service_repo.auth.fixed_user import FixedUser
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger):
ensure_credentials = {"key", "secret"}.issubset(user_data)
if ensure_credentials:
user = AuthUser.objects(
credentials__match=Credentials(
key=user_data["key"], secret=user_data["secret"]
)
).first()
if user:
return user.id
log.info(f"Creating user: {user_data['name']}")
user = AuthUser(
id=user_data.get("id", f"__{user_data['name']}__"),
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])]
if ensure_credentials
else None,
)
user.save()
return user.id
def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
given_name, _, family_name = user_name.partition(" ")
User(
id=user_id,
company=company_id,
name=user_name,
given_name=given_name,
family_name=family_name,
).save()
return user_id
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
if User.objects(id=user.user_id).first():
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
given_name, _, family_name = user.name.partition(" ")
User(
id=user.user_id,
company=company_id,
name=user.name,
given_name=given_name,
family_name=family_name,
).save()

View File

@@ -0,0 +1,40 @@
from logging import Logger
from uuid import uuid4
from bll.queue import QueueBLL
from config import config
from config.info import get_default_company
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings
log = config.logger(__file__)
def _ensure_company(log: Logger):
company_id = get_default_company()
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
return company_id
def _ensure_default_queue(company):
"""
If no queue is present for the company then
create a new one and mark it as a default
"""
queue = Queue.objects(company=company).only("id").first()
if queue:
return
QueueBLL.create(company, name="default", system_tags=["default"])
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))

View File

@@ -0,0 +1,46 @@
import hashlib
from pymongo.database import Database, Collection
from service_repo.auth.fixed_user import FixedUser
def _get_ids():
if not FixedUser.enabled():
return
return {
hashlib.md5(f"{user.username}:{user.password}".encode()).hexdigest(): user.user_id
for user in FixedUser.from_config()
}
def _switch_uuid(collection: Collection, uuid_field: str, uuids: dict):
docs = list(collection.find({uuid_field: {"$in": [uuids]}}))
if not docs:
return
replaced_uuids = [doc[uuid_field] for doc in docs]
for doc in docs:
doc[uuid_field] = uuids[doc[uuid_field]]
collection.insert_many(docs)
collection.delete_many({uuid_field: {"$in": replaced_uuids}})
def migrate_auth(db: Database):
uuids = _get_ids()
if not uuids:
return
collection = db["user"]
collection.drop_index("name_1_company_1")
_switch_uuid(collection=collection, uuid_field="_id", uuids=uuids)
def migrate_backend(db: Database):
uuids = _get_ids()
if not uuids:
return
for name in ("project", "task", "model"):
_switch_uuid(collection=db[name], uuid_field="user", uuids=uuids)

View File

@@ -1,29 +1,30 @@
six
Flask>=0.12.2
elasticsearch>=5.0.0,<6.0.0
pyhocon>=0.3.35
requests>=2.13.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
Flask-Cors>=3.0.5
Flask-Compress>=1.4.0
mongoengine==0.16.2
jsonmodels>=2.3
pyjwt>=1.3.0
gunicorn>=19.7.1
Jinja2==2.10
python-rapidjson>=0.6.3
jsonschema>=2.6.0
dpath>=1.4.2
funcsigs==1.0.2
luqum>=0.7.2
attrs>=19.1.0
nested_dict>=1.61
related>=0.7.2
validators>=0.12.4
fastjsonschema>=2.8
boltons>=19.1.0
semantic_version>=2.6.0,<3
dpath>=1.4.2,<2.0
elasticsearch>=5.0.0,<6.0.0
fastjsonschema>=2.8
Flask-Compress>=1.4.0
Flask-Cors>=3.0.5
Flask>=0.12.2
funcsigs==1.0.2
furl>=2.0.0
redis>=2.10.5
gunicorn>=19.7.1
humanfriendly==4.18
Jinja2==2.10
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.7.2
mongoengine==0.16.2
nested_dict>=1.61
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt>=1.3.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
python-rapidjson>=0.6.3
redis>=2.10.5
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.0,<3
six
tqdm
validators>=0.12.4

View File

@@ -171,6 +171,30 @@
critical
]
}
event_type_enum {
type: string
enum: [
training_stats_scalar
training_stats_vector
training_debug_image
plot
log
]
}
task_metric {
type: object
required: [task, metric]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
}
}
task_log_event {
description: """A log event associated with a task."""
type: object
@@ -319,6 +343,84 @@
}
}
}
"2.7" {
description: "Get the debug image events for the requested amount of iterations per each task's metric"
request {
type: object
required: [
metrics
]
properties {
metrics {
type: array
items { "$ref": "#/definitions/task_metric" }
description: "List metrics for which the envents will be retreived"
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
navigate_earlier {
type: boolean
description: "If set then events are retreived from later iterations to earlier ones. Otherwise from earlier iterations to the later. The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest iterations. The default is False"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
metrics {
type: array
items: { type: object }
description: "Debug image events grouped by task metrics and iterations"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_task_metrics{
"2.7": {
description: "For each task, get a list of metrics for which the requested event type was reported"
request {
type: object
required: [
tasks
]
properties {
tasks {
type: array
items { type: string }
description: "Task IDs"
}
event_type {
"description": "Event type"
"$ref": "#/definitions/event_type_enum"
}
}
}
response {
type: object
properties {
metrics {
type: array
items { type: object }
description: "List of task with their metrics"
}
}
}
}
}
get_task_log {
"1.5" {

View File

@@ -10,7 +10,8 @@ import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config
from init_data import init_es_data, init_mongo_data
from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError

View File

@@ -9,6 +9,7 @@ import jsonmodels.models
import timing_context
from apierrors import APIError
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
from api_version import __version__ as _api_version_
from config import config
from service_repo.base import PartialVersion
from .apicall import APICall
@@ -34,7 +35,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.6")
_max_version = PartialVersion(".".join(_api_version_.split(".")[:2]))
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -166,7 +167,7 @@ class ServiceRepo(object):
return
assert isinstance(endpoint, Endpoint)
call.actual_endpoint_version: PartialVersion = endpoint.min_version
call.actual_endpoint_version = endpoint.min_version
call.requires_authorization = endpoint.authorize
return endpoint

View File

@@ -2,12 +2,15 @@ import itertools
from collections import defaultdict
from operator import itemgetter
import six
from apierrors import errors
from apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
@@ -299,7 +302,7 @@ def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
if isinstance(task_ids, six.string_types):
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
@@ -481,7 +484,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images(call, company_id, req_model):
def get_debug_images_v1_8(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
@@ -507,6 +510,53 @@ def get_debug_images(call, company_id, req_model):
)
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
tasks = set(m.task for m in req_model.metrics)
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
result = event_bll.debug_images_iterator.get_task_events(
company_id=company_id,
metrics=[(m.task, m.metric) for m in req_model.metrics],
iter_count=req_model.iters,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
)
call.result.data_model = DebugImageResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
metric=metric,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, metric, iterations) in result.metric_events
],
)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
task_bll.assert_exists(
call.identity.company, task_ids=req_model.tasks, allow_public=True
)
res = event_bll.metrics.get_tasks_metrics(
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]

View File

@@ -61,7 +61,7 @@ def get_by_id(call):
def make_projects_get_all_pipelines(project_ids, specific_state=None):
archived = EntityVisibility.archived.value
def ensure_system_tags():
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
@@ -73,6 +73,9 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
"then": [],
"else": "$system_tags",
}
},
"status": {
"$ifNull": ["$status", "unknown"]
}
}
}
@@ -80,7 +83,7 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
status_count_pipeline = [
# count tasks per project per status
{"$match": {"project": {"$in": project_ids}}},
ensure_system_tags(),
ensure_valid_fields(),
{
"$group": {
"_id": {
@@ -153,7 +156,7 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
"project": {"$in": project_ids},
}
},
ensure_system_tags(),
ensure_valid_fields(),
{
# for each project
"$group": group_step

View File

@@ -1,14 +1,14 @@
import operator
from time import sleep
from typing import Sequence
from typing import Sequence, Mapping
from tests.automated import TestService
class TestEntityOrdering(TestService):
test_comment = "Entity ordering test"
only_fields = ["id", "started", "comment"]
only_fields = ["id", "started", "comment", "execution.parameters"]
def setUp(self, **kwargs):
super().setUp(**kwargs)
@@ -27,6 +27,9 @@ class TestEntityOrdering(TestService):
# sort by the same field that we use for the search
self._assertGetTasksWithOrdering(order_by="comment")
# sort by parameter which type is not part of db schema
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
def test_order_with_paging(self):
order_field = "started"
# all results in one page
@@ -52,7 +55,7 @@ class TestEntityOrdering(TestService):
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
return self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=[order_by] if order_by else None,
order_by=[order_by] if isinstance(order_by, str) else order_by,
comment=self.test_comment,
page=page,
page_size=page_size,
@@ -63,12 +66,19 @@ class TestEntityOrdering(TestService):
Assert that vals are sorted in the ascending or descending order
with None values are always coming from the end
"""
if None in vals:
first_null_idx = vals.index(None)
none_tail = vals[first_null_idx:]
vals = vals[:first_null_idx]
self.assertTrue(all(val is None for val in none_tail))
self.assertTrue(all(val is not None for val in vals))
empty = [None, "", [], {}]
empty_value = None
idx = 0
for idx, val in enumerate(vals):
if val in empty:
empty_value = val
break
if idx < len(vals) - 1:
none_tail = vals[idx:]
vals = vals[:idx]
self.assertTrue(all(val == empty_value for val in none_tail))
self.assertTrue(all(val != empty_value for val in vals))
if ascending:
cmp = operator.le
@@ -76,10 +86,18 @@ class TestEntityOrdering(TestService):
cmp = operator.ge
self.assertTrue(all(cmp(i, j) for i, j in zip(vals, vals[1:])))
def _get_value_for_path(self, data: Mapping, field_path: Sequence[str]):
val = None
for name in field_path:
val = data.get(name)
data = val if isinstance(val, dict) else {}
return val
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=[order_by] if order_by else None,
order_by=[order_by] if isinstance(order_by, str) else order_by,
comment=self.test_comment,
**kwargs,
).tasks
@@ -87,12 +105,17 @@ class TestEntityOrdering(TestService):
if order_by:
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [t.get(field_name) for t in tasks]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
def _create_tasks(self):
tasks = [self._temp_task() for _ in range(10)]
for _, task in zip(range(5), tasks):
tasks = [
self._temp_task(
**(dict(execution={"parameters": {"test": f"{i}"} if i >= 5 else {}}))
)
for i in range(10)
]
for idx, task in zip(range(5), tasks):
self.api.tasks.started(task=task)
sleep(0.1)
return tasks

View File

@@ -2,83 +2,199 @@
Comprehensive test of all(?) use cases of datasets and frames
"""
import json
import time
import unittest
from functools import partial
from statistics import mean
from typing import Sequence
import es_factory
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestTaskEvents(TestService):
def setUp(self, version="1.7"):
def setUp(self, version="2.7"):
super().setUp(version=version)
self.created_tasks = []
self.task = dict(
name="test task events",
type="training",
input=dict(mapping={}, view=dict(entries=[])),
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
)
res, self.task_id = self.api.send("tasks.create", self.task, extract="id")
assert res.meta.result_code == 200
self.created_tasks.append(self.task_id)
return self.create_temp("tasks", **task_input)
def tearDown(self):
log.info("Cleanup...")
for task_id in self.created_tasks:
try:
self.api.send("tasks.delete", dict(task=task_id, force=True))
except Exception as ex:
log.exception(ex)
def create_task_event(self, type, iteration):
def _create_task_event(self, type_, task, iteration, **kwargs):
return {
"worker": "test",
"type": type,
"task": self.task_id,
"type": type_,
"task": task,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis()
"timestamp": es_factory.get_timestamp_millis(),
**kwargs,
}
def copy_and_update(self, src_obj, new_data):
def _copy_and_update(self, src_obj, new_data):
obj = src_obj.copy()
obj.update(new_data)
return obj
def test_task_metrics(self):
tasks = {
self._temp_task(): {
"Metric1": ["training_debug_image"],
"Metric2": ["training_debug_image", "log"],
},
self._temp_task(): {"Metric3": ["training_debug_image"]},
}
events = [
self._create_task_event(
event_type,
task=task,
iteration=1,
metric=metric,
variant="Test variant",
)
for task, metrics in tasks.items()
for metric, event_types in metrics.items()
for event_type in event_types
]
self.send_batch(events)
self._assert_task_metrics(tasks, "training_debug_image")
self._assert_task_metrics(tasks, "log")
self._assert_task_metrics(tasks, "training_stats_scalar")
def _assert_task_metrics(self, tasks: dict, event_type: str):
res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type)
for task, metrics in tasks.items():
res_metrics = next(
(tm.metrics for tm in res.metrics if tm.task == task), ()
)
self.assertEqual(
set(res_metrics),
set(
metric for metric, events in metrics.items() if event_type in events
),
)
def test_task_debug_images(self):
task = self._temp_task()
metric = "Metric1"
variants = [("Variant1", 7), ("Variant2", 4)]
iterations = 10
# test empty
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
)
self.assertFalse(res.metrics)
# create events
events = [
self._create_task_event(
"training_debug_image",
task=task,
iteration=n,
metric=metric,
variant=variant,
url=f"{metric}_{variant}_{n % unique_images}",
)
for n in range(iterations)
for (variant, unique_images) in variants
]
self.send_batch(events)
# init testing
unique_images = [unique for (_, unique) in variants]
scroll_id = None
assert_debug_images = partial(
self._assertDebugImages,
task=task,
metric=metric,
max_iter=iterations - 1,
unique_images=unique_images,
)
# test forward navigation
for page in range(3):
scroll_id = assert_debug_images(scroll_id=scroll_id, page=page)
# test backwards navigation
scroll_id = assert_debug_images(
scroll_id=scroll_id, page=0, navigate_earlier=False
)
# beyond the latest iteration and back
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
scroll_id=scroll_id,
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_debug_images(scroll_id=scroll_id, page=1)
# refresh
assert_debug_images(scroll_id=scroll_id, page=0, refresh=True)
def _assertDebugImages(
self,
task,
metric,
max_iter: int,
unique_images: Sequence[int],
scroll_id,
page: int,
iters: int = 5,
**extra_params,
):
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=iters,
scroll_id=scroll_id,
**extra_params,
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
self.assertEqual(data["metric"], metric)
left_iterations = max(0, max(unique_images) - page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
events_per_iter = sum(
1 for unique in unique_images if unique > max_iter - it["iter"]
)
self.assertEqual(len(it["events"]), events_per_iter)
return res.scroll_id
def test_task_logs(self):
events = []
for iter in range(10):
log_event = self.create_task_event("log", iteration=iter)
task = self._temp_task()
for iter_ in range(10):
log_event = self._create_task_event("log", task, iteration=iter_)
events.append(
self.copy_and_update(
self._copy_and_update(
log_event,
{"msg": "This is a log message from test task iter " + str(iter)},
{"msg": "This is a log message from test task iter " + str(iter_)},
)
)
# sleep so timestamp is not the same
import time
time.sleep(0.01)
self.send_batch(events)
data = self.api.events.get_task_log(task=self.task_id)
data = self.api.events.get_task_log(task=task)
assert len(data["events"]) == 10
self.api.tasks.reset(task=self.task_id)
data = self.api.events.get_task_log(task=self.task_id)
self.api.tasks.reset(task=task)
data = self.api.events.get_task_log(task=task)
assert len(data["events"]) == 0
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self.create_task_event("training_stats_scalar", iteration),
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
@@ -88,19 +204,65 @@ class TestTaskEvents(TestService):
self.send_batch(events)
for key in None, "iter", "timestamp", "iso_time":
with self.subTest(key=key):
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, key=key)
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key)
self.assertIn(metric, data)
self.assertIn(variant, data[metric])
self.assertIn("x", data[metric][variant])
self.assertIn("y", data[metric][variant])
def test_multitask_events_many_metrics(self):
tasks = [
self._temp_task(name="test events1"),
self._temp_task(name="test events2"),
]
iter_count = 10
metrics_count = 10
variants_count = 10
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": f"Metric{metric_idx}",
"variant": f"Variant{variant_idx}",
"value": iteration,
}
for iteration in range(iter_count)
for task in tasks
for metric_idx in range(metrics_count)
for variant_idx in range(variants_count)
]
self.send_batch(events)
data = self.api.events.multi_task_scalar_metrics_iter_histogram(tasks=tasks)
self._assert_metrics_and_variants(
data.metrics,
metrics=metrics_count,
variants=variants_count,
tasks=tasks,
iterations=iter_count,
)
def _assert_metrics_and_variants(
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
):
self.assertEqual(len(data), metrics)
for m in range(metrics):
metric_data = data[f"Metric{m}"]
self.assertEqual(len(metric_data), variants)
for v in range(variants):
variant_data = metric_data[f"Variant{v}"]
self.assertEqual(len(variant_data), len(tasks))
for t in tasks:
task_data = variant_data[t]
self.assertEqual(len(task_data["x"]), iterations)
self.assertEqual(len(task_data["y"]), iterations)
def test_task_metric_value_intervals(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self.create_task_event("training_stats_scalar", iteration),
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
@@ -109,13 +271,13 @@ class TestTaskEvents(TestService):
]
self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id)
data = self.api.events.scalar_metrics_iter_histogram(task=task)
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=100)
data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=100)
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=10)
data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=10)
self._assert_metrics_histogram(data[metric][variant], iter_count, 10)
def _assert_metrics_histogram(self, data, iters, samples):
@@ -130,7 +292,8 @@ class TestTaskEvents(TestService):
)
def test_task_plots(self):
event = self.create_task_event("plot", 0)
task = self._temp_task()
event = self._create_task_event("plot", task, 0)
event["metric"] = "roc"
event.update(
{
@@ -179,7 +342,7 @@ class TestTaskEvents(TestService):
)
self.send(event)
event = self.create_task_event("plot", 100)
event = self._create_task_event("plot", task, 100)
event["metric"] = "confusion"
event.update(
{
@@ -222,11 +385,11 @@ class TestTaskEvents(TestService):
)
self.send(event)
data = self.api.events.get_task_plots(task=self.task_id)
data = self.api.events.get_task_plots(task=task)
assert len(data["plots"]) == 2
self.api.tasks.reset(task=self.task_id)
data = self.api.events.get_task_plots(task=self.task_id)
self.api.tasks.reset(task=task)
data = self.api.events.get_task_plots(task=task)
assert len(data["plots"]) == 0
def send_batch(self, events):

View File

@@ -1 +1 @@
__version__ = "0.13.0"
__version__ = "0.14.1"