Add API version 2.4 with new trains-server capabilities including DevOps and scheduling

This commit is contained in:
allegroai
2019-10-25 15:36:58 +03:00
parent 2ea25e498f
commit 1a732ccd8e
54 changed files with 4964 additions and 341 deletions

View File

@@ -1,12 +1,12 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence
from typing import Collection, Sequence, Union
from boltons.iterutils import first
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document
from six import string_types
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
from apierrors import errors
from config import config
@@ -16,9 +16,10 @@ from database.props import PropsMixin
from database.query import RegexQ, RegexWrapper
from database.utils import (
get_company_or_none_constraint,
get_fields_with_attr,
field_exists,
get_fields_choices,
field_does_not_exist,
field_exists,
get_fields,
)
log = config.logger("dbmodel")
@@ -62,6 +63,7 @@ class GetMixin(PropsMixin):
_text_score = "$text_score"
_ordering_key = "order_by"
_search_text_key = "search_text"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
@@ -221,6 +223,24 @@ class GetMixin(PropsMixin):
return get_company_or_none_constraint(company)
return Q(company=company)
@classmethod
def validate_order_by(cls, parameters, search_text) -> Sequence:
"""
Validate and extract order_by params as a list
"""
order_by = parameters.get(cls._ordering_key)
if not order_by:
return []
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
if not search_text and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
return order_by
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
@@ -267,7 +287,6 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
override_none_ordering=False,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@@ -303,7 +322,6 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
override_none_ordering=override_none_ordering,
)
def projection_func(doc_type, projection, ids):
@@ -328,7 +346,6 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
override_none_ordering=False,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@@ -341,8 +358,9 @@ class GetMixin(PropsMixin):
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
be raised.
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
requested, each contains only the requested projection). If False, a QuerySet object is returned
(lazy evaluated). If return_dicts is requested then the entities with the None value in order_by field
are returned last in the ordering.
:param company: Company ID (required)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
@@ -352,8 +370,6 @@ class GetMixin(PropsMixin):
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:param override_none_ordering: If True, then items with the None values in the first ordered field
are always sorted in the end
:return: A list of objects matching the query.
"""
if query_dict is not None:
@@ -367,26 +383,19 @@ class GetMixin(PropsMixin):
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
if override_none_ordering:
if return_dicts:
return cls._get_many_override_none_ordering(
query=_query,
parameters=parameters,
query_dict=query_dict,
query_options=query_options,
override_projection=override_projection,
)
return cls._get_many_no_company(
query=_query,
parameters=parameters,
override_projection=override_projection,
return_dicts=return_dicts,
query=_query, parameters=parameters, override_projection=override_projection
)
@classmethod
def _get_many_no_company(
cls, query, parameters=None, override_projection=None, return_dicts=True
):
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
@@ -395,44 +404,25 @@ class GetMixin(PropsMixin):
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
:param query: Query object (mongoengine.Q)
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
parameters = parameters or {}
if not query:
raise ValueError("query or call_data must be provided")
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
order_by = parameters.get(cls._ordering_key)
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
search_text = parameters.get("search_text")
only = cls.get_projection(parameters, override_projection)
if not search_text and order_by and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
qs = cls.objects(query)
if search_text:
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = (
qs.order_by(order_by)
if isinstance(order_by, string_types)
else qs.order_by(*order_by)
)
qs = qs.order_by(*order_by)
if only:
# add projection
qs = qs.only(*only)
@@ -444,17 +434,13 @@ class GetMixin(PropsMixin):
# add paging
qs = qs.skip(page * page_size).limit(page_size)
if return_dicts:
return [obj.to_proper_dict(only=only) for obj in qs]
return qs
@classmethod
def _get_many_override_none_ordering(
cls,
cls: Union[Document, "GetMixin"],
query: Q = None,
parameters: dict = None,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
override_projection: Collection[str] = None,
) -> Sequence[dict]:
"""
@@ -467,57 +453,45 @@ class GetMixin(PropsMixin):
:param query: Query object (mongoengine.Q)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
a query. The resulting query is AND'ed with the `query` parameter (if provided).
:param query_options: query parameters options (see ParametersOptions)
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
if not query:
raise ValueError("query or call_data must be provided")
parameters = parameters or {}
search_text = parameters.get("search_text")
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
query_sets = []
order_by = parameters.get(cls._ordering_key)
query_sets = [cls.objects(query)]
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
if not search_text and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
order_field = first(
field for field in order_by if not field.startswith("$")
)
if (
order_field
and not order_field.startswith("-")
and (not query_dict or order_field not in query_dict)
and "[" not in order_field
):
empty_value = None
if order_field in query_options.list_fields:
empty_value = []
elif order_field in query_options.pattern_fields:
empty_value = ""
params = {}
mongo_field = order_field.replace(".", "__")
non_empty = query & field_exists(mongo_field, empty_value=empty_value)
empty = query & field_does_not_exist(
mongo_field, empty_value=empty_value
)
if mongo_field in get_fields(cls, of_type=ListField, subfields=True):
params["is_list"] = True
elif mongo_field in get_fields(
cls, of_type=StringField, subfields=True
):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
if not query_sets:
query_sets = [cls.objects(query)]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if order_by:
# add ordering
query_sets = [qs.order_by(*order_by) for qs in query_sets]
only = cls.get_projection(parameters, override_projection)
if only:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
@@ -583,8 +557,8 @@ class UpdateMixin(object):
def user_set_allowed(cls):
res = getattr(cls, "__user_set_allowed_fields", None)
if res is None:
res = cls.__user_set_allowed_fields = dict(
get_fields_with_attr(cls, "user_set_allowed")
res = cls.__user_set_allowed_fields = get_fields_choices(
cls, "user_set_allowed"
)
return res
@@ -622,7 +596,24 @@ class UpdateMixin(object):
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
pass
@classmethod
def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
) -> CommandCursor:
"""
Aggregate objects of this document class according to the provided pipeline.
:param pipeline: a list of dictionaries describing the pipeline stages
:param allow_disk_use: if True, allow the server to use disk space if aggregation query cannot fit in memory.
If None, default behavior will be used (see apiserver.conf/mongo/aggregate/allow_disk_use)
:param kwargs: additional keyword arguments passed to mongoengine
:return:
"""
kwargs.update(
allowDiskUse=allow_disk_use
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
)
return cls.objects.aggregate(*pipeline, **kwargs)
def validate_id(cls, company, **kwargs):

