diff --git a/apiserver/apimodels/__init__.py b/apiserver/apimodels/__init__.py index 9c3d804..9038b96 100644 --- a/apiserver/apimodels/__init__.py +++ b/apiserver/apimodels/__init__.py @@ -1,10 +1,10 @@ from enum import Enum -from typing import Union, Type, Iterable +from typing import Union, Type, Iterable, Mapping import jsonmodels.errors import six from jsonmodels import fields -from jsonmodels.fields import _LazyType, NotSet +from jsonmodels.fields import _LazyType, NotSet, EmbeddedField from jsonmodels.models import Base as ModelBase from jsonmodels.validators import Enum as EnumValidator from mongoengine.base import BaseDocument @@ -40,6 +40,34 @@ def make_default(field_cls, default_value): return _FieldWithDefault +class OneOfEmbeddedField(EmbeddedField): + def __init__( + self, + *args, + discriminator_property: str, + discriminator_mapping: Mapping[str, type], + **kwargs, + ): + self.discriminator_property = discriminator_property + self.discriminator_mapping = discriminator_mapping + model_types = tuple(set(self.discriminator_mapping.values())) + + super().__init__(model_types, *args, **kwargs) + + def parse_value(self, value): + """Parse value to proper model type.""" + if not isinstance(value, dict) or self.discriminator_property not in value: + return super().parse_value(value) + + property_value = value.get(self.discriminator_property) + embed_type = self.discriminator_mapping.get(property_value) + if not embed_type: + raise jsonmodels.errors.ValidationError( + f"Could not find type matching discriminator property value: {property_value}" + ) + return embed_type(**value) + + class ListField(fields.ListField): def __init__(self, items_types=None, *args, default=NotSet, **kwargs): if default is not NotSet and callable(default): @@ -115,9 +143,7 @@ class DictField(fields.BaseField): if len(self.value_types) != 1: tpl = 'Cannot decide which type to choose from "{types}".' raise jsonmodels.errors.ValidationError( - tpl.format( - types=', '.join([t.__name__ for t in self.value_types]) - ) + tpl.format(types=", ".join([t.__name__ for t in self.value_types])) ) return self.value_types[0](**value) @@ -179,7 +205,7 @@ class EnumField(fields.StringField): *args, required=False, default=None, - **kwargs + **kwargs, ): choices = list(map(self.parse_value, values_or_type)) validator_cls = EnumValidator if required else NullableEnumValidator @@ -202,7 +228,7 @@ class ActualEnumField(fields.StringField): validators=None, required=False, default=None, - **kwargs + **kwargs, ): self.__enum = enum_class self.types = (enum_class,) @@ -215,7 +241,7 @@ class ActualEnumField(fields.StringField): *args, required=required, validators=validators, - **kwargs + **kwargs, ) def parse_value(self, value):