mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add Artifacts support, changed tags to system_tags and added user tags
Add hyper parameter sorting Add min/max value for all time series metrics
This commit is contained in:
18
migration/mongodb/0.12.1.py
Normal file
18
migration/mongodb/0.12.1.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
from database.utils import partition_tags
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
for name in ("project", "task", "model"):
|
||||
collection: Collection = db[name]
|
||||
for doc in collection.find(projection=["tags", "system_tags"]):
|
||||
tags = doc.get("tags")
|
||||
if tags is not None:
|
||||
user_tags, system_tags = partition_tags(
|
||||
name, tags, doc.get("system_tags", [])
|
||||
)
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"system_tags": system_tags, "tags": user_tags}}
|
||||
)
|
||||
@@ -83,7 +83,8 @@ _error_codes = {
|
||||
21: ('bad_credentials', 'unauthorized (malformed credentials)'),
|
||||
22: ('invalid_credentials', 'unauthorized (invalid credentials)'),
|
||||
30: ('invalid_token', 'invalid token'),
|
||||
31: ('blocked_token', 'token is blocked')
|
||||
31: ('blocked_token', 'token is blocked'),
|
||||
40: ('invalid_fixed_user', 'fixed user ID was not found')
|
||||
},
|
||||
|
||||
(403, 'forbidden'): {
|
||||
|
||||
@@ -4,11 +4,10 @@ from enum import Enum
|
||||
from typing import Union, Type, Iterable
|
||||
|
||||
import jsonmodels.errors
|
||||
import jsonmodels.validators
|
||||
import six
|
||||
import validators
|
||||
from jsonmodels import fields
|
||||
from jsonmodels.fields import _LazyType
|
||||
from jsonmodels.fields import _LazyType, NotSet
|
||||
from jsonmodels.models import Base as ModelBase
|
||||
from jsonmodels.validators import Enum as EnumValidator
|
||||
from luqum.parser import parser, ParseError
|
||||
@@ -25,6 +24,12 @@ def make_default(field_cls, default_value):
|
||||
|
||||
|
||||
class ListField(fields.ListField):
|
||||
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
|
||||
if default is not NotSet and callable(default):
|
||||
default = default()
|
||||
|
||||
super(ListField, self).__init__(items_types, *args, default=default, **kwargs)
|
||||
|
||||
def _cast_value(self, value):
|
||||
try:
|
||||
return super(ListField, self)._cast_value(value)
|
||||
@@ -144,6 +149,46 @@ class EnumField(fields.StringField):
|
||||
return super().parse_value(value)
|
||||
|
||||
|
||||
class ActualEnumField(fields.StringField):
|
||||
@property
|
||||
def types(self):
|
||||
return (self.__enum,)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enum_class: Type[Enum],
|
||||
*args,
|
||||
validators=None,
|
||||
required=False,
|
||||
default=None,
|
||||
**kwargs
|
||||
):
|
||||
self.__enum = enum_class
|
||||
# noinspection PyTypeChecker
|
||||
choices = list(enum_class)
|
||||
validator_cls = EnumValidator if required else NullableEnumValidator
|
||||
validators = [*(validators or []), validator_cls(*choices)]
|
||||
super().__init__(
|
||||
default=default and self.parse_value(default),
|
||||
*args,
|
||||
required=required,
|
||||
validators=validators,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value is None and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
return self.__enum(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
def to_struct(self, value):
|
||||
return super().to_struct(value.value)
|
||||
|
||||
|
||||
class EmailField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
@@ -160,3 +205,12 @@ class DomainField(fields.StringField):
|
||||
return
|
||||
if validators.domain(value) is not True:
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField, DateTimeField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Max, Enum
|
||||
|
||||
@@ -79,6 +79,7 @@ class Credentials(Base):
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
|
||||
20
server/apimodels/events.py
Normal file
20
server/apimodels/events.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.scalar_key import ScalarKeyEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=10000)
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(items_types=str)
|
||||
@@ -11,6 +11,7 @@ class CreateModelRequest(models.Base):
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types+(int,), required=True)
|
||||
tags = ListField(items_types=string_types)
|
||||
system_tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
public = fields.BoolField(default=False)
|
||||
project = fields.StringField()
|
||||
|
||||
16
server/apimodels/projects.py
Normal file
16
server/apimodels/projects.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
project = fields.StringField()
|
||||
|
||||
|
||||
class GetHyperParamReq(ProjectReq):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
||||
class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
@@ -57,5 +57,5 @@ class CreateRequest(TaskData):
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskType)))
|
||||
|
||||
|
||||
class PingRequest(models.Base):
|
||||
class PingRequest(TaskRequest):
|
||||
task = StringField(required=True)
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from operator import attrgetter
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.task import TaskBLL
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from database.model.task.metrics import MetricEvent
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
@@ -44,7 +44,12 @@ class EventBLL(object):
|
||||
id_fields = ["task", "iter", "metric", "variant", "key"]
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.es = events_es if events_es is not None else es_factory.connect("events")
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
self._metrics = EventMetrics(self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
return self._metrics
|
||||
|
||||
def add_events(self, company_id, events, worker):
|
||||
actions = []
|
||||
@@ -94,7 +99,7 @@ class EventBLL(object):
|
||||
event["value"] = event["values"]
|
||||
del event["values"]
|
||||
|
||||
index_name = EventBLL.get_index_name(company_id, event_type)
|
||||
index_name = EventMetrics.get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
@@ -154,13 +159,6 @@ class EventBLL(object):
|
||||
else:
|
||||
errors_in_bulk.append(info)
|
||||
|
||||
last_metrics = {
|
||||
t.id: t.to_proper_dict().get("last_metrics", {})
|
||||
for t in Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"last_metrics"
|
||||
)
|
||||
}
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
@@ -173,7 +171,6 @@ class EventBLL(object):
|
||||
now=now,
|
||||
iter=task_iteration.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
last_metrics=last_metrics.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
@@ -210,9 +207,7 @@ class EventBLL(object):
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
|
||||
def _update_task(
|
||||
self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None
|
||||
):
|
||||
def _update_task(self, company_id, task_id, now, iter=None, last_events=None):
|
||||
"""
|
||||
Update task information in DB with aggregated results after handling event(s) related to this task.
|
||||
|
||||
@@ -226,23 +221,13 @@ class EventBLL(object):
|
||||
fields["last_iteration"] = iter
|
||||
|
||||
if last_events:
|
||||
|
||||
def get_metric_event(ev):
|
||||
me = MetricEvent.from_dict(**ev)
|
||||
if "timestamp" in ev:
|
||||
me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000)
|
||||
return me
|
||||
|
||||
new_last_metrics = nested_dict(2, MetricEvent)
|
||||
new_last_metrics.update(last_metrics)
|
||||
|
||||
for metric_hash, variants in last_events.items():
|
||||
for variant_hash, event in variants.items():
|
||||
new_last_metrics[metric_hash][variant_hash] = get_metric_event(
|
||||
event
|
||||
)
|
||||
|
||||
fields["last_metrics"] = new_last_metrics.to_dict()
|
||||
fields["last_values"] = list(
|
||||
flatten_nested_items(
|
||||
last_events,
|
||||
nesting=2,
|
||||
include_leaves=["value", "metric", "variant"],
|
||||
)
|
||||
)
|
||||
|
||||
if not fields:
|
||||
return False
|
||||
@@ -270,7 +255,7 @@ class EventBLL(object):
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [], None, 0
|
||||
@@ -290,6 +275,125 @@ class EventBLL(object):
|
||||
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
|
||||
):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric"},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant"},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_term": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
return [
|
||||
(metric["key"], variant["key"], iter["key"])
|
||||
for metric in es_res["aggregations"]["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
for iter in variant["iters"]["buckets"]
|
||||
]
|
||||
|
||||
def get_task_plots(
|
||||
self,
|
||||
company_id: str,
|
||||
tasks: Sequence[str],
|
||||
last_iterations_per_plot: int = None,
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = "plot"
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
|
||||
for metric, variant, iter in last_iters:
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"term": {"iter": iter}},
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(tasks)
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
# scroll id may be missing when queering a totally empty DB
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
@@ -311,7 +415,7 @@ class EventBLL(object):
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
@@ -374,7 +478,7 @@ class EventBLL(object):
|
||||
|
||||
def get_metrics_and_variants(self, company_id, task_id, event_type):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
@@ -405,7 +509,7 @@ class EventBLL(object):
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id, task_id):
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
es_index = EventMetrics.get_index_name(company_id, "training_stats_scalar")
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
@@ -488,147 +592,9 @@ class EventBLL(object):
|
||||
metrics.append(metric_summary)
|
||||
return metrics, max_timestamp
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self, company_id, task_ids, allow_public=True
|
||||
):
|
||||
assert isinstance(task_ids, list)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
|
||||
"aggs": {
|
||||
"metric_and_variant": {
|
||||
"terms": {
|
||||
"script": "doc['metric'].value +'/'+ doc['variant'].value",
|
||||
"size": 10000,
|
||||
},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {"avg_val": {"avg": {"field": "value"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_comparison"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return
|
||||
|
||||
metrics = {}
|
||||
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
|
||||
iteration = int(iter_bucket["key"])
|
||||
for metric_bucket in iter_bucket["metric_and_variant"]["buckets"]:
|
||||
metric_name = metric_bucket["key"]
|
||||
if metrics.get(metric_name) is None:
|
||||
metrics[metric_name] = {}
|
||||
|
||||
metric_data = metrics[metric_name]
|
||||
for task_bucket in metric_bucket["tasks"]["buckets"]:
|
||||
task_id = task_bucket["key"]
|
||||
value = task_bucket["avg_val"]["value"]
|
||||
if metric_data.get(task_id) is None:
|
||||
metric_data[task_id] = {
|
||||
"x": [],
|
||||
"y": [],
|
||||
"name": task_name_by_id[
|
||||
task_id
|
||||
], # todo: lookup task name from id
|
||||
}
|
||||
metric_data[task_id]["x"].append(iteration)
|
||||
metric_data[task_id]["y"].append(value)
|
||||
|
||||
return metrics
|
||||
|
||||
def get_scalar_metrics_average_per_iter(self, company_id, task_id):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 200,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": 500,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {"avg_val": {"avg": {"field": "value"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
metrics = {}
|
||||
if "aggregations" in es_res:
|
||||
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
|
||||
iteration = int(iter_bucket["key"])
|
||||
for metric_bucket in iter_bucket["metrics"]["buckets"]:
|
||||
metric_name = metric_bucket["key"]
|
||||
if metrics.get(metric_name) is None:
|
||||
metrics[metric_name] = {}
|
||||
|
||||
metric_data = metrics[metric_name]
|
||||
for variant_bucket in metric_bucket["variants"]["buckets"]:
|
||||
variant = variant_bucket["key"]
|
||||
value = variant_bucket["avg_val"]["value"]
|
||||
if metric_data.get(variant) is None:
|
||||
metric_data[variant] = {"x": [], "y": [], "name": variant}
|
||||
metric_data[variant]["x"].append(iteration)
|
||||
metric_data[variant]["y"].append(value)
|
||||
return metrics
|
||||
|
||||
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_vector")
|
||||
es_index = EventMetrics.get_index_name(company_id, "training_stats_vector")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [], []
|
||||
|
||||
@@ -685,7 +651,7 @@ class EventBLL(object):
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
|
||||
def delete_task_events(self, company_id, task_id):
|
||||
es_index = EventBLL.get_index_name(company_id, "*")
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
@@ -693,8 +659,3 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return "events-%s-%s" % (event_type, company_id)
|
||||
|
||||
398
server/bll/event/event_metrics.py
Normal file
398
server/bll/event/event_metrics.py
Normal file
@@ -0,0 +1,398 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apierrors import errors
|
||||
from bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 100
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return f"events-{event_type}-{company_id}"
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
|
||||
return self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=[task_id],
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average,
|
||||
)
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
ret = self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average_per_task,
|
||||
)
|
||||
|
||||
for metric_data in ret.values():
|
||||
for variant_data in metric_data.values():
|
||||
for task_id, task_data in variant_data.items():
|
||||
task_data["name"] = task_name_by_id[task_id]
|
||||
|
||||
return ret
|
||||
|
||||
TaskMetric = Tuple[str, str, str]
|
||||
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
task_ids: Sequence[str],
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
get_func: Callable[
|
||||
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
|
||||
],
|
||||
) -> dict:
|
||||
"""
|
||||
Group metrics per interval length and execute get_func for each group in parallel
|
||||
:param company_id: id of the company
|
||||
:params task_ids: ids of the tasks to collect data for
|
||||
:param samples: maximum number of samples per metric
|
||||
:param get_func: callable that given metric names for the same interval
|
||||
performs histogram aggregation for the metrics and return the aggregated data
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
intervals = self._get_metric_intervals(
|
||||
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
|
||||
)
|
||||
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(len(intervals)) as pool:
|
||||
metrics = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
get_func, task_ids=task_ids, es_index=es_index, key=key
|
||||
),
|
||||
intervals,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
return ret
|
||||
|
||||
def _get_metric_intervals(
|
||||
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return metric variants grouped by interval value with 10% rounding
|
||||
For samples==0 return empty list
|
||||
"""
|
||||
default_intervals = [(1, [])]
|
||||
if not samples:
|
||||
return default_intervals
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return default_intervals
|
||||
|
||||
intervals = [
|
||||
(
|
||||
task["key"],
|
||||
metric["key"],
|
||||
variant["key"],
|
||||
self._calculate_metric_interval(variant, samples),
|
||||
)
|
||||
for task in aggs_result["tasks"]["buckets"]
|
||||
for metric in task["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
metric_intervals = []
|
||||
upper_border = 0
|
||||
interval_metrics = None
|
||||
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
|
||||
if not interval_metrics or interval > upper_border:
|
||||
interval_metrics = []
|
||||
metric_intervals.append((interval, interval_metrics))
|
||||
upper_border = interval + int(interval * 0.1)
|
||||
interval_metrics.append((task, metric, variant))
|
||||
|
||||
return metric_intervals
|
||||
|
||||
@staticmethod
|
||||
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
"""
|
||||
count = safe_get(metric_variant, "count/value")
|
||||
if not count or count < samples:
|
||||
return 1
|
||||
|
||||
min_index = safe_get(metric_variant, "min_index/value", default=0)
|
||||
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
|
||||
return max(1, int(max_index - min_index + 1) // samples)
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
Note: the function works with a single task only
|
||||
"""
|
||||
|
||||
assert len(task_ids) == 1
|
||||
interval, task_metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
"name": variant["key"],
|
||||
**key.get_iterations_data(variant),
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
def _get_scalar_average_per_task(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, task_metrics = metrics_interval
|
||||
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {"terms": {"field": "task"}, "aggs": aggregation}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
task["key"]: key.get_iterations_data(task)
|
||||
for task in variant["tasks"]["buckets"]
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
return {
|
||||
key: {**value, "aggs": {**value.get("aggs", {}), **average_agg}}
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_metrics_and_tasks(
|
||||
self,
|
||||
es_index: str,
|
||||
aggs: dict,
|
||||
task_ids: Sequence[str],
|
||||
task_metrics: Sequence[TaskMetric],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
if task_metrics:
|
||||
condition = {
|
||||
"should": [
|
||||
self._build_metric_terms(task, metric, variant)
|
||||
for task, metric, variant in task_metrics
|
||||
]
|
||||
}
|
||||
else:
|
||||
condition = {"must": [{"terms": {"task": task_ids}}]}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"bool": condition},
|
||||
"aggs": aggs,
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
|
||||
"""
|
||||
Build query term for a metric + variant
|
||||
"""
|
||||
return {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
161
server/bll/event/scalar_key.py
Normal file
161
server/bll/event/scalar_key.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Module for polymorphism over different types of X axes in scalar aggregations
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from apimodels import StringEnum
|
||||
from bll.util import extract_properties_to_lists
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ScalarKeyEnum(StringEnum):
|
||||
"""
|
||||
String enum representing X axes key
|
||||
"""
|
||||
|
||||
iter = auto()
|
||||
timestamp = auto()
|
||||
iso_time = auto()
|
||||
|
||||
|
||||
class ScalarKey(ABC):
|
||||
"""
|
||||
Abstract scalar key
|
||||
"""
|
||||
|
||||
_enum_to_key = {}
|
||||
bucket_key_key = "key"
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def enum_value(self) -> ScalarKeyEnum:
|
||||
"""
|
||||
Enum value accepted in API requests
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Key name. Used as arbitrary internal key in elasticsearch queries
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def field(self) -> str:
|
||||
"""
|
||||
Event key to aggregate by
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
"""
|
||||
Get aggregation for this type of key
|
||||
:param interval: elasticsearch aggregation interval
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Save a mapping from enum values to key class
|
||||
"""
|
||||
if cls.enum_value not in ScalarKeyEnum:
|
||||
raise ValueError(f"{cls.enum_value!r} not in {ScalarKeyEnum.__name__}")
|
||||
if cls.enum_value in cls._enum_to_key:
|
||||
log.warning(
|
||||
f"'{cls.enum_value.value}' is already registered to {ScalarKey.__name__}"
|
||||
)
|
||||
cls._enum_to_key[cls.enum_value] = cls
|
||||
|
||||
@classmethod
|
||||
def resolve(cls, key: ScalarKeyEnum):
|
||||
"""
|
||||
Create a key instance from enum instance
|
||||
"""
|
||||
return cls._enum_to_key[key]()
|
||||
|
||||
def get_iterations_data(self, iter_buckets: dict) -> dict:
|
||||
"""
|
||||
Convert a list of bucket entries to `x`s array and `y`s array
|
||||
"""
|
||||
return extract_properties_to_lists(
|
||||
("x", "y"),
|
||||
iter_buckets[self.name]["buckets"],
|
||||
self._get_iterations_data_single,
|
||||
)
|
||||
|
||||
def _get_iterations_data_single(self, iter_data):
|
||||
"""
|
||||
Extract x value and y value from a single bucket item
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by timestamp in milliseconds since epoch
|
||||
"""
|
||||
|
||||
name = "timestamp"
|
||||
field = "timestamp"
|
||||
enum_value = ScalarKeyEnum.timestamp
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": interval,
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by iteration number
|
||||
"""
|
||||
|
||||
name = "iters"
|
||||
field = "iter"
|
||||
enum_value = ScalarKeyEnum.iter
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"histogram": {"field": "iter", "interval": interval, "min_doc_count": 1}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
Aggregate by time formatted as ISO strings
|
||||
"""
|
||||
|
||||
name = "iso_time"
|
||||
field = "timestamp"
|
||||
enum_value = ScalarKeyEnum.iso_time
|
||||
bucket_key_key = "key_as_string"
|
||||
|
||||
def get_aggregation(self, interval: int) -> dict:
|
||||
return {
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": interval,
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _get_iterations_data_single(self, iter_data):
|
||||
return iter_data[self.bucket_key_key], iter_data["avg_val"]["value"]
|
||||
@@ -2,8 +2,7 @@ import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from time import sleep
|
||||
from typing import Mapping, Collection
|
||||
from urllib.parse import urlparse
|
||||
from typing import Collection, Sequence, Tuple, Any
|
||||
|
||||
import six
|
||||
from mongoengine import Q
|
||||
@@ -13,12 +12,15 @@ import es_factory
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.fields import OutputDestinationField
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.metrics import MetricEvent
|
||||
from database.model.task.output import Output
|
||||
from database.model.task.task import Task, TaskStatus, TaskStatusMessage, TaskTags
|
||||
from database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
@@ -143,7 +145,7 @@ class TaskBLL(object):
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task, force=False):
|
||||
def validate(cls, task: Task):
|
||||
assert isinstance(task, Task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
@@ -154,24 +156,12 @@ class TaskBLL(object):
|
||||
if task.project:
|
||||
Project.get_for_writing(company=task.company, id=task.project)
|
||||
|
||||
model = cls.validate_execution_model(task)
|
||||
if model and not force and not model.ready:
|
||||
raise errors.bad_request.ModelNotReady(
|
||||
"can't be used in a task", model=model.id
|
||||
)
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
if task.execution:
|
||||
if task.execution.parameters:
|
||||
cls._validate_execution_parameters(task.execution.parameters)
|
||||
|
||||
if task.output and task.output.destination:
|
||||
parsed_url = urlparse(task.output.destination)
|
||||
if parsed_url.scheme not in OutputDestinationField.schemes:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported scheme for output destination",
|
||||
dest=task.output.destination,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_execution_parameters(parameters):
|
||||
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
|
||||
@@ -236,7 +226,7 @@ class TaskBLL(object):
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_metrics: Mapping[str, Mapping[str, MetricEvent]] = None,
|
||||
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
"""
|
||||
@@ -248,7 +238,7 @@ class TaskBLL(object):
|
||||
task's last iteration value.
|
||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||
if the current task's last iteration value is smaller than the provided value.
|
||||
:param last_metrics: Last reported metrics summary.
|
||||
:param last_values: Last reported metrics summary (value, metric, variant).
|
||||
:param extra_updates: Extra task updates to include in this update call.
|
||||
:return:
|
||||
"""
|
||||
@@ -259,10 +249,18 @@ class TaskBLL(object):
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_metrics is not None:
|
||||
extra_updates.update(last_metrics=last_metrics)
|
||||
if last_values is not None:
|
||||
|
||||
return Task.objects(id=task_id, company=company_id).update(
|
||||
def op_path(op, *path):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_values:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
if path[-1] == "value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
|
||||
Task.objects(id=task_id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **extra_updates
|
||||
)
|
||||
|
||||
@@ -378,11 +376,11 @@ class TaskBLL(object):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=("status", "project", "tags", "last_update"),
|
||||
only=("status", "project", "tags", "system_tags", "last_update"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
if TaskTags.development in task.tags:
|
||||
if TaskSystemTags.development in task.system_tags:
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
@@ -448,3 +446,55 @@ class TaskBLL(object):
|
||||
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed stopping tasks: {str(ex)}")
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_execution_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[str]]:
|
||||
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": company_id,
|
||||
"execution.parameters": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
}
|
||||
},
|
||||
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
|
||||
{"$unwind": "$parameters"},
|
||||
{"$group": {"_id": "$parameters.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"total": 1,
|
||||
"results": {"$slice": ["$results", page * page_size, page_size]},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.objects.aggregate(*pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [r["_id"] for r in result.get("results", [])]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
@@ -66,6 +66,10 @@ class ChangeStatusRequest(object):
|
||||
)
|
||||
|
||||
update_project_time(project_id)
|
||||
|
||||
# make sure that _raw_ queries are not returned back to the client
|
||||
fields.pop("__raw__", None)
|
||||
|
||||
return dict(updated=updated, fields=fields)
|
||||
|
||||
def validate_transition(self, current_status):
|
||||
@@ -135,9 +139,11 @@ def get_possible_status_changes(current_status):
|
||||
:return possible states from current state
|
||||
"""
|
||||
possible = state_machine.get(current_status)
|
||||
assert (
|
||||
possible is not None
|
||||
), f"Current status {current_status} not supported by state machine"
|
||||
if possible is None:
|
||||
raise errors.server_error.InternalError(
|
||||
f"Current status {current_status} not supported by state machine"
|
||||
)
|
||||
|
||||
return possible
|
||||
|
||||
|
||||
|
||||
20
server/bll/util.py
Normal file
20
server/bll/util.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Callable, Tuple
|
||||
|
||||
|
||||
def extract_properties_to_lists(
|
||||
key_names: Sequence[str],
|
||||
data: Sequence[dict],
|
||||
extract_func: Optional[Callable[[dict], Tuple]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a list of dictionaries and names of dictionary keys
|
||||
builds a dictionary with the requested keys and values lists
|
||||
:param key_names: names of the keys in the resulting dictionary
|
||||
:param data: sequence of dictionaries to extract values from
|
||||
:param extract_func: the optional callable that extracts properties
|
||||
from a dictionary and put them in a tuple in the order corresponding to
|
||||
key_names. If not specified then properties are extracted according to key_names
|
||||
"""
|
||||
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
|
||||
return dict(zip(key_names, map(list, value_sequences)))
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
@@ -46,6 +50,20 @@ class BasicConfig:
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_extra_env_config_values(self):
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
@@ -64,12 +82,17 @@ class BasicConfig:
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
|
||||
self._config = reduce(
|
||||
lambda config, path: ConfigTree.merge_configs(
|
||||
config, self._read_recursive(path, verbose=verbose), copy_trees=True
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
[self.folder] + extra_config_paths,
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
version {
|
||||
required: false
|
||||
default: 1.0
|
||||
# if set then calls to endpoints with the version
|
||||
# greater that the current max version will be rejected
|
||||
check_max_version: false
|
||||
}
|
||||
|
||||
mongo {
|
||||
|
||||
@@ -18,3 +18,11 @@ def get_version():
|
||||
return (root / "VERSION").read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_commit_number():
|
||||
try:
|
||||
return (root / "COMMIT").read_text().strip()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
@@ -11,11 +11,14 @@ from config import config
|
||||
from .defs import Database
|
||||
from .utils import get_items
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
log = config.logger("database")
|
||||
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = "MONGODB_SERVICE_SERVICE_HOST"
|
||||
OVERRIDE_HOST_ENV_KEY = ("MONGODB_SERVICE_HOST", "MONGODB_SERVICE_SERVICE_HOST")
|
||||
OVERRIDE_PORT_ENV_KEY = "MONGODB_SERVICE_PORT"
|
||||
|
||||
_entries = []
|
||||
|
||||
@@ -34,19 +37,27 @@ def initialize():
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_hostname = getenv(OVERRIDE_HOST_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = getenv(OVERRIDE_PORT_ENV_KEY)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from operator import itemgetter
|
||||
from sys import maxsize
|
||||
from typing import Type, Tuple
|
||||
|
||||
import six
|
||||
from mongoengine import (
|
||||
@@ -11,6 +12,7 @@ from mongoengine import (
|
||||
SortedListField,
|
||||
MapField,
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,104 +90,6 @@ class CustomFloatField(FloatField):
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
# TODO: bucket name should be at most 63 characters....
|
||||
aws_s3_bucket_only_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
)
|
||||
|
||||
aws_s3_url_with_bucket_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
|
||||
)
|
||||
|
||||
non_aws_s3_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
|
||||
)
|
||||
|
||||
google_gs_bucket_only_regex = (
|
||||
r"^gs://"
|
||||
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
|
||||
)
|
||||
|
||||
file_regex = r"^file://"
|
||||
|
||||
generic_url_regex = (
|
||||
r"^%s://" # scheme placeholder
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
)
|
||||
|
||||
path_suffix = r"(?:/?|[/?]\S+)$"
|
||||
file_path_suffix = r"(?:/\S*[^/]+)$"
|
||||
|
||||
|
||||
class _RegexURLField(StringField):
|
||||
_regex = []
|
||||
|
||||
def __init__(self, regex, **kwargs):
|
||||
super(_RegexURLField, self).__init__(**kwargs)
|
||||
regex = regex if isinstance(regex, (tuple, list)) else [regex]
|
||||
self._regex = [
|
||||
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
|
||||
for e in regex
|
||||
]
|
||||
|
||||
def validate(self, value):
|
||||
# Check first if the scheme is valid
|
||||
if not any(regex for regex in self._regex if regex.match(value)):
|
||||
self.error("Invalid URL: {}".format(value))
|
||||
return
|
||||
|
||||
|
||||
class OutputDestinationField(_RegexURLField):
|
||||
""" A field representing task output URL """
|
||||
|
||||
schemes = ["s3", "gs", "file"]
|
||||
_expressions = (
|
||||
aws_s3_bucket_only_regex + path_suffix,
|
||||
aws_s3_url_with_bucket_regex + path_suffix,
|
||||
non_aws_s3_regex + path_suffix,
|
||||
google_gs_bucket_only_regex + path_suffix,
|
||||
file_regex + path_suffix,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class SupportedURLField(_RegexURLField):
|
||||
""" A field representing a model URL """
|
||||
|
||||
schemes = ["s3", "gs", "file", "http", "https"]
|
||||
|
||||
_expressions = tuple(
|
||||
pattern + file_path_suffix
|
||||
for pattern in (
|
||||
aws_s3_bucket_only_regex,
|
||||
aws_s3_url_with_bucket_regex,
|
||||
non_aws_s3_regex,
|
||||
google_gs_bucket_only_regex,
|
||||
file_regex,
|
||||
(generic_url_regex % "http"),
|
||||
(generic_url_regex % "https"),
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
def __init__(
|
||||
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
|
||||
@@ -235,3 +139,42 @@ class SafeDictField(DictField):
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
|
||||
|
||||
class SafeSortedListField(SortedListField):
|
||||
"""
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
except TypeError:
|
||||
return self._safe_to_mongo(*args, **kwargs)
|
||||
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
def key(v): return str(itemgetter(self._ordering)(v))
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
|
||||
class UnionField(DynamicField):
|
||||
def __init__(self, types, *args, **kwargs):
|
||||
super(UnionField, self).__init__(*args, **kwargs)
|
||||
self.types: Tuple[Type] = tuple(types)
|
||||
|
||||
def validate(self, value, clean=True):
|
||||
if not isinstance(value, self.types):
|
||||
type_names = [t.__name__ for t in self.types]
|
||||
expected = " or ".join(
|
||||
filter(
|
||||
None,
|
||||
(", ".join(type_names[:-1]), type_names[-1]))
|
||||
)
|
||||
self.error(
|
||||
f"Expected {expected}, got {type(value).__name__}: {value}"
|
||||
)
|
||||
super(UnionField, self).validate(value, clean)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
@@ -54,3 +56,7 @@ def validate_id(cls, company, **kwargs):
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
|
||||
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
|
||||
@@ -45,6 +45,7 @@ class Role(object):
|
||||
class Credentials(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
last_used = DateTimeField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection
|
||||
from typing import Collection, Sequence
|
||||
|
||||
from boltons.iterutils import first
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document
|
||||
from six import string_types
|
||||
@@ -13,7 +14,12 @@ from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
from database.props import PropsMixin
|
||||
from database.query import RegexQ, RegexWrapper
|
||||
from database.utils import get_company_or_none_constraint, get_fields_with_attr
|
||||
from database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_with_attr,
|
||||
field_exists,
|
||||
field_does_not_exist,
|
||||
)
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@@ -68,7 +74,7 @@ class GetMixin(PropsMixin):
|
||||
def __init__(
|
||||
self,
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "id"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
datetime_fields=None,
|
||||
fields=None,
|
||||
):
|
||||
@@ -261,6 +267,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
override_none_ordering=False,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
@@ -296,6 +303,7 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
override_none_ordering=override_none_ordering,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
@@ -320,6 +328,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
override_none_ordering=False,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
@@ -343,6 +352,8 @@ class GetMixin(PropsMixin):
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
:param allow_public: If True, objects marked as public (no associated company) are also queried.
|
||||
:param override_none_ordering: If True, then items with the None values in the first ordered field
|
||||
are always sorted in the end
|
||||
:return: A list of objects matching the query.
|
||||
"""
|
||||
if query_dict is not None:
|
||||
@@ -356,6 +367,15 @@ class GetMixin(PropsMixin):
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if override_none_ordering:
|
||||
return cls._get_many_override_none_ordering(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
query_dict=query_dict,
|
||||
query_options=query_options,
|
||||
override_projection=override_projection,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
@@ -428,6 +448,105 @@ class GetMixin(PropsMixin):
|
||||
return [obj.to_proper_dict(only=only) for obj in qs]
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def _get_many_override_none_ordering(
|
||||
cls,
|
||||
query: Q = None,
|
||||
parameters: dict = None,
|
||||
query_dict: dict = None,
|
||||
query_options: QueryParameterOptions = None,
|
||||
override_projection: Collection[str] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
the None values are sorted in the end regardless of the sorting order.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
|
||||
a query. The resulting query is AND'ed with the `query` parameter (if provided).
|
||||
:param query_options: query parameters options (see ParametersOptions)
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get("search_text")
|
||||
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
|
||||
query_sets = []
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if order_by:
|
||||
order_by = order_by if isinstance(order_by, list) else [order_by]
|
||||
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
||||
if not search_text and cls._text_score in order_by:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
order_field = first(
|
||||
field for field in order_by if not field.startswith("$")
|
||||
)
|
||||
if (
|
||||
order_field
|
||||
and not order_field.startswith("-")
|
||||
and (not query_dict or order_field not in query_dict)
|
||||
):
|
||||
empty_value = None
|
||||
if order_field in query_options.list_fields:
|
||||
empty_value = []
|
||||
elif order_field in query_options.pattern_fields:
|
||||
empty_value = ""
|
||||
mongo_field = order_field.replace(".", "__")
|
||||
non_empty = query & field_exists(mongo_field, empty_value=empty_value)
|
||||
empty = query & field_does_not_exist(
|
||||
mongo_field, empty_value=empty_value
|
||||
)
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
if not query_sets:
|
||||
query_sets = [cls.objects(query)]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if order_by:
|
||||
# add ordering
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
if only:
|
||||
# add projection
|
||||
query_sets = [qs.only(*only) for qs in query_sets]
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields())
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_for_writing(
|
||||
cls, *args, _only: Collection[str] = None, **kwargs
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
|
||||
from database.fields import StrippedStringField, SafeDictField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
@@ -48,7 +48,8 @@ class Model(DbModelMixin, Document):
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = SupportedURLField(default='', user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default='', user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import OutputDestinationField, StrippedStringField
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
@@ -9,7 +9,8 @@ from database.model.base import GetMixin
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "id")
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
@@ -34,6 +35,7 @@ class Project(AttributedDocument):
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list)
|
||||
default_output_destination = OutputDestinationField()
|
||||
tags = ListField(StringField(required=True))
|
||||
system_tags = ListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
|
||||
from mongoengine import EmbeddedDocument, StringField, DynamicField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
metric = StringField(required=True, )
|
||||
variant = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
timestamp = DateTimeField(default=0, required=True)
|
||||
iter = LongField()
|
||||
value = DynamicField(required=True)
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
'strict': False,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs):
|
||||
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})
|
||||
metric = StringField(required=True)
|
||||
variant = StringField(required=True)
|
||||
value = DynamicField(required=True)
|
||||
min_value = DynamicField() # for backwards compatibility reasons
|
||||
max_value = DynamicField() # for backwards compatibility reasons
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import EmbeddedDocument, StringField
|
||||
from database.utils import get_options
|
||||
|
||||
from database.fields import OutputDestinationField
|
||||
from database.fields import StrippedStringField
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class Result(object):
|
||||
@@ -10,7 +10,7 @@ class Result(object):
|
||||
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = OutputDestinationField()
|
||||
destination = StrippedStringField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
@@ -7,10 +5,18 @@ from mongoengine import (
|
||||
DateTimeField,
|
||||
IntField,
|
||||
ListField,
|
||||
LongField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeMapField, SafeDictField
|
||||
from database.fields import (
|
||||
StrippedStringField,
|
||||
SafeMapField,
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
@@ -22,27 +28,27 @@ DEFAULT_LAST_ITERATION = 0
|
||||
|
||||
|
||||
class TaskStatus(object):
|
||||
created = 'created'
|
||||
in_progress = 'in_progress'
|
||||
stopped = 'stopped'
|
||||
publishing = 'publishing'
|
||||
published = 'published'
|
||||
closed = 'closed'
|
||||
failed = 'failed'
|
||||
completed = 'completed'
|
||||
unknown = 'unknown'
|
||||
created = "created"
|
||||
in_progress = "in_progress"
|
||||
stopped = "stopped"
|
||||
publishing = "publishing"
|
||||
published = "published"
|
||||
closed = "closed"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class TaskStatusMessage(object):
|
||||
stopping = 'stopping'
|
||||
stopping = "stopping"
|
||||
|
||||
|
||||
class TaskTags(object):
|
||||
development = 'development'
|
||||
class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
binary = StringField(default='python')
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(required=True)
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
@@ -53,51 +59,70 @@ class Script(EmbeddedDocument):
|
||||
diff = StringField()
|
||||
|
||||
|
||||
class ArtifactTypeData(EmbeddedDocument):
|
||||
preview = StringField()
|
||||
content_type = StringField()
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=("input", "output"), default="output")
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
timestamp = LongField()
|
||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument):
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field='Model')
|
||||
model_desc = SafeMapField(StringField(default=''))
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
artifacts = EmbeddedDocumentSortedListField(Artifact)
|
||||
|
||||
queue = StringField()
|
||||
''' Queue ID where task was queued '''
|
||||
""" Queue ID where task was queued """
|
||||
|
||||
|
||||
class TaskType(object):
|
||||
training = 'training'
|
||||
testing = 'testing'
|
||||
training = "training"
|
||||
testing = "testing"
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
'created',
|
||||
'started',
|
||||
'completed',
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"created",
|
||||
"started",
|
||||
"completed",
|
||||
{
|
||||
'name': '%s.task.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$execution.model',
|
||||
'$output.model',
|
||||
'$script.repository',
|
||||
'$script.entry_point',
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$execution.model",
|
||||
"$output.model",
|
||||
"$script.repository",
|
||||
"$script.entry_point",
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'execution.model': 2,
|
||||
'output.model': 2,
|
||||
'script.repository': 1,
|
||||
'script.entry_point': 1,
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"execution.model": 2,
|
||||
"output.model": 2,
|
||||
"script.repository": 1,
|
||||
"script.entry_point": 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
@@ -123,12 +148,8 @@ class Task(AttributedDocument):
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
|
||||
|
||||
class TaskVisibility(Enum):
|
||||
active = 'active'
|
||||
archived = 'archived'
|
||||
|
||||
18
server/database/model/version.py
Normal file
18
server/database/model/version.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from mongoengine import Document, DateTimeField, StringField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class Version(DbModelMixin, Document):
|
||||
meta = {
|
||||
"collection": "versions", # custom collection name ('version' is not a proper collection name...)
|
||||
"db_alias": Database.backend, # although we'll use this model for all databases, a default must be defined
|
||||
"strict": strict,
|
||||
"indexes": [("-created", "-num")],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
num = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
desc = StringField()
|
||||
@@ -1,13 +1,17 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
|
||||
import dpath
|
||||
import dpath.path
|
||||
|
||||
from apierrors import errors
|
||||
from database.props import PropsMixin
|
||||
|
||||
SEP = "."
|
||||
|
||||
def project_dict(data, projection, separator='.'):
|
||||
|
||||
def project_dict(data, projection, separator=SEP):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
@@ -30,19 +34,27 @@ def project_dict(data, projection, separator='.'):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError('Incompatible destination type %s for %s (list expected)'
|
||||
% (type(dst), separator.join(path_parts[:depth + 1])))
|
||||
raise TypeError(
|
||||
"Incompatible destination type %s for %s (list expected)"
|
||||
% (type(dst), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError('Destination list length differs from source length for %s'
|
||||
% separator.join(path_parts[:depth + 1]))
|
||||
raise ValueError(
|
||||
"Destination list length differs from source length for %s"
|
||||
% separator.join(path_parts[: depth + 1])
|
||||
)
|
||||
|
||||
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])]
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1:], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError('Unsupported projection type %s for %s'
|
||||
% (type(src), separator.join(path_parts[:depth + 1])))
|
||||
raise TypeError(
|
||||
"Unsupported projection type %s for %s"
|
||||
% (type(src), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
@@ -53,12 +65,35 @@ def project_dict(data, projection, separator='.'):
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator),
|
||||
source=data,
|
||||
destination=result)
|
||||
path_parts=projection_path.split(separator), source=data, destination=result
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _ReferenceProxy(dict):
|
||||
def __init__(self, id):
|
||||
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
|
||||
|
||||
|
||||
class _ProxyManager:
|
||||
lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._proxies: Dict[str, _ReferenceProxy] = {}
|
||||
|
||||
def add(self, id):
|
||||
with self.lock:
|
||||
proxy = self._proxies.get(id)
|
||||
if proxy is None:
|
||||
proxy = self._proxies[id] = _ReferenceProxy(id)
|
||||
return proxy
|
||||
|
||||
def update(self, result):
|
||||
proxy = self._proxies.get(result.get("id"))
|
||||
if proxy is not None:
|
||||
proxy.update(result)
|
||||
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
@@ -72,6 +107,11 @@ class ProjectionHelper(object):
|
||||
self._doc_cls = doc_cls
|
||||
self._doc_projection = None
|
||||
self._ref_projection = None
|
||||
self._proxy_manager = _ProxyManager()
|
||||
|
||||
# Cached dpath paths for each of the result documents
|
||||
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
|
||||
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
@@ -81,8 +121,12 @@ class ProjectionHelper(object):
|
||||
:param projection: List of projection fields
|
||||
:return: A tuple of document projection and reference fields information
|
||||
"""
|
||||
doc_projection = set() # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = [] # Projection information for reference fields (used in join queries)
|
||||
doc_projection = (
|
||||
set()
|
||||
) # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = (
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
@@ -93,7 +137,7 @@ class ProjectionHelper(object):
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
if not subfield.startswith('.'):
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
@@ -103,10 +147,12 @@ class ProjectionHelper(object):
|
||||
# Not a reference field, just add to the top-level projection
|
||||
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
|
||||
orig_field = field
|
||||
if field.endswith('.*'):
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
doc_projection.add(field)
|
||||
return doc_projection, ref_projection_info
|
||||
|
||||
@@ -124,12 +170,14 @@ class ProjectionHelper(object):
|
||||
if not projection:
|
||||
return [], {}
|
||||
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(
|
||||
doc_cls, projection
|
||||
)
|
||||
|
||||
def normalize_cls_projection(cls_, fields):
|
||||
""" Normalize projection for this class and group (expand *, for once) """
|
||||
if '*' in fields:
|
||||
return list(fields.difference('*').union(cls_.get_fields()))
|
||||
if "*" in fields:
|
||||
return list(fields.difference("*").union(cls_.get_fields()))
|
||||
return list(fields)
|
||||
|
||||
def compute_ref_cls_projection(cls_, group):
|
||||
@@ -143,12 +191,16 @@ class ProjectionHelper(object):
|
||||
# Aggregate by reference field. We'll leave out '*' from the projected items since
|
||||
ref_projection = {
|
||||
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
|
||||
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
|
||||
for (ref_field, ref_cls), g in groupby(
|
||||
sorted(ref_projection_info, key=sort_key), sort_key
|
||||
)
|
||||
}
|
||||
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection).union({"id"})
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
|
||||
@@ -158,13 +210,20 @@ class ProjectionHelper(object):
|
||||
doc_projection = [
|
||||
field
|
||||
for field in doc_projection
|
||||
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
|
||||
if not any(
|
||||
field.startswith(f"{other_field}.")
|
||||
for other_field in projection_set - {field}
|
||||
)
|
||||
]
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
|
||||
invalid_fields = [
|
||||
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
|
||||
raise errors.bad_request.InvalidFields(
|
||||
fields=invalid_fields, object=doc_cls.__name__
|
||||
)
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - use both normal projection fields and top-level reference fields
|
||||
@@ -178,11 +237,44 @@ class ProjectionHelper(object):
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@staticmethod
|
||||
def _search(doc_cls, obj, path, only_values=True):
|
||||
""" Call dpath.search with yielded=True, collect result values """
|
||||
def _search(
|
||||
self,
|
||||
doc_cls: PropsMixin,
|
||||
obj: dict,
|
||||
path: str,
|
||||
factory: Callable[[str], dict] = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Search for a path in the given object, return the list of values found for the
|
||||
given path (multiple values may exist if the path is a glob expression)
|
||||
:param doc_cls: The document class represented by the object
|
||||
:param obj: Data object
|
||||
:param path: Path to a leaf in the data object ("." separated, may contain "*")
|
||||
(in case the path contains "*", there may be multiple values)
|
||||
:param factory: If provided, replace each value found with an instance provided by the factory.
|
||||
"""
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
|
||||
globlist = norm_path.strip(SEP).split(SEP)
|
||||
|
||||
obj_paths = self._cached_results_paths.get(id(obj))
|
||||
if obj_paths is None:
|
||||
obj_paths = self._cached_results_paths[id(obj)] = list(
|
||||
dpath.path.paths(obj, dirs=True, skip=True)
|
||||
)
|
||||
|
||||
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
|
||||
|
||||
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
|
||||
parent = None
|
||||
target = obj
|
||||
for part in p:
|
||||
parent = target
|
||||
target = target[part[0]]
|
||||
if parent and factory:
|
||||
parent[p[-1][0]] = factory(target)
|
||||
return target
|
||||
|
||||
return [search_and_replace(p) for p in paths]
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
@@ -197,28 +289,50 @@ class ProjectionHelper(object):
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - get results for each reference fields projection required (this is the join step)
|
||||
# Note: this is a recursive step, so we support nested reference fields
|
||||
# Note: this is a recursive step, so nested reference fields are supported
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data = item
|
||||
res = {}
|
||||
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
|
||||
for res in results))))
|
||||
if ids:
|
||||
doc_type = data['cls']
|
||||
doc_only = list(filter(None, data['only']))
|
||||
doc_only = list({'id'} | set(doc_only)) if doc_only else None
|
||||
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
|
||||
data['res'] = res
|
||||
def collect_ids(ref_field_name):
|
||||
"""
|
||||
Collect unique IDs for the given reference path from all result documents.
|
||||
All collected IDs are replaced in the result dictionaries with a reference proxy generated by the
|
||||
proxies manager to allow rapid update later on when projection results are obtained.
|
||||
"""
|
||||
all_ids = (
|
||||
self._search(
|
||||
cls, res, ref_field_name, factory=self._proxy_manager.add
|
||||
)
|
||||
for res in results
|
||||
)
|
||||
return list(filter(None, set(chain.from_iterable(all_ids))))
|
||||
|
||||
items = list(ref_projection.items())
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
items = [
|
||||
tup
|
||||
for tup in (
|
||||
(*item, collect_ids(item[0])) for item in ref_projection.items()
|
||||
)
|
||||
if tup[2]
|
||||
]
|
||||
|
||||
if items:
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
doc_type = data["cls"]
|
||||
doc_only = list(filter(None, data["only"]))
|
||||
doc_only = list({"id"} | set(doc_only)) if doc_only else None
|
||||
|
||||
for res in projection_func(
|
||||
doc_type=doc_type, projection=doc_only, ids=ids
|
||||
):
|
||||
self._proxy_manager.update(res)
|
||||
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
|
||||
def do_expand_reference_ids(result, skip_fields=None):
|
||||
ref_fields = cls.get_reference_fields()
|
||||
@@ -226,44 +340,18 @@ class ProjectionHelper(object):
|
||||
ref_fields = set(ref_fields) - set(skip_fields)
|
||||
self._expand_reference_fields(cls, result, ref_fields)
|
||||
|
||||
def merge_projection_result(result):
|
||||
for ref_field_name, data in ref_projection.items():
|
||||
res = data.get('res')
|
||||
if not res:
|
||||
self._expand_reference_fields(cls, result, [ref_field_name])
|
||||
continue
|
||||
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
obj = res.get(value) or {'id': value}
|
||||
dpath.new(result, path, obj, separator='.')
|
||||
|
||||
# any reference field not projected should be expanded
|
||||
do_expand_reference_ids(result, skip_fields=list(ref_projection))
|
||||
|
||||
update_func = merge_projection_result if ref_projection else \
|
||||
do_expand_reference_ids if self._should_expand_reference_ids else None
|
||||
|
||||
if update_func:
|
||||
# any reference field not projected should be expanded
|
||||
if self._should_expand_reference_ids:
|
||||
for result in results:
|
||||
update_func(result)
|
||||
do_expand_reference_ids(
|
||||
result, skip_fields=list(ref_projection) if ref_projection else None
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _expand_reference_fields(cls, doc_cls, result, fields):
|
||||
def _expand_reference_fields(self, doc_cls, result, fields):
|
||||
for ref_field_name in fields:
|
||||
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
dpath.set(
|
||||
result,
|
||||
path,
|
||||
{'id': value} if value else {},
|
||||
separator='.')
|
||||
self._search(doc_cls, result, ref_field_name, factory=_ReferenceProxy)
|
||||
|
||||
@classmethod
|
||||
def expand_reference_ids(cls, doc_cls, result):
|
||||
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
def expand_reference_ids(self, doc_cls, result):
|
||||
self._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from typing import Sequence, Tuple, Set, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
@@ -12,9 +13,13 @@ def get_fields(cls, of_type=BaseField, return_instance=False):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = []
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.extend([k if not return_instance else (k, v)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, of_type)])
|
||||
res.extend(
|
||||
[
|
||||
k if not return_instance else (k, v)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, of_type)
|
||||
]
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
@@ -22,9 +27,13 @@ def get_fields_and_attr(cls, attr):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = {}
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.update({k: getattr(v, attr)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, BaseField) and hasattr(v, attr)})
|
||||
res.update(
|
||||
{
|
||||
k: getattr(v, attr)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, BaseField) and hasattr(v, attr)
|
||||
}
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
@@ -33,7 +42,7 @@ def _get_field_choices(name, field):
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
obj = field.document_type_obj
|
||||
n, choices = _get_field_choices(field.name, obj.field)
|
||||
return '%s__%s' % (name, n), choices
|
||||
return "%s__%s" % (name, n), choices
|
||||
elif issubclass(type(field), ListField):
|
||||
return name, field.field.choices
|
||||
return name, field.choices
|
||||
@@ -46,8 +55,14 @@ def get_fields_with_attr(cls, attr, default=False):
|
||||
continue
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
fields.extend((('%s__%s' % (field_name, name), choices)
|
||||
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
|
||||
fields.extend(
|
||||
(
|
||||
("%s__%s" % (field_name, name), choices)
|
||||
for name, choices in get_fields_with_attr(
|
||||
field.document_type, attr, default
|
||||
)
|
||||
)
|
||||
)
|
||||
elif issubclass(type(field), ListField):
|
||||
fields.append((field_name, field.field.choices))
|
||||
else:
|
||||
@@ -58,11 +73,7 @@ def get_fields_with_attr(cls, attr, default=False):
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
|
||||
res = {
|
||||
k: v
|
||||
for k, v in getmembers(cls)
|
||||
if not (k.startswith("_") or ismethod(v))
|
||||
}
|
||||
res = {k: v for k, v in getmembers(cls) if not (k.startswith("_") or ismethod(v))}
|
||||
return res
|
||||
|
||||
|
||||
@@ -81,7 +92,7 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
fields = {k: None for k in fields}
|
||||
fields = {k: v for k, v in fields.items() if k in cls_fields}
|
||||
res = {}
|
||||
with translate_errors_context('parsing call data'):
|
||||
with translate_errors_context("parsing call data"):
|
||||
for field, desc in fields.items():
|
||||
value = call_data.get(field)
|
||||
if value is None:
|
||||
@@ -93,20 +104,34 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
if callable(desc):
|
||||
desc(value)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
|
||||
raise ParseCallError('expecting %s' % desc.__name__, field=field)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
|
||||
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(
|
||||
value, desc
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s" % desc.__name__, field=field
|
||||
)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only(
|
||||
"id"
|
||||
):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
def init_cls_from_base(cls, instance):
|
||||
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in instance.to_mongo(use_db_field=False).to_dict().items()
|
||||
if k[0] != "_"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_company_or_none_constraint(company=None):
|
||||
return Q(company__in=(company, None, '')) | Q(company__exists=False)
|
||||
return Q(company__in=(company, None, "")) | Q(company__exists=False)
|
||||
|
||||
|
||||
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
@@ -118,23 +143,40 @@ def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = (Q(**{f"{field}__exists": False}) |
|
||||
Q(**{f"{field}__in": {empty_value, None}}))
|
||||
query = Q(**{f"{field}__exists": False}) | Q(
|
||||
**{f"{field}__in": {empty_value, None}}
|
||||
)
|
||||
if is_list:
|
||||
query |= Q(**{f"{field}__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def field_exists(field: str, empty_value=None) -> Q:
|
||||
"""
|
||||
Creates a query object used for finding a field that exists and is not None or empty.
|
||||
:param field: Field name
|
||||
:param empty_value: The empty value to test for (None means no specific empty value will be used).
|
||||
For lists pass [] for empty_value
|
||||
:return:
|
||||
"""
|
||||
query = Q(**{f"{field}__exists": True}) & Q(
|
||||
**{f"{field}__nin": {empty_value, None}}
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
def get_subkey(d, key_path, default=None):
|
||||
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
|
||||
the nested dictionary.
|
||||
"""
|
||||
keys = key_path.split('.')
|
||||
keys = key_path.split(".")
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(d, dict):
|
||||
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
|
||||
raise KeyError(
|
||||
"Expecting a dict (%s)" % (".".join(keys[:i]) if i else "bad input")
|
||||
)
|
||||
d = d.get(key)
|
||||
if key is None:
|
||||
if d is None:
|
||||
return default
|
||||
return d
|
||||
|
||||
@@ -158,3 +200,41 @@ def merge_dicts(*dicts):
|
||||
def filter_fields(cls, fields):
|
||||
"""From the fields dictionary return only the fields that match cls fields"""
|
||||
return {key: fields[key] for key in fields if key in get_fields(cls)}
|
||||
|
||||
|
||||
def _names_set(*names: str) -> Set[str]:
|
||||
"""
|
||||
Given a list of names return set with names and '-names'
|
||||
"""
|
||||
return set(names) | set(f"-{name}" for name in names)
|
||||
|
||||
|
||||
system_tag_names = {
|
||||
"model": _names_set("active", "archived"),
|
||||
"project": _names_set("archived", "public", "default"),
|
||||
"task": _names_set("active", "archived", "development"),
|
||||
}
|
||||
|
||||
system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
|
||||
|
||||
def partition_tags(
|
||||
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Partition the given tags sequence into system and user-defined tags
|
||||
:param entity: The name of the entity that defines the list of the system tags
|
||||
:param tags: The tags to partition
|
||||
:param system_tags: Optional. If passed then these tags are considered system together
|
||||
with those defined for the entity.
|
||||
:return: a tuple where the first element is the sequence of user-defined tags and
|
||||
the second element is the sequence of system tags
|
||||
"""
|
||||
tags = set(tags)
|
||||
system_tags = set(system_tags)
|
||||
system_tags |= tags & system_tag_names[entity]
|
||||
|
||||
prefixes = system_tag_prefixes.get(entity, [])
|
||||
system_tags |= {t for t in tags for p in prefixes if t.lower().startswith(p)}
|
||||
|
||||
return list(tags - system_tags), list(system_tags)
|
||||
|
||||
@@ -7,12 +7,17 @@ from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = "ELASTIC_SERVICE_SERVICE_HOST"
|
||||
OVERRIDE_HOST_ENV_KEY = ("ELASTIC_SERVICE_HOST", "ELASTIC_SERVICE_SERVICE_HOST")
|
||||
OVERRIDE_PORT_ENV_KEY = "ELASTIC_SERVICE_PORT"
|
||||
|
||||
OVERRIDE_HOST = getenv(OVERRIDE_HOST_ENV_KEY)
|
||||
OVERRIDE_HOST = next(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)), None)
|
||||
if OVERRIDE_HOST:
|
||||
log.info(f"Using override elastic host {OVERRIDE_HOST}")
|
||||
|
||||
OVERRIDE_PORT = getenv(OVERRIDE_PORT_ENV_KEY)
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
@@ -63,9 +68,15 @@ def get_cluster_config(cluster_name):
|
||||
if not cluster_config:
|
||||
raise MissingClusterConfiguration(cluster_name)
|
||||
|
||||
if OVERRIDE_HOST:
|
||||
def set_host_prop(key, value):
|
||||
for host in cluster_config.get('hosts', []):
|
||||
host["host"] = OVERRIDE_HOST
|
||||
host[key] = value
|
||||
|
||||
if OVERRIDE_HOST:
|
||||
set_host_prop("host", OVERRIDE_HOST)
|
||||
|
||||
if OVERRIDE_PORT:
|
||||
set_host_prop("port", OVERRIDE_PORT)
|
||||
|
||||
return cluster_config
|
||||
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
import importlib.util
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
from mongoengine.connection import get_db
|
||||
from semantic_version import Version
|
||||
|
||||
import database.utils
|
||||
from config import config
|
||||
from database import Database
|
||||
from database.model.auth import Role
|
||||
from database.model.auth import User as AuthUser, Credentials
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
from database.model.version import Version as DatabaseVersion
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from es_factory import get_cluster_config
|
||||
from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
migration_dir = (Path(__file__) / "../../migration/mongodb").resolve()
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
@@ -102,8 +111,64 @@ def _ensure_user(user: FixedUser, company_id: str):
|
||||
).save()
|
||||
|
||||
|
||||
def _apply_migrations():
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
|
||||
try:
|
||||
previous_versions = sorted(
|
||||
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
|
||||
reverse=True,
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Invalid database version number encountered: {ex}")
|
||||
|
||||
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
|
||||
|
||||
try:
|
||||
new_scripts = {
|
||||
ver: path
|
||||
for ver, path in (
|
||||
(Version(f.stem), f) for f in migration_dir.glob("*.py")
|
||||
)
|
||||
if ver > last_version
|
||||
}
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Failed parsing migration version from file: {ex}")
|
||||
|
||||
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
|
||||
|
||||
migration_log = log.getChild("mongodb_migration")
|
||||
|
||||
for script_version in sorted(new_scripts.keys()):
|
||||
script = new_scripts[script_version]
|
||||
spec = importlib.util.spec_from_file_location(script.stem, str(script))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for alias, func_name in dbs.items():
|
||||
func = getattr(module, func_name, None)
|
||||
if not func:
|
||||
continue
|
||||
try:
|
||||
migration_log.info(f"Applying {script.stem}/{func_name}()")
|
||||
func(get_db(alias))
|
||||
except Exception:
|
||||
migration_log.exception(f"Failed applying {script}:{func_name}()")
|
||||
raise ValueError("Migration failed, aborting. Please restore backup.")
|
||||
|
||||
DatabaseVersion(
|
||||
id=database.utils.id(),
|
||||
num=script.stem,
|
||||
created=datetime.utcnow(),
|
||||
desc="Applied on server startup",
|
||||
).save()
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
try:
|
||||
_apply_migrations()
|
||||
|
||||
company_id = _ensure_company()
|
||||
users = [
|
||||
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
|
||||
@@ -125,4 +190,4 @@ def init_mongo_data():
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user['name']}: {ex}")
|
||||
except Exception as ex:
|
||||
pass
|
||||
log.exception("Failed initializing mongodb")
|
||||
|
||||
@@ -15,6 +15,11 @@ _definitions {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
last_used {
|
||||
type: string
|
||||
description: ""
|
||||
format: "date-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
role {
|
||||
|
||||
@@ -149,6 +149,14 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_key_enum {
|
||||
type: string
|
||||
enum: [
|
||||
iter
|
||||
timestamp
|
||||
iso_time
|
||||
]
|
||||
}
|
||||
log_level_enum {
|
||||
type: string
|
||||
enum: [
|
||||
@@ -682,6 +690,19 @@
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
description: """
|
||||
Histogram x axis to use:
|
||||
iter - iteration number
|
||||
iso_time - event time as ISO formatted string
|
||||
timestamp - event timestamp as milliseconds since epoch
|
||||
"""
|
||||
"$ref": "#/definitions/scalar_key_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -715,7 +736,19 @@
|
||||
description: "List of task Task IDs"
|
||||
}
|
||||
}
|
||||
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
description: """
|
||||
Histogram x axis to use:
|
||||
iter - iteration number
|
||||
iso_time - event time as ISO formatted string
|
||||
timestamp - event timestamp as milliseconds since epoch
|
||||
"""
|
||||
"$ref": "#/definitions/scalar_key_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -57,9 +57,14 @@
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "Tags"
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||
type: string
|
||||
@@ -159,7 +164,12 @@
|
||||
type: boolean
|
||||
}
|
||||
tags {
|
||||
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
@@ -263,10 +273,15 @@
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task."
|
||||
type: string
|
||||
@@ -325,10 +340,15 @@
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
type: string
|
||||
@@ -344,7 +364,7 @@
|
||||
additionalProperties { type: integer }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
description: "Indication if the model is final and can be used by other tasks. Default is false."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
@@ -408,10 +428,15 @@
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
type: string
|
||||
@@ -485,10 +510,15 @@
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
type: boolean
|
||||
|
||||
@@ -1,462 +1,523 @@
|
||||
{
|
||||
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
|
||||
_definitions {
|
||||
multi_field_pattern_data {
|
||||
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
|
||||
_definitions {
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
project {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
last_update {
|
||||
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
stats_status_count {
|
||||
type: object
|
||||
properties {
|
||||
total_runtime {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
properties {
|
||||
created {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
in_progress {
|
||||
description: "Number of 'in_progress' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
stopped {
|
||||
description: "Number of 'stopped' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
published {
|
||||
description: "Number of 'published' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
closed {
|
||||
description: "Number of 'closed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
failed {
|
||||
description: "Number of 'failed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
unknown {
|
||||
description: "Number of 'unknown' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stats {
|
||||
type: object
|
||||
properties {
|
||||
active {
|
||||
description: "Stats for active tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
archived {
|
||||
description: "Stats for archived tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
}
|
||||
}
|
||||
projects_get_all_response_single {
|
||||
// copy-paste from project definition
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
// extra properties
|
||||
stats: {
|
||||
description: "Additional project stats"
|
||||
"$ref": "#/definitions/stats"
|
||||
}
|
||||
}
|
||||
}
|
||||
metric_variant_result {
|
||||
type: object
|
||||
properties {
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
metric_hash {
|
||||
description: """Metric name hash. Used instead of the metric name when categorizing
|
||||
last metrics events in task objects."""
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Variant name"
|
||||
type: string
|
||||
}
|
||||
variant_hash {
|
||||
description: """Variant name hash. Used instead of the variant name when categorizing
|
||||
last metrics events in task objects."""
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
"2.1" {
|
||||
description: "Create a new project"
|
||||
request {
|
||||
type: object
|
||||
required :[
|
||||
name
|
||||
description
|
||||
]
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
name {
|
||||
description: "Project name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
project_tags_enum {
|
||||
type: string
|
||||
enum: [ archived, public, default ]
|
||||
}
|
||||
project {
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
description: "Tags"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/project_tags_enum" }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
last_update {
|
||||
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
stats_status_count {
|
||||
type: object
|
||||
properties {
|
||||
total_runtime {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
properties {
|
||||
created {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
in_progress {
|
||||
description: "Number of 'in_progress' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
stopped {
|
||||
description: "Number of 'stopped' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
published {
|
||||
description: "Number of 'published' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
closed {
|
||||
description: "Number of 'closed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
failed {
|
||||
description: "Number of 'failed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
unknown {
|
||||
description: "Number of 'unknown' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stats {
|
||||
type: object
|
||||
properties {
|
||||
active {
|
||||
description: "Stats for active tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
archived {
|
||||
description: "Stats for archived tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
}
|
||||
}
|
||||
projects_get_all_response_single {
|
||||
// copy-paste from project definition
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
description: "Tags"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/project_tags_enum" }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
// extra properties
|
||||
stats: {
|
||||
description: "Additional project stats"
|
||||
"$ref": "#/definitions/stats"
|
||||
}
|
||||
}
|
||||
}
|
||||
metric_variant_result {
|
||||
type: object
|
||||
properties {
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
metric_hash {
|
||||
description: """Metric name hash. Used instead of the metric name when categorizing
|
||||
last metrics events in task objects."""
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Variant name"
|
||||
type: string
|
||||
}
|
||||
variant_hash {
|
||||
description: """Variant name hash. Used instead of the variant name when categorizing
|
||||
last metrics events in task objects."""
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
"2.1" {
|
||||
description: "Create a new project"
|
||||
request {
|
||||
type: object
|
||||
required :[
|
||||
name
|
||||
description
|
||||
]
|
||||
properties {
|
||||
name {
|
||||
description: "Project name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/project_tags_enum" }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: ""
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project info"
|
||||
"$ref": "#/definitions/project"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
description: "Get all the company's projects and all public projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "List of IDs to filter by"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
name {
|
||||
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of dataviews"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
page_size {
|
||||
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
|
||||
type: integer
|
||||
minimum: 1
|
||||
}
|
||||
search_text {
|
||||
description: "Free text search query"
|
||||
type: string
|
||||
}
|
||||
only_fields {
|
||||
description: "List of document's field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
_any_ {
|
||||
description: "Multi-field pattern condition (any field matches pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_stats {
|
||||
description: "If true, include project statistic in response."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
stats_for_state {
|
||||
description: "Report stats include only statistics for tasks in the specified state. If Null is provided, stats for all task states will be returned."
|
||||
type: string
|
||||
enum: [ active, archived ]
|
||||
default: active
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
description: "Update project information"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
description: """If not true, fails if project has tasks.
|
||||
If true, and project has tasks, they will be unassigned"""
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Number of projects deleted (0 or 1)"
|
||||
type: integer
|
||||
}
|
||||
disassociated_tasks {
|
||||
description: "Number of tasks disassociated from the deleted project"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_unique_metric_variants {
|
||||
"2.1" {
|
||||
description: """Get all metric/variant pairs reported for tasks in a specific project.
|
||||
If no project is specified, metrics/variant paris reported for all tasks will be returned.
|
||||
If the project does not exist, an empty list will be returned."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
description: "A list of metric variants reported for tasks in this project"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/metric_variant_result" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: ""
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project info"
|
||||
"$ref": "#/definitions/project"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
description: "Get all the company's projects and all public projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "List of IDs to filter by"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
name {
|
||||
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of dataviews"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
page_size {
|
||||
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
|
||||
type: integer
|
||||
minimum: 1
|
||||
}
|
||||
search_text {
|
||||
description: "Free text search query"
|
||||
type: string
|
||||
}
|
||||
only_fields {
|
||||
description: "List of document's field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
_any_ {
|
||||
description: "Multi-field pattern condition (any field matches pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_stats {
|
||||
description: "If true, include project statistic in response."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
stats_for_state {
|
||||
description: "Report stats include only statistics for tasks in the specified state. If Null is provided, stats for all task states will be returned."
|
||||
type: string
|
||||
enum: [ active, archived ]
|
||||
default: active
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
description: "Update project information"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
description: """If not true, fails if project has tasks.
|
||||
If true, and project has tasks, they will be unassigned"""
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Number of projects deleted (0 or 1)"
|
||||
type: integer
|
||||
}
|
||||
disassociated_tasks {
|
||||
description: "Number of tasks disassociated from the deleted project"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_unique_metric_variants {
|
||||
"2.1" {
|
||||
description: """Get all metric/variant pairs reported for tasks in a specific project.
|
||||
If no project is specified, metrics/variant paris reported for all tasks will be returned.
|
||||
If the project does not exist, an empty list will be returned."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
description: "A list of metric variants reported for tasks in this project"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/metric_variant_result" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.2" {
|
||||
description: """Get a list of all hyper parameter names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
page {
|
||||
description: "Page number"
|
||||
default: 0
|
||||
type: integer
|
||||
}
|
||||
page_size {
|
||||
description: "Page size"
|
||||
default: 500
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
parameters {
|
||||
description: "A list of hyper parameter names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
type: integer
|
||||
}
|
||||
total {
|
||||
description: "Total number of results"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
68
server/schema/services/server.conf
Normal file
68
server/schema/services/server.conf
Normal file
@@ -0,0 +1,68 @@
|
||||
_description: "server utilities"
|
||||
_default {
|
||||
internal: true
|
||||
allow_roles: ["root", "system"]
|
||||
}
|
||||
config {
|
||||
"2.1" {
|
||||
description: "Get server configuration. Secure section is not returned."
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
path {
|
||||
description: "Path of config value. Defaults to root"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info {
|
||||
authorize = false
|
||||
allow_roles = [ "*" ]
|
||||
"2.1" {
|
||||
description: "Get server information, including version and build number"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
version {
|
||||
description: "Version string"
|
||||
type: string
|
||||
}
|
||||
build {
|
||||
description: "Build number"
|
||||
type: string
|
||||
}
|
||||
commit {
|
||||
description: "VCS commit number"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
endpoints {
|
||||
"2.1" {
|
||||
description: "Show available endpoints"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,6 +120,76 @@ _definitions {
|
||||
frame_per_roi
|
||||
]
|
||||
}
|
||||
artifact_type_data {
|
||||
type: object
|
||||
properties {
|
||||
preview {
|
||||
description: "Description or textual data"
|
||||
type: string
|
||||
}
|
||||
content_type {
|
||||
description: "System defined raw data content type"
|
||||
type: string
|
||||
}
|
||||
data_hash {
|
||||
description: "Hash of raw data, without any headers or descriptive parts"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact {
|
||||
type: object
|
||||
required: [key, type]
|
||||
properties {
|
||||
key {
|
||||
description: "Entry key"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "System defined type"
|
||||
type: string
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
type: string
|
||||
enum: [
|
||||
input
|
||||
output
|
||||
]
|
||||
default: output
|
||||
}
|
||||
uri {
|
||||
description: "Raw data location"
|
||||
type: string
|
||||
}
|
||||
content_size {
|
||||
description: "Raw data length in bytes"
|
||||
type: integer
|
||||
}
|
||||
hash {
|
||||
description: "Hash of entire raw data"
|
||||
type: string
|
||||
}
|
||||
timestamp {
|
||||
description: "Epoch time when artifact was created"
|
||||
type: integer
|
||||
}
|
||||
type_data {
|
||||
description: "Additional fields defined by the system"
|
||||
"$ref": "#/definitions/artifact_type_data"
|
||||
}
|
||||
display_data {
|
||||
description: "User-defined list of key/value pairs, sorted"
|
||||
type: array
|
||||
items {
|
||||
type: array
|
||||
items {
|
||||
type: string # can also be a number... TODO: upgrade the generator
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
@@ -149,6 +219,11 @@ _definitions {
|
||||
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
|
||||
type: string
|
||||
}
|
||||
artifacts {
|
||||
description: "Task artifacts"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/artifact" }
|
||||
}
|
||||
}
|
||||
}
|
||||
task_status_enum {
|
||||
@@ -183,21 +258,16 @@ _definitions {
|
||||
description: "Variant name"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "Event type"
|
||||
type: string
|
||||
}
|
||||
timestamp {
|
||||
description: "Event report time (UTC)"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
iter {
|
||||
description: "Iteration number"
|
||||
type: integer
|
||||
}
|
||||
value {
|
||||
description: "Value"
|
||||
description: "Last value reported"
|
||||
type: number
|
||||
}
|
||||
min_value {
|
||||
description: "Minimum value reported"
|
||||
type: number
|
||||
}
|
||||
max_value {
|
||||
description: "Maximum value reported"
|
||||
type: number
|
||||
}
|
||||
}
|
||||
@@ -278,10 +348,15 @@ _definitions {
|
||||
"$ref": "#/definitions/script"
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
status_changed {
|
||||
description: "Last status change time"
|
||||
type: string
|
||||
@@ -392,7 +467,12 @@ get_all {
|
||||
items { type: string }
|
||||
}
|
||||
tags {
|
||||
description: "List of task tags. Use '-' prefix to exclude tags"
|
||||
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
@@ -467,10 +547,15 @@ create {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
type {
|
||||
description: "Type of task"
|
||||
"$ref": "#/definitions/task_type_enum"
|
||||
@@ -527,10 +612,15 @@ validate {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
type {
|
||||
description: "Type of task"
|
||||
"$ref": "#/definitions/task_type_enum"
|
||||
@@ -585,10 +675,15 @@ update {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
comment {
|
||||
description: "Free text comment "
|
||||
type: string
|
||||
@@ -667,10 +762,15 @@ edit {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don’t use it."
|
||||
items {type: string}
|
||||
}
|
||||
type {
|
||||
description: "Type of task"
|
||||
"$ref": "#/definitions/task_type_enum"
|
||||
@@ -1062,7 +1162,7 @@ completed {
|
||||
task
|
||||
]
|
||||
properties.force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is not created/in_progress/published"
|
||||
description: "If not true, call fails if the task status is not in_progress/stopped"
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
|
||||
@@ -107,7 +107,6 @@ def update_call_data(call, req):
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
# NOTE: dict() form data to make sure we won't pass along a MultiDict or some other nasty crap
|
||||
call.data = json_body or form or {}
|
||||
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class DataContainer(object):
|
||||
if self._batched_data:
|
||||
try:
|
||||
data_model = [cls(**item) for item in self._batched_data]
|
||||
except TypeError as ex:
|
||||
except (ValueError, TypeError) as ex:
|
||||
raise CallParsingError(str(ex))
|
||||
|
||||
for m in data_model:
|
||||
@@ -112,7 +112,7 @@ class DataContainer(object):
|
||||
else:
|
||||
try:
|
||||
data_model = cls(**self.data)
|
||||
except TypeError as ex:
|
||||
except (ValueError, TypeError) as ex:
|
||||
raise CallParsingError(str(ex))
|
||||
|
||||
if not self.schema_validator.enabled:
|
||||
@@ -182,8 +182,6 @@ class APICallResult(DataContainer):
|
||||
traceback=self._traceback,
|
||||
extra=self._extra,
|
||||
)
|
||||
if self.log_data:
|
||||
res["data"] = self.data
|
||||
return res
|
||||
|
||||
def copy_from(self, result):
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
from mongoengine import Q
|
||||
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.company import Company
|
||||
from database.utils import get_options
|
||||
from database.model.auth import User, Entities, Credentials
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User, Entities, Credentials
|
||||
from database.model.company import Company
|
||||
from database.utils import get_options
|
||||
from timing_context import TimingContext
|
||||
|
||||
from .payload import Payload, Token, Basic, AuthType
|
||||
from .identity import Identity
|
||||
from .fixed_user import FixedUser
|
||||
|
||||
from .identity import Identity
|
||||
from .payload import Payload, Token, Basic, AuthType
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -38,12 +38,16 @@ def authorize_token(jwt_token, *_, **__):
|
||||
return Token.from_encoded_token(jwt_token)
|
||||
|
||||
except jwt.exceptions.InvalidKeyError as ex:
|
||||
raise errors.unauthorized.InvalidToken('jwt invalid key error', reason=ex.args[0])
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
"jwt invalid key error", reason=ex.args[0]
|
||||
)
|
||||
except jwt.InvalidTokenError as ex:
|
||||
raise errors.unauthorized.InvalidToken('invalid jwt token', reason=ex.args[0])
|
||||
raise errors.unauthorized.InvalidToken("invalid jwt token", reason=ex.args[0])
|
||||
except ValueError as ex:
|
||||
log.exception('Failed while processing token: %s' % ex.args[0])
|
||||
raise errors.unauthorized.InvalidToken('failed processing token', reason=ex.args[0])
|
||||
log.exception("Failed while processing token: %s" % ex.args[0])
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
"failed processing token", reason=ex.args[0]
|
||||
)
|
||||
|
||||
|
||||
def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
@@ -67,9 +71,14 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
|
||||
user = User.objects(query).first()
|
||||
if not user:
|
||||
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
|
||||
|
||||
if not user:
|
||||
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
|
||||
if not FixedUser.enabled():
|
||||
# In case these are proper credentials, update last used time
|
||||
User.objects(id=user.id, credentials__key=access_key).update(
|
||||
**{"set__credentials__$__last_used": datetime.utcnow()}
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "company_by_id"):
|
||||
company = Company.objects(id=user.company).only('id', 'name').first()
|
||||
@@ -85,13 +94,13 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
return basic
|
||||
|
||||
|
||||
def authorize_impersonation(user, identity, service, action, call_data_items):
|
||||
def authorize_impersonation(user, identity, service, action, call):
|
||||
""" Returns a new basic object (auth payload)"""
|
||||
if not user:
|
||||
raise ValueError('missing user')
|
||||
raise ValueError("missing user")
|
||||
|
||||
company = Company.objects(id=user.company).only('id', 'name').first()
|
||||
company = Company.objects(id=user.company).only("id", "name").first()
|
||||
if not company:
|
||||
raise errors.unauthorized.InvalidCredentials('invalid user company')
|
||||
raise errors.unauthorized.InvalidCredentials("invalid user company")
|
||||
|
||||
return Payload(auth_type=None, identity=identity)
|
||||
|
||||
@@ -30,6 +30,8 @@ def get_secret_key(length=50):
|
||||
Create a random secret key.
|
||||
|
||||
Taken from the Django project.
|
||||
NOTE: asterisk is not supported due to issues with environment variables containing
|
||||
asterisks (in case the secret key is stored in an environment variable)
|
||||
"""
|
||||
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*(-_=+)'
|
||||
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&(-_=+)'
|
||||
return get_random_string(length, chars)
|
||||
|
||||
@@ -76,7 +76,7 @@ class Endpoint(object):
|
||||
Provided endpoints and their schemas on a best-effort basis.
|
||||
"""
|
||||
d = {
|
||||
"min_version": self.min_version,
|
||||
"min_version": str(self.min_version),
|
||||
"required_fields": self.required_fields,
|
||||
"request_data_model": None,
|
||||
"response_data_model": None,
|
||||
|
||||
@@ -12,7 +12,7 @@ from config import config
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, auto_exc=True)
|
||||
@attr.s(auto_attribs=True, cmp=False)
|
||||
class FastValidationError(Exception):
|
||||
error: fastjsonschema.JsonSchemaException
|
||||
data: dict
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import re
|
||||
from importlib import import_module
|
||||
from itertools import chain
|
||||
from typing import cast, Iterable, List, MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple
|
||||
|
||||
import jsonmodels.models
|
||||
from pathlib import Path
|
||||
|
||||
import timing_context
|
||||
from apierrors import APIError
|
||||
@@ -30,7 +30,11 @@ class ServiceRepo(object):
|
||||
_version_required = config.get("apiserver.version.required")
|
||||
""" If version is required, parsing will fail for endpoint paths that do not contain a valid version """
|
||||
|
||||
_max_version = PartialVersion("2.1")
|
||||
_check_max_version = config.get("apiserver.version.check_max_version")
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.3")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
@@ -133,7 +137,7 @@ class ServiceRepo(object):
|
||||
return cls._max_version
|
||||
|
||||
@classmethod
|
||||
def _get_endpoint(cls, name, version):
|
||||
def _get_endpoint(cls, name, version) -> Optional[Endpoint]:
|
||||
versions = cls._endpoints.get(name)
|
||||
if not versions:
|
||||
return None
|
||||
@@ -144,7 +148,7 @@ class ServiceRepo(object):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_endpoint_from_call(cls, call):
|
||||
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
|
||||
assert isinstance(call, APICall)
|
||||
endpoint = cls._get_endpoint(
|
||||
call.endpoint_name, call.requested_endpoint_version
|
||||
@@ -167,7 +171,7 @@ class ServiceRepo(object):
|
||||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def parse_endpoint_path(cls, path):
|
||||
def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]:
|
||||
""" Parse endpoint version, service and action from request path. """
|
||||
m = cls._endpoint_exp.match(path)
|
||||
if not m:
|
||||
@@ -182,14 +186,14 @@ class ServiceRepo(object):
|
||||
version = PartialVersion(version)
|
||||
except ValueError as e:
|
||||
raise RequestPathHasInvalidVersion(version=version, reason=e)
|
||||
if version > cls._max_version:
|
||||
if cls._check_max_version and version > cls._max_version:
|
||||
raise InvalidVersionError(
|
||||
f"Invalid API version (max. supported version is {cls._max_version})"
|
||||
)
|
||||
return version, endpoint_name
|
||||
|
||||
@classmethod
|
||||
def _should_return_stack(cls, code, subcode):
|
||||
def _should_return_stack(cls, code: int, subcode: int) -> bool:
|
||||
if not cls._return_stack or code not in cls._return_stack_on_code:
|
||||
return False
|
||||
if subcode is None:
|
||||
@@ -202,7 +206,7 @@ class ServiceRepo(object):
|
||||
return subcode in subcode_list
|
||||
|
||||
@classmethod
|
||||
def _validate_call(cls, call):
|
||||
def _validate_call(cls, call: APICall) -> Optional[Endpoint]:
|
||||
endpoint = cls._resolve_endpoint_from_call(call)
|
||||
if call.failed:
|
||||
return
|
||||
@@ -210,11 +214,13 @@ class ServiceRepo(object):
|
||||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def validate_call(cls, call):
|
||||
def validate_call(cls, call: APICall):
|
||||
cls._validate_call(call)
|
||||
|
||||
@classmethod
|
||||
def _get_company(cls, call, endpoint=None, ignore_error=False):
|
||||
def _get_company(
|
||||
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
||||
) -> Optional[str]:
|
||||
authorize = endpoint and endpoint.authorize
|
||||
if ignore_error or not authorize:
|
||||
try:
|
||||
@@ -224,7 +230,7 @@ class ServiceRepo(object):
|
||||
return call.identity.company
|
||||
|
||||
@classmethod
|
||||
def handle_call(cls, call):
|
||||
def handle_call(cls, call: APICall):
|
||||
try:
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def validate_impersonation(endpoint, call):
|
||||
),
|
||||
service=service,
|
||||
action=action,
|
||||
call_data_items=call.batched_data,
|
||||
call=call,
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -143,7 +143,8 @@ def get_credentials(call):
|
||||
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
||||
call.result.data_model = GetCredentialsResponse(
|
||||
credentials=[
|
||||
CredentialsResponse(access_key=c.key) for c in user.credentials
|
||||
CredentialsResponse(access_key=c.key, last_used=c.last_used)
|
||||
for c in user.credentials
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,12 @@ from operator import itemgetter
|
||||
import six
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.events import (
|
||||
MultiTaskScalarMetricsIterHistogramRequest,
|
||||
ScalarMetricsIterHistogramRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.task import TaskBLL
|
||||
from service_repo import APICall, endpoint
|
||||
from utilities import json
|
||||
@@ -17,11 +22,10 @@ event_bll = EventBLL()
|
||||
@endpoint("events.add")
|
||||
def add(call, company_id, req_model):
|
||||
assert isinstance(call, APICall)
|
||||
added, batch_errors = event_bll.add_events(company_id, [call.data.copy()], call.worker)
|
||||
call.result.data = dict(
|
||||
added=added,
|
||||
errors=len(batch_errors)
|
||||
added, batch_errors = event_bll.add_events(
|
||||
company_id, [call.data.copy()], call.worker
|
||||
)
|
||||
call.result.data = dict(added=added, errors=len(batch_errors))
|
||||
call.kpis["events"] = 1
|
||||
|
||||
|
||||
@@ -33,10 +37,7 @@ def add_batch(call, company_id, req_model):
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(
|
||||
added=added,
|
||||
errors=len(batch_errors)
|
||||
)
|
||||
call.result.data = dict(added=added, errors=len(batch_errors))
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
|
||||
@@ -48,16 +49,16 @@ def get_task_log(call, company_id, req_model):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id, task_id, order,
|
||||
company_id,
|
||||
task_id,
|
||||
order,
|
||||
event_type="log",
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id)
|
||||
call.result.data = dict(
|
||||
events=events,
|
||||
returned=len(events),
|
||||
total=total_events,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
call.result.data = dict(
|
||||
events=events, returned=len(events), total=total_events, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
|
||||
@@ -70,7 +71,7 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
|
||||
scroll_order = 'asc' if (from_ == 'head') else 'desc'
|
||||
scroll_order = "asc" if (from_ == "head") else "desc"
|
||||
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id=company_id,
|
||||
@@ -78,54 +79,57 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
order=scroll_order,
|
||||
event_type="log",
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
if scroll_order != order:
|
||||
events = events[::-1]
|
||||
|
||||
call.result.data = dict(
|
||||
events=events,
|
||||
returned=len(events),
|
||||
total=total_events,
|
||||
scroll_id=scroll_id,
|
||||
events=events, returned=len(events), total=total_events, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
|
||||
@endpoint('events.download_task_log', required_fields=['task'])
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
def download_task_log(call, company_id, req_model):
|
||||
task_id = call.data['task']
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
|
||||
line_type = call.data.get('line_type', 'json').lower()
|
||||
line_format = str(call.data.get('line_format', '{asctime} {worker} {level} {msg}'))
|
||||
line_type = call.data.get("line_type", "json").lower()
|
||||
line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}"))
|
||||
|
||||
is_json = (line_type == 'json')
|
||||
is_json = line_type == "json"
|
||||
if not is_json:
|
||||
if not line_format:
|
||||
raise errors.bad_request.MissingRequiredFields('line_format is required for plain text lines')
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"line_format is required for plain text lines"
|
||||
)
|
||||
|
||||
# validate line format placeholders
|
||||
valid_task_log_fields = {'asctime', 'timestamp', 'level', 'worker', 'msg'}
|
||||
valid_task_log_fields = {"asctime", "timestamp", "level", "worker", "msg"}
|
||||
|
||||
invalid_placeholders = set()
|
||||
while True:
|
||||
try:
|
||||
line_format.format(**dict.fromkeys(valid_task_log_fields | invalid_placeholders))
|
||||
line_format.format(
|
||||
**dict.fromkeys(valid_task_log_fields | invalid_placeholders)
|
||||
)
|
||||
break
|
||||
except KeyError as e:
|
||||
invalid_placeholders.add(e.args[0])
|
||||
except Exception as e:
|
||||
raise errors.bad_request.FieldsValueError('invalid line format', error=e.args[0])
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"invalid line format", error=e.args[0]
|
||||
)
|
||||
|
||||
if invalid_placeholders:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
'undefined placeholders in line format',
|
||||
placeholders=invalid_placeholders
|
||||
"undefined placeholders in line format",
|
||||
placeholders=invalid_placeholders,
|
||||
)
|
||||
|
||||
# make sure line_format has a trailing newline
|
||||
line_format = line_format.rstrip('\n') + '\n'
|
||||
line_format = line_format.rstrip("\n") + "\n"
|
||||
|
||||
def generate():
|
||||
scroll_id = None
|
||||
@@ -137,30 +141,30 @@ def download_task_log(call, company_id, req_model):
|
||||
order="asc",
|
||||
event_type="log",
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
if not log_events:
|
||||
break
|
||||
for ev in log_events:
|
||||
ev['asctime'] = ev.pop('@timestamp')
|
||||
ev["asctime"] = ev.pop("@timestamp")
|
||||
if is_json:
|
||||
ev.pop('type')
|
||||
ev.pop('task')
|
||||
yield json.dumps(ev) + '\n'
|
||||
ev.pop("type")
|
||||
ev.pop("task")
|
||||
yield json.dumps(ev) + "\n"
|
||||
else:
|
||||
try:
|
||||
yield line_format.format(**ev)
|
||||
except KeyError as ex:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
'undefined placeholders in line format',
|
||||
placeholders=[str(ex)]
|
||||
"undefined placeholders in line format",
|
||||
placeholders=[str(ex)],
|
||||
)
|
||||
|
||||
if len(log_events) < batch_size:
|
||||
break
|
||||
|
||||
call.result.filename = 'task_%s.log' % task_id
|
||||
call.result.content_type = 'text/plain'
|
||||
call.result.filename = "task_%s.log" % task_id
|
||||
call.result.content_type = "text/plain"
|
||||
call.result.raw_data = generate()
|
||||
|
||||
|
||||
@@ -169,7 +173,9 @@ def get_vector_metrics_and_variants(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_vector")
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
company_id, task_id, "training_stats_vector"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -178,23 +184,27 @@ def get_scalar_metrics_and_variants(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_scalar")
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
company_id, task_id, "training_stats_scalar"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
|
||||
@endpoint("events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"])
|
||||
@endpoint(
|
||||
"events.vector_metrics_iter_histogram",
|
||||
required_fields=["task", "metric", "variant"],
|
||||
)
|
||||
def vector_metrics_iter_histogram(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(company_id, task_id, metric, variant)
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
company_id, task_id, metric, variant
|
||||
)
|
||||
call.result.data = dict(
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
vectors=vectors,
|
||||
iterations=iterations
|
||||
metric=metric, variant=variant, vectors=vectors, iterations=iterations
|
||||
)
|
||||
|
||||
|
||||
@@ -207,10 +217,11 @@ def get_task_events(call, company_id, req_model):
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
company_id,
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=event_type,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@@ -229,11 +240,12 @@ def get_scalar_metric_data(call, company_id, req_model):
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type="training_stats_scalar",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@@ -248,35 +260,50 @@ def get_scalar_metric_data(call, company_id, req_model):
|
||||
def get_task_latest_scalar_values(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(company_id, task_id)
|
||||
es_index = EventBLL.get_index_name(company_id, "*")
|
||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
|
||||
company_id, task_id
|
||||
)
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
last_iter=last_iters[0] if last_iters else 0,
|
||||
name=task.name,
|
||||
status=task.status,
|
||||
last_timestamp=last_timestamp
|
||||
last_timestamp=last_timestamp,
|
||||
)
|
||||
|
||||
|
||||
# todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed
|
||||
@endpoint("events.scalar_metrics_iter_histogram", required_fields=["task"])
|
||||
def scalar_metrics_iter_histogram(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
metrics = event_bll.get_scalar_metrics_average_per_iter(company_id, task_id)
|
||||
@endpoint(
|
||||
"events.scalar_metrics_iter_histogram",
|
||||
request_data_model=ScalarMetricsIterHistogramRequest,
|
||||
)
|
||||
def scalar_metrics_iter_histogram(
|
||||
call, company_id, req_model: ScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task_bll.assert_exists(call.identity.company, req_model.task, allow_public=True)
|
||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||
company_id, task_id=req_model.task, samples=req_model.samples, key=req_model.key
|
||||
)
|
||||
call.result.data = metrics
|
||||
|
||||
|
||||
@endpoint("events.multi_task_scalar_metrics_iter_histogram", required_fields=["tasks"])
|
||||
def multi_task_scalar_metrics_iter_histogram(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
@endpoint(
|
||||
"events.multi_task_scalar_metrics_iter_histogram",
|
||||
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
|
||||
)
|
||||
def multi_task_scalar_metrics_iter_histogram(
|
||||
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task_ids = req_model.tasks
|
||||
if isinstance(task_ids, six.string_types):
|
||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||
# Note, bll already validates task ids as it needs their names
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.compare_scalar_metrics_average_per_iter(company_id, task_ids, allow_public=True)
|
||||
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||
company_id, task_ids=task_ids, samples=req_model.samples, allow_public=True, key=req_model.key
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -287,21 +314,27 @@ def get_multi_task_plots_v1_7(call, company_id, req_model):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_ids,
|
||||
company_id,
|
||||
task_ids,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, tasks=tasks
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
@@ -318,20 +351,26 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_ids,
|
||||
company_id,
|
||||
task_ids,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, tasks=tasks
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
@@ -357,11 +396,12 @@ def get_task_plots_v1_7(call, company_id, req_model):
|
||||
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
|
||||
@@ -381,12 +421,13 @@ def get_task_plots(call, company_id, req_model):
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
event_type="plot",
|
||||
result = event_bll.get_task_plots(
|
||||
company_id,
|
||||
|
||||
tasks=[task_id],
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@@ -415,11 +456,12 @@ def get_debug_images_v1_7(call, company_id, req_model):
|
||||
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
|
||||
@@ -441,11 +483,12 @@ def get_debug_images(call, company_id, req_model):
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@@ -464,34 +507,32 @@ def delete_for_task(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
|
||||
task_bll.assert_exists(company_id, task_id)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(company_id, task_id)
|
||||
)
|
||||
call.result.data = dict(deleted=event_bll.delete_task_events(company_id, task_id))
|
||||
|
||||
|
||||
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
key = itemgetter('metric', 'variant', 'task', 'iter')
|
||||
key = itemgetter("metric", "variant", "task", "iter")
|
||||
|
||||
unique_events = itertools.chain.from_iterable(
|
||||
itertools.islice(group, max_iters)
|
||||
for _, group in itertools.groupby(sorted(events, key=key, reverse=True), key=key))
|
||||
for _, group in itertools.groupby(
|
||||
sorted(events, key=key, reverse=True), key=key
|
||||
)
|
||||
)
|
||||
|
||||
def collect(evs, fields):
|
||||
if not fields:
|
||||
evs = list(evs)
|
||||
return {
|
||||
'name': tasks.get(evs[0].get('task')),
|
||||
'plots': evs
|
||||
}
|
||||
return {"name": tasks.get(evs[0].get("task")), "plots": evs}
|
||||
return {
|
||||
str(k): collect(group, fields[1:])
|
||||
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
|
||||
}
|
||||
|
||||
collect_fields = ('metric', 'variant', 'task', 'iter')
|
||||
collect_fields = ("metric", "variant", "task", "iter")
|
||||
return collect(
|
||||
sorted(unique_events, key=itemgetter(*collect_fields), reverse=True),
|
||||
collect_fields
|
||||
collect_fields,
|
||||
)
|
||||
|
||||
|
||||
@@ -502,6 +543,8 @@ def _get_top_iter_unique_events(events, max_iters):
|
||||
evs = top_unique_events[key]
|
||||
if len(evs) < max_iters:
|
||||
evs.append(e)
|
||||
unique_events = list(itertools.chain.from_iterable(list(top_unique_events.values())))
|
||||
unique_events = list(
|
||||
itertools.chain.from_iterable(list(top_unique_events.values()))
|
||||
)
|
||||
unique_events.sort(key=lambda e: e["iter"], reverse=True)
|
||||
return unique_events
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
@@ -16,7 +15,6 @@ from apimodels.models import (
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.fields import SupportedURLField
|
||||
from database.model import validate_id
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
@@ -27,13 +25,23 @@ from database.utils import (
|
||||
filter_fields,
|
||||
)
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
get_all_query_options = Model.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=("tags", "framework", "uri", "id", "project", "task", "parent"),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -43,20 +51,20 @@ def get_by_id(call):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
res = Model.get_many(
|
||||
models = Model.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not res:
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
)
|
||||
|
||||
call.result.data = {"model": res[0]}
|
||||
conform_output_tags(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
@@ -66,31 +74,32 @@ def get_by_task_id(call):
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=call.identity.company)
|
||||
res = Task.get(_only=["output"], **query)
|
||||
if not res:
|
||||
task = Task.get(_only=["output"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not res.output:
|
||||
if not task.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="output")
|
||||
if not res.output.model:
|
||||
if not task.output.model:
|
||||
raise errors.bad_request.MissingTaskFields(field="output.model")
|
||||
|
||||
model_id = res.output.model
|
||||
res = Model.objects(
|
||||
model_id = task.output.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
|
||||
).first()
|
||||
if not res:
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
)
|
||||
call.result.data = {"model": res.to_proper_dict()}
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
@@ -99,14 +108,13 @@ def get_all_ex(call):
|
||||
allow_public=True,
|
||||
query_options=get_all_query_options,
|
||||
)
|
||||
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
models = Model.get_many(
|
||||
@@ -116,13 +124,14 @@ def get_all(call):
|
||||
allow_public=True,
|
||||
query_options=get_all_query_options,
|
||||
)
|
||||
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
"tags": list,
|
||||
"system_tags": list,
|
||||
"task": Task,
|
||||
"comment": None,
|
||||
"uri": None,
|
||||
@@ -134,22 +143,10 @@ create_fields = {
|
||||
"ready": None,
|
||||
}
|
||||
|
||||
schemes = list(SupportedURLField.schemes)
|
||||
|
||||
|
||||
def _validate_uri(uri):
|
||||
parsed_uri = urlparse(uri)
|
||||
if parsed_uri.scheme not in schemes:
|
||||
raise errors.bad_request.InvalidModelUri("unsupported scheme", uri=uri)
|
||||
elif not parsed_uri.path:
|
||||
raise errors.bad_request.InvalidModelUri("missing path", uri=uri)
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
tags = fields.get("tags")
|
||||
if tags:
|
||||
fields["tags"] = list(set(tags))
|
||||
conform_tag_fields(call, fields)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -251,15 +248,14 @@ def create(call, company, req_model):
|
||||
if project:
|
||||
validate_id(Project, company=company, project=project)
|
||||
|
||||
uri = req_model.uri
|
||||
if uri:
|
||||
_validate_uri(uri)
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(call, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields)
|
||||
|
||||
# create and save model
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
@@ -276,12 +272,12 @@ def create(call, company, req_model):
|
||||
def prepare_update_fields(call, fields):
|
||||
fields = fields.copy()
|
||||
if "uri" in fields:
|
||||
_validate_uri(fields["uri"])
|
||||
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(call, fields)
|
||||
|
||||
conform_tag_fields(call, fields)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -290,8 +286,7 @@ def validate_task(call, fields):
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
def edit(call):
|
||||
assert isinstance(call, APICall)
|
||||
def edit(call: APICall):
|
||||
identity = call.identity
|
||||
model_id = call.data["model"]
|
||||
|
||||
@@ -327,13 +322,13 @@ def edit(call):
|
||||
|
||||
if fields:
|
||||
updated = model.update(upsert=False, **fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call, model_id=None):
|
||||
assert isinstance(call, APICall)
|
||||
def _update_model(call: APICall, model_id=None):
|
||||
identity = call.identity
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
@@ -358,6 +353,7 @@ def _update_model(call, model_id=None):
|
||||
updated_count, updated_fields = Model.safe_update(
|
||||
call.identity.company, model.id, data
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
|
||||
@@ -9,27 +9,32 @@ from mongoengine import Q
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq
|
||||
from bll.task import TaskBLL
|
||||
from database.errors import translate_errors_context
|
||||
from database.model import EntityVisibility
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus, TaskVisibility
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from timing_context import TimingContext
|
||||
|
||||
task_bll = TaskBLL()
|
||||
archived_tasks_cond = {"$in": [TaskVisibility.archived.value, "$tags"]}
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
"description": None,
|
||||
"tags": list,
|
||||
"system_tags": list,
|
||||
"default_output_destination": None,
|
||||
}
|
||||
|
||||
get_all_query_options = Project.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "id")
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
|
||||
@@ -43,32 +48,39 @@ def get_by_id(call):
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(
|
||||
call.identity.company
|
||||
)
|
||||
res = Project.objects(query).first()
|
||||
if not res:
|
||||
project = Project.objects(query).first()
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
res = res.to_proper_dict()
|
||||
project_dict = project.to_proper_dict()
|
||||
conform_output_tags(call, project_dict)
|
||||
|
||||
call.result.data = {"project": res}
|
||||
call.result.data = {"project": project_dict}
|
||||
|
||||
|
||||
def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
archived = TaskVisibility.archived.value
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{"$match": {"project": {"$in": project_ids}}},
|
||||
# make sure tags is always an array (required by subsequent $in in archived_tasks_cond)
|
||||
{
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def ensure_system_tags():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
return {
|
||||
"$addFields": {
|
||||
"tags": {
|
||||
"system_tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$tags"}, "array"]},
|
||||
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$tags",
|
||||
"else": "$system_tags",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{"$match": {"project": {"$in": project_ids}}},
|
||||
ensure_system_tags(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
@@ -125,12 +137,12 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in TaskVisibility:
|
||||
for state in EntityVisibility:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == TaskVisibility.active:
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
||||
elif state == TaskVisibility.archived:
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
||||
|
||||
runtime_pipeline = [
|
||||
@@ -141,6 +153,7 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_system_tags(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
@@ -151,32 +164,33 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex")
|
||||
def get_all_ex(call):
|
||||
assert isinstance(call, APICall)
|
||||
def get_all_ex(call: APICall):
|
||||
include_stats = call.data.get("include_stats")
|
||||
stats_for_state = call.data.get("stats_for_state", TaskVisibility.active.value)
|
||||
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
||||
|
||||
if stats_for_state:
|
||||
try:
|
||||
specific_state = TaskVisibility(stats_for_state)
|
||||
specific_state = EntityVisibility(stats_for_state)
|
||||
except ValueError:
|
||||
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
||||
else:
|
||||
specific_state = None
|
||||
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
res = Project.get_many_with_join(
|
||||
projects = Project.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
if not include_stats:
|
||||
call.result.data = {"projects": res}
|
||||
call.result.data = {"projects": projects}
|
||||
return
|
||||
|
||||
ids = [project["id"] for project in res]
|
||||
ids = [project["id"] for project in projects]
|
||||
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
||||
ids, specific_state=specific_state
|
||||
)
|
||||
@@ -187,11 +201,11 @@ def get_all_ex(call):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(TaskVisibility.archived.value)
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.objects.aggregate(*status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
TaskVisibility.archived if k else TaskVisibility.active
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
@@ -219,32 +233,32 @@ def get_all_ex(call):
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in TaskVisibility if not specific_state or specific_state == s
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
for project in res:
|
||||
for project in projects:
|
||||
project["stats"] = {
|
||||
task_state.value: get_status_counts(project["id"], task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
|
||||
call.result.data = {"projects": res}
|
||||
call.result.data = {"projects": projects}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
res = Project.get_many(
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
parameters=call.data,
|
||||
allow_public=True,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
call.result.data = {"projects": res}
|
||||
call.result.data = {"projects": projects}
|
||||
|
||||
|
||||
@endpoint("projects.create", required_fields=["name", "description"])
|
||||
@@ -254,6 +268,7 @@ def create(call):
|
||||
|
||||
with translate_errors_context():
|
||||
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
||||
conform_tag_fields(call, fields)
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
@@ -271,7 +286,7 @@ def create(call):
|
||||
@endpoint(
|
||||
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call):
|
||||
def update(call: APICall):
|
||||
"""
|
||||
update
|
||||
|
||||
@@ -280,7 +295,6 @@ def update(call):
|
||||
:return: updated - `int` - number of projects updated
|
||||
fields - `[string]` - updated fields
|
||||
"""
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -291,9 +305,11 @@ def update(call):
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
updated = project.update(upsert=False, **fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
@@ -317,7 +333,7 @@ def delete(call):
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
res = cls.objects(
|
||||
project=project_id, tags__nin=[TaskVisibility.archived.value]
|
||||
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
|
||||
).only("id")
|
||||
if res and not force:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
@@ -329,12 +345,33 @@ def delete(call):
|
||||
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
|
||||
|
||||
|
||||
@endpoint("projects.get_unique_metric_variants")
|
||||
def get_unique_metric_variants(call, company_id, req_model):
|
||||
project_id = call.data.get("project")
|
||||
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
|
||||
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
company_id, [project_id] if project_id else None
|
||||
company_id, [request.project] if request.project else None
|
||||
)
|
||||
|
||||
call.result.data = {"metrics": metrics}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.2",
|
||||
request_data_model=GetHyperParamReq,
|
||||
response_data_model=GetHyperParamResp,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"remaining": remaining,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
45
server/services/server/__init__.py
Normal file
45
server/services/server/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pyhocon.config_tree import NoneValue
|
||||
|
||||
from config import config
|
||||
from config.info import get_version, get_build_number, get_commit_number
|
||||
from service_repo import ServiceRepo, APICall, endpoint
|
||||
|
||||
|
||||
@endpoint("server.config")
|
||||
def get_config(call: APICall):
|
||||
path = call.data.get("path")
|
||||
if path:
|
||||
c = dict(config.get(path))
|
||||
else:
|
||||
c = config.to_dict()
|
||||
|
||||
def remove_none_value(x):
|
||||
"""
|
||||
Pyhocon bug in Python 3: leaves dummy "NoneValue"s in tree,
|
||||
see: https://github.com/chimpler/pyhocon/issues/111
|
||||
"""
|
||||
if isinstance(x, dict):
|
||||
return {key: remove_none_value(value) for key, value in x.items()}
|
||||
if isinstance(x, list):
|
||||
return list(map(remove_none_value, x))
|
||||
if isinstance(x, NoneValue):
|
||||
return None
|
||||
return x
|
||||
|
||||
c.pop("secure", None)
|
||||
|
||||
call.result.data = remove_none_value(c)
|
||||
|
||||
|
||||
@endpoint("server.endpoints")
|
||||
def get_endpoints(call: APICall):
|
||||
call.result.data = ServiceRepo.endpoints_summary()
|
||||
|
||||
|
||||
@endpoint("server.info")
|
||||
def info(call: APICall):
|
||||
call.result.data = {
|
||||
"version": get_version(),
|
||||
"build": get_build_number(),
|
||||
"commit": get_commit_number(),
|
||||
}
|
||||
@@ -33,13 +33,14 @@ from database.model.task.output import Output
|
||||
from database.model.task.task import Task, TaskStatus, Script, DEFAULT_LAST_ITERATION
|
||||
from database.utils import get_fields, parse_from_call
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
|
||||
task_fields = set(Task.get_fields())
|
||||
task_script_fields = set(get_fields(Script))
|
||||
get_all_query_options = Task.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "type", "status", "project"),
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
@@ -79,11 +80,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
req_model.task, company_id=company_id, allow_public=True
|
||||
)
|
||||
task_dict = task.to_proper_dict()
|
||||
conform_output_tags(call, task_dict)
|
||||
call.result.data = {"task": task_dict}
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
@@ -91,13 +94,15 @@ def get_all_ex(call: APICall):
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
override_none_ordering=True,
|
||||
)
|
||||
|
||||
conform_output_tags(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
@@ -106,7 +111,9 @@ def get_all(call: APICall):
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
override_none_ordering=True,
|
||||
)
|
||||
conform_output_tags(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@@ -188,6 +195,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
|
||||
create_fields = {
|
||||
"name": None,
|
||||
"tags": list,
|
||||
"system_tags": list,
|
||||
"type": None,
|
||||
"error": None,
|
||||
"comment": None,
|
||||
@@ -219,10 +227,7 @@ def prepare_create_fields(
|
||||
output = Output(destination=output_dest)
|
||||
fields["output"] = output
|
||||
|
||||
# Make sure there are no duplicate tags
|
||||
tags = fields.get("tags")
|
||||
if tags:
|
||||
fields["tags"] = list(set(tags))
|
||||
conform_tag_fields(call, fields)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_fields:
|
||||
@@ -251,7 +256,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
|
||||
task = task_bll.create(call, fields)
|
||||
|
||||
with TimingContext("code", "validate"):
|
||||
task_bll.validate(task, force=call.data.get("force", False))
|
||||
task_bll.validate(task)
|
||||
|
||||
return task
|
||||
|
||||
@@ -272,16 +277,14 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
call.result.data = {"id": task.id}
|
||||
|
||||
|
||||
def prepare_update_fields(task, call_data):
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
valid_fields = deepcopy(task.__class__.user_set_allowed())
|
||||
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
||||
update_fields["output__error"] = None
|
||||
t_fields = task_fields
|
||||
t_fields.add("output__error")
|
||||
fields = parse_from_call(call_data, update_fields, t_fields)
|
||||
tags = fields.get("tags")
|
||||
if tags:
|
||||
fields["tags"] = list(set(tags))
|
||||
conform_tag_fields(call, fields)
|
||||
return fields, valid_fields
|
||||
|
||||
|
||||
@@ -296,7 +299,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(task, call.data)
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
|
||||
|
||||
if not partial_update_dict:
|
||||
return UpdateResponse(updated=0)
|
||||
@@ -309,7 +312,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
)
|
||||
|
||||
update_project_time(updated_fields.get("project"))
|
||||
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@@ -364,7 +367,7 @@ def update_batch(call: APICall):
|
||||
|
||||
bulk_ops = []
|
||||
for id, data in items.items():
|
||||
fields, valid_fields = prepare_update_fields(tasks[id], data)
|
||||
fields, valid_fields = prepare_update_fields(call, tasks[id], data)
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
@@ -421,7 +424,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
task_bll.validate(task_bll.create(call, fields), force=force)
|
||||
task_bll.validate(task_bll.create(call, fields))
|
||||
|
||||
# make sure field names do not end in mongoengine comparison operators
|
||||
fixed_fields = {
|
||||
@@ -434,6 +437,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
fixed_fields.update(last_update=now)
|
||||
updated = task.update(upsert=False, **fixed_fields)
|
||||
update_project_time(fields.get("project"))
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
@@ -463,6 +467,7 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
set__last_metrics={},
|
||||
unset__output__result=1,
|
||||
unset__output__model=1,
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
)
|
||||
|
||||
res = ResetResponse(
|
||||
@@ -670,7 +675,10 @@ def publish(call: APICall, company_id, req_model: PublishRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.completed", min_version="2.2", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
||||
"tasks.completed",
|
||||
min_version="2.2",
|
||||
request_data_model=UpdateRequest,
|
||||
response_data_model=UpdateResponse,
|
||||
)
|
||||
def completed(call: APICall, company_id, request: PublishRequest):
|
||||
call.result.data_model = UpdateResponse(
|
||||
@@ -688,4 +696,3 @@ def ping(_, company_id, request: PingRequest):
|
||||
TaskBLL.set_last_update(
|
||||
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
|
||||
52
server/services/utils.py
Normal file
52
server/services/utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Union, Sequence
|
||||
|
||||
from database.utils import partition_tags
|
||||
from service_repo import APICall
|
||||
from service_repo.base import PartialVersion
|
||||
|
||||
|
||||
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
if call.requested_endpoint_version >= PartialVersion("2.3"):
|
||||
return
|
||||
if isinstance(documents, dict):
|
||||
documents = [documents]
|
||||
for doc in documents:
|
||||
system_tags = doc.get("system_tags")
|
||||
if system_tags:
|
||||
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
|
||||
|
||||
|
||||
def conform_tag_fields(call: APICall, document: dict):
|
||||
"""
|
||||
Make sure that 'tags' from the old SDK clients
|
||||
are correctly split into 'tags' and 'system_tags'
|
||||
Make sure that there are no duplicate tags
|
||||
"""
|
||||
if call.requested_endpoint_version < PartialVersion("2.3"):
|
||||
service_name = call.endpoint_name.partition(".")[0]
|
||||
upgrade_tags(
|
||||
service_name[:-1] if service_name.endswith("s") else service_name, document
|
||||
)
|
||||
remove_duplicate_tags(document)
|
||||
|
||||
|
||||
def upgrade_tags(entity: str, document: dict):
|
||||
"""
|
||||
If only 'tags' is present in the fields then extract
|
||||
the system tags from it to a separate field 'system_tags'
|
||||
"""
|
||||
tags = document.get("tags")
|
||||
if tags is not None and not document.get("system_tags"):
|
||||
user_tags, system_tags = partition_tags(entity, tags)
|
||||
document["tags"] = user_tags
|
||||
document["system_tags"] = system_tags
|
||||
|
||||
|
||||
def remove_duplicate_tags(document: dict):
|
||||
"""
|
||||
Remove duplicates from 'tags' and 'system_tags' fields
|
||||
"""
|
||||
for name in ("tags", "system_tags"):
|
||||
values = document.get(name)
|
||||
if values:
|
||||
document[name] = list(set(values))
|
||||
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
Comprehensive test of all(?) use cases of datasets and frames
|
||||
"""
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import es_factory
|
||||
from tests.api_client import APIClient
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestDatasetsService(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = APIClient(base_url="http://localhost:5100/v1.0")
|
||||
self.created_tasks = []
|
||||
|
||||
self.task = dict(
|
||||
name="test task events",
|
||||
type="training",
|
||||
)
|
||||
res, self.task_id = self.api.send('tasks.create', self.task, extract="id")
|
||||
assert (res.meta.result_code == 200)
|
||||
self.created_tasks.append(self.task_id)
|
||||
|
||||
def tearDown(self):
|
||||
log.info("Cleanup...")
|
||||
for task_id in self.created_tasks:
|
||||
try:
|
||||
self.api.send('tasks.delete', dict(task=task_id, force=True))
|
||||
except Exception as ex:
|
||||
log.exception(ex)
|
||||
|
||||
def create_task_event(self, type, iteration):
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type,
|
||||
"task": self.task_id,
|
||||
"iter": iteration,
|
||||
"timestamp": es_factory.get_timestamp_millis()
|
||||
}
|
||||
|
||||
def copy_and_update(self, src_obj, new_data):
|
||||
obj = src_obj.copy()
|
||||
obj.update(new_data)
|
||||
return obj
|
||||
|
||||
def test_task_logs(self):
|
||||
events = []
|
||||
for iter in range(10):
|
||||
log_event = self.create_task_event("log", iteration=iter)
|
||||
events.append(self.copy_and_update(log_event, {
|
||||
"msg": "This is a log message from test task iter " + str(iter)
|
||||
}))
|
||||
# sleep so timestamp is not the same
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
assert len(data["events"]) == 10
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
assert len(data["events"]) == 0
|
||||
|
||||
def test_task_plots(self):
|
||||
event = self.create_task_event("plot", 0)
|
||||
event["metric"] = "roc"
|
||||
event.update({
|
||||
"plot_str": json.dumps({
|
||||
"data": [
|
||||
{
|
||||
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"y": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"text": ["Th=0.1", "Th=0.2", "Th=0.3", "Th=0.4", "Th=0.5", "Th=0.6", "Th=0.7", "Th=0.8"],
|
||||
"name": 'class1'
|
||||
},
|
||||
{
|
||||
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"y": [2.0, 3.0, 5.0, 8.2, 6.4, 7.5, 9.2, 8.1, 10.0],
|
||||
"text": ["Th=0.1", "Th=0.2", "Th=0.3", "Th=0.4", "Th=0.5", "Th=0.6", "Th=0.7", "Th=0.8"],
|
||||
"name": 'class2',
|
||||
}
|
||||
],
|
||||
"layout": {
|
||||
"title": "ROC for iter 0",
|
||||
"xaxis": {
|
||||
"title": 'my x axis'
|
||||
},
|
||||
"yaxis": {
|
||||
"title": 'my y axis'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
self.send(event)
|
||||
|
||||
event = self.create_task_event("plot", 100)
|
||||
event["metric"] = "confusion"
|
||||
event.update({
|
||||
"plot_str": json.dumps({
|
||||
"data": [
|
||||
{
|
||||
"y": [
|
||||
"lying",
|
||||
"sitting",
|
||||
"standing",
|
||||
"people",
|
||||
"backgroun"
|
||||
],
|
||||
"x": [
|
||||
"lying",
|
||||
"sitting",
|
||||
"standing",
|
||||
"people",
|
||||
"backgroun"
|
||||
],
|
||||
"z": [
|
||||
[758, 163, 0, 0, 23],
|
||||
[63, 858, 3, 0, 0],
|
||||
[0, 50, 188, 21, 35],
|
||||
[0, 22, 8, 40, 4, ],
|
||||
[12, 91, 26, 29, 368]
|
||||
],
|
||||
"type": "heatmap"
|
||||
}
|
||||
],
|
||||
"layout": {
|
||||
"title": "Confusion Matrix for iter 100",
|
||||
"xaxis": {
|
||||
"title": "Predicted value"
|
||||
},
|
||||
"yaxis": {
|
||||
"title": "Real value"
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
self.send(event)
|
||||
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
assert len(data["plots"]) == 2
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
assert len(data["plots"]) == 0
|
||||
|
||||
def send_batch(self, events):
|
||||
self.api.send_batch('events.add_batch', events)
|
||||
|
||||
def send(self, event):
|
||||
self.api.send('events.add', event)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
218
server/tests/automated/test_tags.py
Normal file
218
server/tests/automated/test_tags.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from time import sleep
|
||||
from typing import Sequence
|
||||
|
||||
from apierrors.errors import bad_request
|
||||
from database.utils import partition_tags
|
||||
from tests.api_client import APIClient, AttrDict
|
||||
from tests.automated import TestService
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestTags(TestService):
|
||||
def setUp(self, version="2.3"):
|
||||
super().setUp(version)
|
||||
|
||||
def testPartition(self):
|
||||
tags, system_tags = partition_tags("project", ["test"])
|
||||
self.assertTagsEqual(tags, ["test"])
|
||||
self.assertTagsEqual(system_tags, [])
|
||||
|
||||
tags, system_tags = partition_tags("project", ["test", "archived"])
|
||||
self.assertTagsEqual(tags, ["test"])
|
||||
self.assertTagsEqual(system_tags, ["archived"])
|
||||
|
||||
tags, system_tags = partition_tags("project", ["test", "archived"], ["custom"])
|
||||
self.assertTagsEqual(tags, ["test"])
|
||||
self.assertTagsEqual(system_tags, ["archived", "custom"])
|
||||
|
||||
tags, system_tags = partition_tags(
|
||||
"task", ["test", "development", "annotator20", "Annotation"]
|
||||
)
|
||||
self.assertTagsEqual(tags, ["test"])
|
||||
self.assertTagsEqual(system_tags, ["development", "annotator20", "Annotation"])
|
||||
|
||||
def testBackwardsCompatibility(self):
|
||||
new_api = self.api
|
||||
self.api = APIClient(base_url="http://localhost:8008/v2.2")
|
||||
entity_tags = {
|
||||
"model": "archived",
|
||||
"project": "public",
|
||||
"task": "development",
|
||||
}
|
||||
|
||||
for name, system_tag in entity_tags.items():
|
||||
create_func = getattr(self, f"_temp_{name}")
|
||||
_id = create_func(tags=[system_tag, "test"])
|
||||
names = f"{name}s"
|
||||
|
||||
# when accessed through the old api all the tags are in the tags field
|
||||
self.assertGetById(
|
||||
service=names, entity=name, _id=_id, tags=[system_tag, "test"]
|
||||
)
|
||||
entities = self._send(
|
||||
names, "get_all", name="Test tags", tags=[f"-{system_tag}"]
|
||||
)[names]
|
||||
self.assertNotFound(_id, entities)
|
||||
|
||||
# when accessed through the new api the tags are in tags and system_tags fields
|
||||
self.assertGetById(
|
||||
service=names,
|
||||
entity=name,
|
||||
_id=_id,
|
||||
tags=["test"],
|
||||
system_tags=[system_tag],
|
||||
api=new_api,
|
||||
)
|
||||
|
||||
# update operation, remove system tag through the old api
|
||||
self._send(names, "update", tags=["test"], **{name: _id})
|
||||
self.assertGetById(service=names, entity=name, _id=_id, tags=["test"])
|
||||
|
||||
def testProjectTags(self):
|
||||
pr_id = self._temp_project(system_tags=["default"])
|
||||
|
||||
# Test getting project with system tags
|
||||
projects = self.api.projects.get_all(name="Test tags").projects
|
||||
self.assertFound(pr_id, ["default"], projects)
|
||||
|
||||
projects = self.api.projects.get_all(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).projects
|
||||
self.assertFound(pr_id, ["default"], projects)
|
||||
|
||||
projects = self.api.projects.get_all(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).projects
|
||||
self.assertNotFound(pr_id, projects)
|
||||
|
||||
self.api.projects.update(project=pr_id, system_tags=[])
|
||||
projects = self.api.projects.get_all(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).projects
|
||||
self.assertFound(pr_id, [], projects)
|
||||
|
||||
# Test task statistics and delete
|
||||
task1_id = self._temp_task(
|
||||
name="Tags test1", project=pr_id, system_tags=["active"]
|
||||
)
|
||||
self._run_task(task1_id)
|
||||
task2_id = self._temp_task(
|
||||
name="Tags test2", project=pr_id, system_tags=["archived"]
|
||||
)
|
||||
projects = self.api.projects.get_all_ex(name="Test tags").projects
|
||||
self.assertFound(pr_id, [], projects)
|
||||
|
||||
projects = self.api.projects.get_all_ex(
|
||||
name="Test tags", include_stats=True
|
||||
).projects
|
||||
project = next(p for p in projects if p.id == pr_id)
|
||||
self.assertProjectStats(project)
|
||||
|
||||
with self.api.raises(bad_request.ProjectHasTasks):
|
||||
self.api.projects.delete(project=pr_id)
|
||||
self.api.projects.delete(project=pr_id, force=True)
|
||||
|
||||
def testModelTags(self):
|
||||
model_id = self._temp_model(system_tags=["default"])
|
||||
models = self.api.models.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).models
|
||||
self.assertFound(model_id, ["default"], models)
|
||||
|
||||
models = self.api.models.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).models
|
||||
self.assertNotFound(model_id, models)
|
||||
|
||||
self.api.models.update(model=model_id, system_tags=[])
|
||||
models = self.api.models.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).models
|
||||
self.assertFound(model_id, [], models)
|
||||
|
||||
def testTaskTags(self):
|
||||
task_id = self._temp_task(
|
||||
name="Test tags", system_tags=["active"]
|
||||
)
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
name="Test tags", system_tags=["active"]
|
||||
).tasks
|
||||
self.assertFound(task_id, ["active"], tasks)
|
||||
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
name="Test tags", system_tags=["-active"]
|
||||
).tasks
|
||||
self.assertNotFound(task_id, tasks)
|
||||
|
||||
self.api.tasks.update(task=task_id, system_tags=[])
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
name="Test tags", system_tags=["-active"]
|
||||
).tasks
|
||||
self.assertFound(task_id, [], tasks)
|
||||
|
||||
# test development system tag
|
||||
self.api.tasks.started(task=task_id)
|
||||
self.api.tasks.stop(task=task_id)
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self.assertEqual(task.status, "in_progress")
|
||||
self.api.tasks.update(task=task_id, system_tags=["development"])
|
||||
self.api.tasks.stop(task=task_id)
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self.assertEqual(task.status, "stopped")
|
||||
|
||||
def assertProjectStats(self, project: AttrDict):
|
||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||
for status, count in project.stats.active.status_count.items():
|
||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||
|
||||
def _run_task(self, task_id):
|
||||
"""Imitate 1 second of running"""
|
||||
self.api.tasks.started(task=task_id)
|
||||
sleep(1)
|
||||
self.api.tasks.stopped(task=task_id)
|
||||
|
||||
def _temp_project(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags", description="test")
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
def _temp_model(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", **kwargs)
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _update_missing(target: dict, **update):
|
||||
target.update({k: v for k, v in update.items() if k not in target})
|
||||
|
||||
def _send(self, service, action, **kwargs):
|
||||
api = kwargs.pop("api", self.api)
|
||||
return AttrDict(
|
||||
api.send(f"{service}.{action}", kwargs)[1]
|
||||
)
|
||||
|
||||
def assertGetById(self, service, entity, _id, tags, system_tags=None, **kwargs):
|
||||
entity = self._send(service, "get_by_id", **{entity: _id}, **kwargs)[entity]
|
||||
self.assertEqual(set(entity.tags), set(tags))
|
||||
if system_tags is not None:
|
||||
self.assertEqual(set(entity.system_tags), set(system_tags))
|
||||
|
||||
def assertFound(
|
||||
self, _id: str, system_tags: Sequence[str], res: Sequence[AttrDict]
|
||||
):
|
||||
found = next((r for r in res if _id == r.id), None)
|
||||
assert found
|
||||
self.assertTagsEqual(found.system_tags, system_tags)
|
||||
|
||||
def assertNotFound(
|
||||
self, _id: str, res: Sequence[AttrDict]
|
||||
):
|
||||
self.assertFalse(any(r for r in res if r.id == _id))
|
||||
|
||||
def assertTagsEqual(self, tags: Sequence[str], expected_tags: Sequence[str]):
|
||||
self.assertEqual(set(tags), set(expected_tags))
|
||||
240
server/tests/automated/test_task_events.py
Normal file
240
server/tests/automated/test_task_events.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Comprehensive test of all(?) use cases of datasets and frames
|
||||
"""
|
||||
import json
|
||||
import unittest
|
||||
from statistics import mean
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
def setUp(self, version="1.7"):
|
||||
super().setUp(version=version)
|
||||
|
||||
self.created_tasks = []
|
||||
|
||||
self.task = dict(
|
||||
name="test task events",
|
||||
type="training",
|
||||
input=dict(mapping={}, view=dict(entries=[])),
|
||||
)
|
||||
res, self.task_id = self.api.send("tasks.create", self.task, extract="id")
|
||||
assert res.meta.result_code == 200
|
||||
self.created_tasks.append(self.task_id)
|
||||
|
||||
def tearDown(self):
|
||||
log.info("Cleanup...")
|
||||
for task_id in self.created_tasks:
|
||||
try:
|
||||
self.api.send("tasks.delete", dict(task=task_id, force=True))
|
||||
except Exception as ex:
|
||||
log.exception(ex)
|
||||
|
||||
def create_task_event(self, type, iteration):
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type,
|
||||
"task": self.task_id,
|
||||
"iter": iteration,
|
||||
"timestamp": es_factory.get_timestamp_millis()
|
||||
}
|
||||
|
||||
def copy_and_update(self, src_obj, new_data):
|
||||
obj = src_obj.copy()
|
||||
obj.update(new_data)
|
||||
return obj
|
||||
|
||||
def test_task_logs(self):
|
||||
events = []
|
||||
for iter in range(10):
|
||||
log_event = self.create_task_event("log", iteration=iter)
|
||||
events.append(
|
||||
self.copy_and_update(
|
||||
log_event,
|
||||
{"msg": "This is a log message from test task iter " + str(iter)},
|
||||
)
|
||||
)
|
||||
# sleep so timestamp is not the same
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
assert len(data["events"]) == 10
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
assert len(data["events"]) == 0
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
events = [
|
||||
{
|
||||
**self.create_task_event("training_stats_scalar", iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
for key in None, "iter", "timestamp", "iso_time":
|
||||
with self.subTest(key=key):
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, key=key)
|
||||
self.assertIn(metric, data)
|
||||
self.assertIn(variant, data[metric])
|
||||
self.assertIn("x", data[metric][variant])
|
||||
self.assertIn("y", data[metric][variant])
|
||||
|
||||
def test_task_metric_value_intervals(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
events = [
|
||||
{
|
||||
**self.create_task_event("training_stats_scalar", iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=100)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=10)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 10)
|
||||
|
||||
def _assert_metrics_histogram(self, data, iters, samples):
|
||||
interval = iters // samples
|
||||
self.assertEqual(len(data["x"]), samples)
|
||||
self.assertEqual(len(data["y"]), samples)
|
||||
for curr in range(samples):
|
||||
self.assertEqual(data["x"][curr], curr * interval)
|
||||
self.assertEqual(
|
||||
data["y"][curr],
|
||||
mean(v for v in range(curr * interval, (curr + 1) * interval)),
|
||||
)
|
||||
|
||||
def test_task_plots(self):
|
||||
event = self.create_task_event("plot", 0)
|
||||
event["metric"] = "roc"
|
||||
event.update(
|
||||
{
|
||||
"plot_str": json.dumps(
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"y": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"text": [
|
||||
"Th=0.1",
|
||||
"Th=0.2",
|
||||
"Th=0.3",
|
||||
"Th=0.4",
|
||||
"Th=0.5",
|
||||
"Th=0.6",
|
||||
"Th=0.7",
|
||||
"Th=0.8",
|
||||
],
|
||||
"name": "class1",
|
||||
},
|
||||
{
|
||||
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"y": [2.0, 3.0, 5.0, 8.2, 6.4, 7.5, 9.2, 8.1, 10.0],
|
||||
"text": [
|
||||
"Th=0.1",
|
||||
"Th=0.2",
|
||||
"Th=0.3",
|
||||
"Th=0.4",
|
||||
"Th=0.5",
|
||||
"Th=0.6",
|
||||
"Th=0.7",
|
||||
"Th=0.8",
|
||||
],
|
||||
"name": "class2",
|
||||
},
|
||||
],
|
||||
"layout": {
|
||||
"title": "ROC for iter 0",
|
||||
"xaxis": {"title": "my x axis"},
|
||||
"yaxis": {"title": "my y axis"},
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
event = self.create_task_event("plot", 100)
|
||||
event["metric"] = "confusion"
|
||||
event.update(
|
||||
{
|
||||
"plot_str": json.dumps(
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"y": [
|
||||
"lying",
|
||||
"sitting",
|
||||
"standing",
|
||||
"people",
|
||||
"backgroun",
|
||||
],
|
||||
"x": [
|
||||
"lying",
|
||||
"sitting",
|
||||
"standing",
|
||||
"people",
|
||||
"backgroun",
|
||||
],
|
||||
"z": [
|
||||
[758, 163, 0, 0, 23],
|
||||
[63, 858, 3, 0, 0],
|
||||
[0, 50, 188, 21, 35],
|
||||
[0, 22, 8, 40, 4],
|
||||
[12, 91, 26, 29, 368],
|
||||
],
|
||||
"type": "heatmap",
|
||||
}
|
||||
],
|
||||
"layout": {
|
||||
"title": "Confusion Matrix for iter 100",
|
||||
"xaxis": {"title": "Predicted value"},
|
||||
"yaxis": {"title": "Real value"},
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
assert len(data["plots"]) == 2
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
assert len(data["plots"]) == 0
|
||||
|
||||
def send_batch(self, events):
|
||||
self.api.send_batch("events.add_batch", events)
|
||||
|
||||
def send(self, event):
|
||||
self.api.send("events.add", event)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -36,8 +36,8 @@ class TestTasksResetDelete(TestService):
|
||||
|
||||
TASK_CANNOT_BE_DELETED_CODES = (400, 123)
|
||||
|
||||
def setUp(self):
|
||||
super(TestTasksResetDelete, self).setUp()
|
||||
def setUp(self, version="1.7"):
|
||||
super(TestTasksResetDelete, self).setUp(version=version)
|
||||
self.tasks = self.api.tasks
|
||||
self.models = self.api.models
|
||||
|
||||
108
server/tests/automated/test_tasks_ordering.py
Normal file
108
server/tests/automated/test_tasks_ordering.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import operator
|
||||
from time import sleep
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestTasksOrdering(TestService):
|
||||
test_comment = "Task ordering test"
|
||||
only_fields = ["id", "started", "comment"]
|
||||
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(**kwargs)
|
||||
self.task_ids = self._create_tasks()
|
||||
|
||||
def test_order(self):
|
||||
# test no ordering
|
||||
self._assertGetTasksWithOrdering()
|
||||
|
||||
# sort ascending
|
||||
self._assertGetTasksWithOrdering(order_by="started")
|
||||
|
||||
# sort descending
|
||||
self._assertGetTasksWithOrdering(order_by="-started")
|
||||
|
||||
# sort by the same field that we use for the search
|
||||
self._assertGetTasksWithOrdering(order_by="comment")
|
||||
|
||||
def test_order_with_paging(self):
|
||||
order_field = "started"
|
||||
# all results in one page
|
||||
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
|
||||
|
||||
field_vals = []
|
||||
page_size = 2
|
||||
num_pages = 5
|
||||
for page in range(num_pages):
|
||||
paged_tasks = self._get_page_tasks(
|
||||
order_by=order_field, page=page, page_size=page_size
|
||||
)
|
||||
self.assertEqual(len(paged_tasks), page_size)
|
||||
field_vals.extend(t.get(order_field) for t in paged_tasks)
|
||||
|
||||
paged_tasks = self._get_page_tasks(
|
||||
order_by=order_field, page=num_pages, page_size=page_size
|
||||
)
|
||||
self.assertTrue(not paged_tasks)
|
||||
|
||||
self._assertSorted(field_vals)
|
||||
|
||||
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
|
||||
return self.api.tasks.get_all_ex(
|
||||
only_fields=self.only_fields,
|
||||
order_by=order_by,
|
||||
comment=self.test_comment,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
).tasks
|
||||
|
||||
def _assertSorted(self, vals: Sequence, ascending=True):
|
||||
"""
|
||||
Assert that vals are sorted in the ascending or descending order
|
||||
with None values are always coming from the end
|
||||
"""
|
||||
if None in vals:
|
||||
first_null_idx = vals.index(None)
|
||||
none_tail = vals[first_null_idx:]
|
||||
vals = vals[:first_null_idx]
|
||||
self.assertTrue(all(val is None for val in none_tail))
|
||||
self.assertTrue(all(val is not None for val in vals))
|
||||
|
||||
if ascending:
|
||||
cmp = operator.le
|
||||
else:
|
||||
cmp = operator.ge
|
||||
self.assertTrue(all(cmp(i, j) for i, j in zip(vals, vals[1:])))
|
||||
|
||||
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
only_fields=self.only_fields,
|
||||
order_by=order_by,
|
||||
comment=self.test_comment,
|
||||
**kwargs,
|
||||
).tasks
|
||||
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
|
||||
if order_by:
|
||||
# test that the output is correctly ordered
|
||||
field_name = order_by if not order_by.startswith("-") else order_by[1:]
|
||||
field_vals = [t.get(field_name) for t in tasks]
|
||||
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
|
||||
|
||||
def _create_tasks(self):
|
||||
tasks = [self._temp_task() for _ in range(10)]
|
||||
for _, task in zip(range(5), tasks):
|
||||
self.api.tasks.started(task=task)
|
||||
sleep(0.1)
|
||||
return tasks
|
||||
|
||||
def _temp_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
name="test",
|
||||
comment=self.test_comment,
|
||||
type="testing",
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
||||
17
server/utilities/dicts.py
Normal file
17
server/utilities/dicts.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Sequence, Tuple, Any
|
||||
|
||||
|
||||
def flatten_nested_items(
|
||||
dictionary: dict, nesting: int = None, include_leaves=None, prefix=None
|
||||
) -> Sequence[Tuple[Tuple[str, ...], Any]]:
|
||||
"""
|
||||
iterate through dictionary and return with nested keys flattened into a tuple
|
||||
"""
|
||||
next_nesting = None if nesting is None else (nesting - 1)
|
||||
prefix = prefix or ()
|
||||
for key, value in dictionary.items():
|
||||
path = prefix + (key,)
|
||||
if isinstance(value, dict) and nesting != 0:
|
||||
yield from flatten_nested_items(value, next_nesting, include_leaves, prefix=path)
|
||||
elif include_leaves is None or key in include_leaves:
|
||||
yield path, value
|
||||
Reference in New Issue
Block a user