clearml-server/apiserver/apimodels/base.py
2021-01-05 18:11:22 +02:00

277 lines
7.9 KiB
Python

from __future__ import absolute_import
from enum import Enum
from typing import Union, Type, Iterable
import jsonmodels.errors
import six
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from jsonmodels import models, fields
from jsonmodels.validators import Length
from mongoengine.base import BaseDocument
from apiserver.utilities.json import loads, dumps
def make_default(field_cls, default_value):
class _FieldWithDefault(field_cls):
def get_default_value(self):
return default_value
return _FieldWithDefault
class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default):
default = default()
super(ListField, self).__init__(items_types, *args, default=default, **kwargs)
def _cast_value(self, value):
try:
return super(ListField, self)._cast_value(value)
except TypeError:
if len(self.items_types) == 1 and issubclass(self.items_types[0], Enum):
return self.items_types[0](value)
return value
def validate_single_value(self, item):
super(ListField, self).validate_single_value(item)
if isinstance(item, ModelBase):
item.validate()
# since there is no distinction between None and empty DictField
# this value can be used as sentinel in order to distinguish
# between not set and empty DictField
DictFieldNotSet = {}
class DictField(fields.BaseField):
types = (dict,)
def __init__(self, value_types=None, *args, **kwargs):
self.value_types = self._assign_types(value_types)
super(DictField, self).__init__(*args, **kwargs)
def get_default_value(self):
default = super(DictField, self).get_default_value()
if default is None and not self.required:
return {}
return default
@staticmethod
def _assign_types(value_types):
if value_types:
try:
value_types = tuple(value_types)
except TypeError:
value_types = (value_types,)
else:
value_types = tuple()
return tuple(
_LazyType(type_) if isinstance(type_, six.string_types) else type_
for type_ in value_types
)
def validate(self, value):
super(DictField, self).validate(value)
if not self.value_types:
return
if not value:
return
for item in value.values():
self.validate_single_value(item)
def validate_single_value(self, item):
if not self.value_types:
return
if not isinstance(item, self.value_types):
raise jsonmodels.errors.ValidationError(
"All items must be instances "
'of "{types}", and not "{type}".'.format(
types=", ".join([t.__name__ for t in self.value_types]),
type=type(item).__name__,
)
)
class IntField(fields.IntField):
def parse_value(self, value):
try:
return super(IntField, self).parse_value(value)
except (ValueError, TypeError):
return value
class NullableEnumValidator(EnumValidator):
"""Validator for enums that allows a None value."""
def validate(self, value):
if value is not None:
super(NullableEnumValidator, self).validate(value)
class EnumField(fields.StringField):
def __init__(
self,
values_or_type: Union[Iterable, Type[Enum]],
*args,
required=False,
default=None,
**kwargs
):
choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator
kwargs.setdefault("validators", []).append(validator_cls(*choices))
super().__init__(
default=self.parse_value(default), required=required, *args, **kwargs
)
def parse_value(self, value):
if isinstance(value, Enum):
return str(value.value)
return super().parse_value(value)
class ActualEnumField(fields.StringField):
def __init__(
self,
enum_class: Type[Enum],
*args,
validators=None,
required=False,
default=None,
**kwargs
):
self.__enum = enum_class
self.types = (enum_class,)
# noinspection PyTypeChecker
choices = list(enum_class)
validator_cls = EnumValidator if required else NullableEnumValidator
validators = [*(validators or []), validator_cls(*choices)]
super().__init__(
default=self.parse_value(default) if default else NotSet,
*args,
required=required,
validators=validators,
**kwargs
)
def parse_value(self, value):
if value is None and not self.required:
return self.get_default_value()
try:
# noinspection PyArgumentList
return self.__enum(value)
except ValueError:
return value
def to_struct(self, value):
return super().to_struct(value.value)
class JsonSerializableMixin:
def to_json(self: ModelBase):
return dumps(self.to_struct())
@classmethod
def from_json(cls: Type[ModelBase], s):
return cls(**loads(s))
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
class _Wrapped(cls):
_callable_default = None
def get_default_value(self):
if self._callable_default:
return self._callable_default()
return super(_Wrapped, self).get_default_value()
def __init__(self, *args, default=None, **kwargs):
if default and callable(default):
self._callable_default = default
default = default()
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
return _Wrapped
class MongoengineFieldsDict(DictField):
"""
DictField representing mongoengine field names/value mapping.
Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update
operators.
"""
mongoengine_update_operators = (
"inc",
"dec",
"push",
"push_all",
"pop",
"pull",
"pull_all",
"add_to_set",
)
@staticmethod
def _normalize_mongo_value(value):
if isinstance(value, BaseDocument):
return value.to_mongo()
return value
@classmethod
def _normalize_mongo_field_path(cls, path, value):
parts = path.split("__")
if len(parts) > 1:
if parts[0] == "set":
parts = parts[1:]
elif parts[0] == "unset":
parts = parts[1:]
value = None
elif parts[0] in cls.mongoengine_update_operators:
return None, None
return ".".join(parts), cls._normalize_mongo_value(value)
def parse_value(self, value):
value = super(MongoengineFieldsDict, self).parse_value(value)
return {
k: v
for k, v in (self._normalize_mongo_field_path(*p) for p in value.items())
if k is not None
}
class UpdateResponse(models.Base):
updated = fields.IntField(required=True)
fields = MongoengineFieldsDict()
class PagedRequest(models.Base):
page = fields.IntField()
page_size = fields.IntField()
class IdResponse(models.Base):
id = fields.StringField(required=True)
class MakePublicRequest(models.Base):
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
class MoveRequest(models.Base):
ids = ListField([str], validators=Length(minimum_value=1))
project = fields.StringField()
project_name = fields.StringField()