View File

@@ -0,0 +1,47 @@
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.model import DbModelMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.company import Company
from database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
added = DateTimeField(required=True)
''' Added to the queue '''
class Queue(DbModelMixin, Document):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True)
name = StrippedStringField(
required=True, unique_with="company", min_length=3, user_set_allowed=True
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()

View File

@@ -29,6 +29,7 @@ DEFAULT_LAST_ITERATION = 0
class TaskStatus(object):
created = "created"
queued = "queued"
in_progress = "in_progress"
stopped = "stopped"
publishing = "publishing"
@@ -85,7 +86,7 @@ class Execution(EmbeddedDocument):
model_labels = ModelLabels()
framework = StringField()
artifacts = EmbeddedDocumentSortedListField(Artifact)
docker_cmd = StringField()
queue = StringField()
""" Queue ID where task was queued """
@@ -150,6 +151,8 @@ class Task(AttributedDocument):
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))

View File

@@ -1,6 +1,7 @@
from collections import OrderedDict
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
@@ -11,7 +12,7 @@ from database.fields import (
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
)
from database.utils import get_fields, get_fields_and_attr
from database.utils import get_fields, get_fields_attr
class PropsMixin(object):
@@ -42,7 +43,7 @@ class PropsMixin(object):
@staticmethod
def _get_fields_with_attr(cls_, attr):
""" Get all fields with the specified attribute (supports nested fields) """
res = get_fields_and_attr(cls_, attr=attr)
res = get_fields_attr(cls_, attr=attr)
def resolve_doc(v):
if not isinstance(v, six.string_types):
@@ -122,6 +123,14 @@ class PropsMixin(object):
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@classmethod
def get_extra_projection(cls, fields: Sequence) -> tuple:
if isinstance(fields, str):
fields = [fields]
return tuple(
set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
)
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
@@ -140,3 +149,18 @@ class PropsMixin(object):
result = separator.join(translated)
cls.__cached_dpath_computed_fields[path] = result
return cls.__cached_dpath_computed_fields[path]
def get_field_value(self, field_path: str, default=None):
"""
Return the document field_path value by the field_path name.
The path may contain '.'. If on any level the path is
not found then the default value is returned
"""
path_elements = field_path.split(".")
current = self
for name in path_elements:
current = getattr(current, name, default)
if current == default:
break
return current

