Add support for OneOfEmbeddedField

This commit is contained in:
clearml 2024-12-05 22:27:20 +02:00
parent 543c579a2e
commit a3b303fa28

View File

@ -1,10 +1,10 @@
from enum import Enum from enum import Enum
from typing import Union, Type, Iterable from typing import Union, Type, Iterable, Mapping
import jsonmodels.errors import jsonmodels.errors
import six import six
from jsonmodels import fields from jsonmodels import fields
from jsonmodels.fields import _LazyType, NotSet from jsonmodels.fields import _LazyType, NotSet, EmbeddedField
from jsonmodels.models import Base as ModelBase from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator from jsonmodels.validators import Enum as EnumValidator
from mongoengine.base import BaseDocument from mongoengine.base import BaseDocument
@ -40,6 +40,34 @@ def make_default(field_cls, default_value):
return _FieldWithDefault return _FieldWithDefault
class OneOfEmbeddedField(EmbeddedField):
def __init__(
self,
*args,
discriminator_property: str,
discriminator_mapping: Mapping[str, type],
**kwargs,
):
self.discriminator_property = discriminator_property
self.discriminator_mapping = discriminator_mapping
model_types = tuple(set(self.discriminator_mapping.values()))
super().__init__(model_types, *args, **kwargs)
def parse_value(self, value):
"""Parse value to proper model type."""
if not isinstance(value, dict) or self.discriminator_property not in value:
return super().parse_value(value)
property_value = value.get(self.discriminator_property)
embed_type = self.discriminator_mapping.get(property_value)
if not embed_type:
raise jsonmodels.errors.ValidationError(
f"Could not find type matching discriminator property value: {property_value}"
)
return embed_type(**value)
class ListField(fields.ListField): class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs): def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default): if default is not NotSet and callable(default):
@ -115,9 +143,7 @@ class DictField(fields.BaseField):
if len(self.value_types) != 1: if len(self.value_types) != 1:
tpl = 'Cannot decide which type to choose from "{types}".' tpl = 'Cannot decide which type to choose from "{types}".'
raise jsonmodels.errors.ValidationError( raise jsonmodels.errors.ValidationError(
tpl.format( tpl.format(types=", ".join([t.__name__ for t in self.value_types]))
types=', '.join([t.__name__ for t in self.value_types])
)
) )
return self.value_types[0](**value) return self.value_types[0](**value)
@ -179,7 +205,7 @@ class EnumField(fields.StringField):
*args, *args,
required=False, required=False,
default=None, default=None,
**kwargs **kwargs,
): ):
choices = list(map(self.parse_value, values_or_type)) choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator validator_cls = EnumValidator if required else NullableEnumValidator
@ -202,7 +228,7 @@ class ActualEnumField(fields.StringField):
validators=None, validators=None,
required=False, required=False,
default=None, default=None,
**kwargs **kwargs,
): ):
self.__enum = enum_class self.__enum = enum_class
self.types = (enum_class,) self.types = (enum_class,)
@ -215,7 +241,7 @@ class ActualEnumField(fields.StringField):
*args, *args,
required=required, required=required,
validators=validators, validators=validators,
**kwargs **kwargs,
) )
def parse_value(self, value): def parse_value(self, value):