mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
277 lines
7.9 KiB
Python
277 lines
7.9 KiB
Python
from __future__ import absolute_import
|
|
|
|
from enum import Enum
|
|
from typing import Union, Type, Iterable
|
|
|
|
import jsonmodels.errors
|
|
import six
|
|
from jsonmodels.fields import _LazyType, NotSet
|
|
from jsonmodels.models import Base as ModelBase
|
|
from jsonmodels.validators import Enum as EnumValidator
|
|
|
|
from jsonmodels import models, fields
|
|
from jsonmodels.validators import Length
|
|
from mongoengine.base import BaseDocument
|
|
from apiserver.utilities.json import loads, dumps
|
|
|
|
|
|
def make_default(field_cls, default_value):
|
|
class _FieldWithDefault(field_cls):
|
|
def get_default_value(self):
|
|
return default_value
|
|
|
|
return _FieldWithDefault
|
|
|
|
|
|
class ListField(fields.ListField):
|
|
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
|
|
if default is not NotSet and callable(default):
|
|
default = default()
|
|
|
|
super(ListField, self).__init__(items_types, *args, default=default, **kwargs)
|
|
|
|
def _cast_value(self, value):
|
|
try:
|
|
return super(ListField, self)._cast_value(value)
|
|
except TypeError:
|
|
if len(self.items_types) == 1 and issubclass(self.items_types[0], Enum):
|
|
return self.items_types[0](value)
|
|
return value
|
|
|
|
def validate_single_value(self, item):
|
|
super(ListField, self).validate_single_value(item)
|
|
if isinstance(item, ModelBase):
|
|
item.validate()
|
|
|
|
|
|
# since there is no distinction between None and empty DictField
|
|
# this value can be used as sentinel in order to distinguish
|
|
# between not set and empty DictField
|
|
DictFieldNotSet = {}
|
|
|
|
|
|
class DictField(fields.BaseField):
|
|
types = (dict,)
|
|
|
|
def __init__(self, value_types=None, *args, **kwargs):
|
|
self.value_types = self._assign_types(value_types)
|
|
super(DictField, self).__init__(*args, **kwargs)
|
|
|
|
def get_default_value(self):
|
|
default = super(DictField, self).get_default_value()
|
|
if default is None and not self.required:
|
|
return {}
|
|
return default
|
|
|
|
@staticmethod
|
|
def _assign_types(value_types):
|
|
if value_types:
|
|
try:
|
|
value_types = tuple(value_types)
|
|
except TypeError:
|
|
value_types = (value_types,)
|
|
else:
|
|
value_types = tuple()
|
|
|
|
return tuple(
|
|
_LazyType(type_) if isinstance(type_, six.string_types) else type_
|
|
for type_ in value_types
|
|
)
|
|
|
|
def validate(self, value):
|
|
super(DictField, self).validate(value)
|
|
|
|
if not self.value_types:
|
|
return
|
|
|
|
if not value:
|
|
return
|
|
|
|
for item in value.values():
|
|
self.validate_single_value(item)
|
|
|
|
def validate_single_value(self, item):
|
|
if not self.value_types:
|
|
return
|
|
|
|
if not isinstance(item, self.value_types):
|
|
raise jsonmodels.errors.ValidationError(
|
|
"All items must be instances "
|
|
'of "{types}", and not "{type}".'.format(
|
|
types=", ".join([t.__name__ for t in self.value_types]),
|
|
type=type(item).__name__,
|
|
)
|
|
)
|
|
|
|
|
|
class IntField(fields.IntField):
|
|
def parse_value(self, value):
|
|
try:
|
|
return super(IntField, self).parse_value(value)
|
|
except (ValueError, TypeError):
|
|
return value
|
|
|
|
|
|
class NullableEnumValidator(EnumValidator):
|
|
"""Validator for enums that allows a None value."""
|
|
|
|
def validate(self, value):
|
|
if value is not None:
|
|
super(NullableEnumValidator, self).validate(value)
|
|
|
|
|
|
class EnumField(fields.StringField):
|
|
def __init__(
|
|
self,
|
|
values_or_type: Union[Iterable, Type[Enum]],
|
|
*args,
|
|
required=False,
|
|
default=None,
|
|
**kwargs
|
|
):
|
|
choices = list(map(self.parse_value, values_or_type))
|
|
validator_cls = EnumValidator if required else NullableEnumValidator
|
|
kwargs.setdefault("validators", []).append(validator_cls(*choices))
|
|
super().__init__(
|
|
default=self.parse_value(default), required=required, *args, **kwargs
|
|
)
|
|
|
|
def parse_value(self, value):
|
|
if isinstance(value, Enum):
|
|
return str(value.value)
|
|
return super().parse_value(value)
|
|
|
|
|
|
class ActualEnumField(fields.StringField):
|
|
def __init__(
|
|
self,
|
|
enum_class: Type[Enum],
|
|
*args,
|
|
validators=None,
|
|
required=False,
|
|
default=None,
|
|
**kwargs
|
|
):
|
|
self.__enum = enum_class
|
|
self.types = (enum_class,)
|
|
# noinspection PyTypeChecker
|
|
choices = list(enum_class)
|
|
validator_cls = EnumValidator if required else NullableEnumValidator
|
|
validators = [*(validators or []), validator_cls(*choices)]
|
|
super().__init__(
|
|
default=self.parse_value(default) if default else NotSet,
|
|
*args,
|
|
required=required,
|
|
validators=validators,
|
|
**kwargs
|
|
)
|
|
|
|
def parse_value(self, value):
|
|
if value is None and not self.required:
|
|
return self.get_default_value()
|
|
try:
|
|
# noinspection PyArgumentList
|
|
return self.__enum(value)
|
|
except ValueError:
|
|
return value
|
|
|
|
def to_struct(self, value):
|
|
return super().to_struct(value.value)
|
|
|
|
|
|
class JsonSerializableMixin:
|
|
def to_json(self: ModelBase):
|
|
return dumps(self.to_struct())
|
|
|
|
@classmethod
|
|
def from_json(cls: Type[ModelBase], s):
|
|
return cls(**loads(s))
|
|
|
|
|
|
def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]:
|
|
class _Wrapped(cls):
|
|
_callable_default = None
|
|
|
|
def get_default_value(self):
|
|
if self._callable_default:
|
|
return self._callable_default()
|
|
return super(_Wrapped, self).get_default_value()
|
|
|
|
def __init__(self, *args, default=None, **kwargs):
|
|
if default and callable(default):
|
|
self._callable_default = default
|
|
default = default()
|
|
super(_Wrapped, self).__init__(*args, default=default, **kwargs)
|
|
|
|
return _Wrapped
|
|
|
|
|
|
class MongoengineFieldsDict(DictField):
|
|
"""
|
|
DictField representing mongoengine field names/value mapping.
|
|
Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update
|
|
operators.
|
|
"""
|
|
|
|
mongoengine_update_operators = (
|
|
"inc",
|
|
"dec",
|
|
"push",
|
|
"push_all",
|
|
"pop",
|
|
"pull",
|
|
"pull_all",
|
|
"add_to_set",
|
|
)
|
|
|
|
@staticmethod
|
|
def _normalize_mongo_value(value):
|
|
if isinstance(value, BaseDocument):
|
|
return value.to_mongo()
|
|
return value
|
|
|
|
@classmethod
|
|
def _normalize_mongo_field_path(cls, path, value):
|
|
parts = path.split("__")
|
|
if len(parts) > 1:
|
|
if parts[0] == "set":
|
|
parts = parts[1:]
|
|
elif parts[0] == "unset":
|
|
parts = parts[1:]
|
|
value = None
|
|
elif parts[0] in cls.mongoengine_update_operators:
|
|
return None, None
|
|
return ".".join(parts), cls._normalize_mongo_value(value)
|
|
|
|
def parse_value(self, value):
|
|
value = super(MongoengineFieldsDict, self).parse_value(value)
|
|
return {
|
|
k: v
|
|
for k, v in (self._normalize_mongo_field_path(*p) for p in value.items())
|
|
if k is not None
|
|
}
|
|
|
|
|
|
class UpdateResponse(models.Base):
|
|
updated = fields.IntField(required=True)
|
|
fields = MongoengineFieldsDict()
|
|
|
|
|
|
class PagedRequest(models.Base):
|
|
page = fields.IntField()
|
|
page_size = fields.IntField()
|
|
|
|
|
|
class IdResponse(models.Base):
|
|
id = fields.StringField(required=True)
|
|
|
|
|
|
class MakePublicRequest(models.Base):
|
|
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
|
|
|
|
|
class MoveRequest(models.Base):
|
|
ids = ListField([str], validators=Length(minimum_value=1))
|
|
project = fields.StringField()
|
|
project_name = fields.StringField()
|