View File

@@ -1,6 +1,6 @@
import hashlib
from inspect import ismethod, getmembers
from typing import Sequence, Tuple, Set, Optional
from typing import Sequence, Tuple, Set, Optional, Callable, Any
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
@@ -9,64 +9,58 @@ from mongoengine.base import BaseField
from .errors import translate_errors_context, ParseCallError
def get_fields(cls, of_type=BaseField, return_instance=False):
def get_fields(cls, of_type=BaseField, return_instance=False, subfields=False):
return _get_fields(
cls,
of_type=of_type,
subfields=subfields,
selector=lambda k, v: (k, v) if return_instance else k,
)
def get_fields_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = []
for cls_ in reversed(cls.mro()):
res.extend(
[
k if not return_instance else (k, v)
for k, v in vars(cls_).items()
if isinstance(v, of_type)
]
)
return res
return dict(
_get_fields(cls, with_attr=attr, selector=lambda k, v: (k, getattr(v, attr)))
)
def get_fields_and_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = {}
for cls_ in reversed(cls.mro()):
res.update(
{
k: getattr(v, attr)
for k, v in vars(cls_).items()
if isinstance(v, BaseField) and hasattr(v, attr)
}
)
return res
def get_fields_choices(cls, attr):
def get_choices(field_name: str, field: BaseField) -> Tuple:
if isinstance(field, ListField):
return field_name, field.field.choices
return field_name, field.choices
return dict(_get_fields(cls, with_attr=attr, subfields=True, selector=get_choices))
def _get_field_choices(name, field):
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
obj = field.document_type_obj
n, choices = _get_field_choices(field.name, obj.field)
return "%s__%s" % (name, n), choices
elif issubclass(type(field), ListField):
return name, field.field.choices
return name, field.choices
def get_fields_with_attr(cls, attr, default=False):
def _get_fields(
cls,
with_attr=None,
of_type=BaseField,
subfields=False,
selector: Optional[Callable[[str, BaseField], Any]] = None,
path: Tuple[str, ...] = (),
):
fields = []
for field_name, field in cls._fields.items():
if not getattr(field, attr, default):
continue
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
field_path = path + (field_name,)
if isinstance(field, of_type) and (not with_attr or hasattr(field, with_attr)):
full_name = "__".join(field_path)
fields.append(selector(full_name, field) if selector else full_name)
if subfields and isinstance(field, EmbeddedDocumentField):
fields.extend(
(
("%s__%s" % (field_name, name), choices)
for name, choices in get_fields_with_attr(
field.document_type, attr, default
)
_get_fields(
field.document_type,
with_attr=with_attr,
of_type=of_type,
subfields=subfields,
selector=selector,
path=field_path,
)
)
elif issubclass(type(field), ListField):
fields.append((field_name, field.field.choices))
else:
fields.append((field_name, field.choices))
return fields
@@ -151,17 +145,20 @@ def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
return query
def field_exists(field: str, empty_value=None) -> Q:
def field_exists(field: str, empty_value=None, is_list=False) -> Q:
"""
Creates a query object used for finding a field that exists and is not None or empty.
:param field: Field name
:param empty_value: The empty value to test for (None means no specific empty value will be used).
For lists pass [] for empty_value
:param empty_value: The empty value to test for (None means no specific empty value will be used)
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
the length of the array will be used (len==0 means empty)
:return:
"""
query = Q(**{f"{field}__exists": True}) & Q(
**{f"{field}__nin": {empty_value, None}}
)
if is_list:
query &= Q(**{f"{field}__not__size": 0})
return query
@@ -213,6 +210,7 @@ system_tag_names = {
"model": _names_set("active", "archived"),
"project": _names_set("archived", "public", "default"),
"task": _names_set("active", "archived", "development"),
"queue": _names_set("default"),
}
system_tag_prefixes = {"task": _names_set("annotat")}