From b97a6084ce6467f68e93c84f010102afe6c0cc9a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 18:25:18 +0200 Subject: [PATCH] Refactor configuration infrastructure Remove untracked files left from previous commit --- apiserver/apierrors/autogen/__init__.py | 4 - apiserver/apierrors/autogen/__main__.py | 6 - apiserver/apierrors/autogen/generator.py | 96 ------ .../apierrors/autogen/templates/error.jinja2 | 6 - .../apierrors/autogen/templates/init.jinja2 | 14 - .../autogen/templates/section.jinja2 | 9 - apiserver/apimodels/__init__.py | 308 +++++++++++++++++- apiserver/apimodels/base.py | 284 +--------------- .../apimodels/custom_validators/__init__.py | 34 ++ apiserver/config/__init__.py | 11 +- apiserver/config/basic.py | 135 +++++--- apiserver/config_repo.py | 4 +- 12 files changed, 425 insertions(+), 486 deletions(-) delete mode 100644 apiserver/apierrors/autogen/__init__.py delete mode 100644 apiserver/apierrors/autogen/__main__.py delete mode 100644 apiserver/apierrors/autogen/generator.py delete mode 100644 apiserver/apierrors/autogen/templates/error.jinja2 delete mode 100644 apiserver/apierrors/autogen/templates/init.jinja2 delete mode 100644 apiserver/apierrors/autogen/templates/section.jinja2 create mode 100644 apiserver/apimodels/custom_validators/__init__.py diff --git a/apiserver/apierrors/autogen/__init__.py b/apiserver/apierrors/autogen/__init__.py deleted file mode 100644 index b0d6c17..0000000 --- a/apiserver/apierrors/autogen/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -def generate(path, error_codes): - from .generator import Generator - from pathlib import Path - Generator(Path(path) / 'errors', format_pep8=False).make_errors(error_codes) diff --git a/apiserver/apierrors/autogen/__main__.py b/apiserver/apierrors/autogen/__main__.py deleted file mode 100644 index f39c9f3..0000000 --- a/apiserver/apierrors/autogen/__main__.py +++ /dev/null @@ -1,6 +0,0 @@ -if __name__ == '__main__': - from pathlib import Path - from apiserver.apierrors import _error_codes - from apiserver.apierrors.autogen import generate - - generate(Path(__file__).parent.parent, _error_codes) diff --git a/apiserver/apierrors/autogen/generator.py b/apiserver/apierrors/autogen/generator.py deleted file mode 100644 index 90100d2..0000000 --- a/apiserver/apierrors/autogen/generator.py +++ /dev/null @@ -1,96 +0,0 @@ -import re -import json -import jinja2 -import hashlib - -from pathlib import Path - - -env = jinja2.Environment( - loader=jinja2.FileSystemLoader(str(Path(__file__).parent)), - autoescape=jinja2.select_autoescape( - disabled_extensions=("py",), default_for_string=False - ), - trim_blocks=True, - lstrip_blocks=True, -) - - -def env_filter(name=None): - return lambda func: env.filters.setdefault(name or func.__name__, func) - - -@env_filter() -def cls_name(name): - delims = list(map(re.escape, (" ", "_"))) - parts = re.split("|".join(delims), name) - return "".join(x.capitalize() for x in parts) - - -class Generator(object): - _base_class_name = "BaseError" - _base_class_module = "apiserver.apierrors.base" - - def __init__(self, path, format_pep8=True, use_md5=True): - self._use_md5 = use_md5 - self._format_pep8 = format_pep8 - self._path = Path(path) - self._path.mkdir(parents=True, exist_ok=True) - - def _make_init_file(self, path): - (self._path / path / "__init__.py").write_bytes("") - - def _do_render(self, file, template, context): - with file.open("w") as f: - result = template.render( - base_class_name=self._base_class_name, - base_class_module=self._base_class_module, - **context - ) - if self._format_pep8: - import autopep8 - - result = autopep8.fix_code( - result, - options={"aggressive": 1, "verbose": 0, "max_line_length": 120}, - ) - f.write(result) - - def _make_section(self, name, code, subcodes): - self._do_render( - file=(self._path / name).with_suffix(".py"), - template=env.get_template("templates/section.jinja2"), - context=dict(code=code, subcodes=list(subcodes.items()),), - ) - - def _make_init(self, sections): - self._do_render( - file=(self._path / "__init__.py"), - template=env.get_template("templates/init.jinja2"), - context=dict(sections=sections,), - ) - - def _key_to_str(self, data): - if isinstance(data, dict): - return {str(k): self._key_to_str(v) for k, v in data.items()} - return data - - def _calc_digest(self, data): - data = json.dumps(self._key_to_str(data), sort_keys=True) - return hashlib.md5(data.encode("utf8")).hexdigest() - - def make_errors(self, errors): - digest = None - digest_file = self._path / "digest.md5" - if self._use_md5: - digest = self._calc_digest(errors) - if digest_file.is_file(): - if digest_file.read_text() == digest: - return - - self._make_init(errors) - for (code, section_name), subcodes in errors.items(): - self._make_section(section_name, code, subcodes) - - if self._use_md5: - digest_file.write_text(digest) diff --git a/apiserver/apierrors/autogen/templates/error.jinja2 b/apiserver/apierrors/autogen/templates/error.jinja2 deleted file mode 100644 index 01d6689..0000000 --- a/apiserver/apierrors/autogen/templates/error.jinja2 +++ /dev/null @@ -1,6 +0,0 @@ -{% macro error_class(name, msg, code, subcode=0) %} -class {{ name }}({{ base_class_name }}): - _default_code = {{ code }} - _default_subcode = {{ subcode }} - _default_msg = "{{ msg|capitalize }}" -{% endmacro -%} \ No newline at end of file diff --git a/apiserver/apierrors/autogen/templates/init.jinja2 b/apiserver/apierrors/autogen/templates/init.jinja2 deleted file mode 100644 index e16f23f..0000000 --- a/apiserver/apierrors/autogen/templates/init.jinja2 +++ /dev/null @@ -1,14 +0,0 @@ -{% from 'templates/error.jinja2' import error_class with context %} -{% if sections %} -from {{ base_class_module }} import {{ base_class_name }} -{% endif %} - -{% for _, name in sections %} -from . import {{ name }} -{% endfor %} - - -{% for code, name in sections %} -{{ error_class(name|cls_name, name|replace('_', ' '), code) }} - -{% endfor %} diff --git a/apiserver/apierrors/autogen/templates/section.jinja2 b/apiserver/apierrors/autogen/templates/section.jinja2 deleted file mode 100644 index 11263cb..0000000 --- a/apiserver/apierrors/autogen/templates/section.jinja2 +++ /dev/null @@ -1,9 +0,0 @@ -{% from 'templates/error.jinja2' import error_class with context %} -{% if subcodes %} -from {{ base_class_module }} import {{ base_class_name }} -{% endif %} -{% for subcode, (name, msg) in subcodes %} - - -{{ error_class(name|cls_name, msg, code, subcode) -}} -{% endfor %} \ No newline at end of file diff --git a/apiserver/apimodels/__init__.py b/apiserver/apimodels/__init__.py index 6ac4e8c..1718e9e 100644 --- a/apiserver/apimodels/__init__.py +++ b/apiserver/apimodels/__init__.py @@ -1,13 +1,38 @@ -from __future__ import absolute_import - +from enum import Enum from textwrap import shorten +from typing import Union, Type, Iterable +import jsonmodels.errors +import six +from jsonmodels import fields +from jsonmodels.fields import _LazyType, NotSet +from jsonmodels.models import Base as ModelBase +from jsonmodels.validators import Enum as EnumValidator from luqum.exceptions import ParseError from luqum.parser import parser +from mongoengine.base import BaseDocument from validators import email as email_validator, domain as domain_validator from apiserver.apierrors import errors -from .base import * +from apiserver.utilities.json import loads, dumps + + +class EmailField(fields.StringField): + def validate(self, value): + super().validate(value) + if value is None: + return + if email_validator(value) is not True: + raise errors.bad_request.InvalidEmailAddress() + + +class DomainField(fields.StringField): + def validate(self, value): + super().validate(value) + if value is None: + return + if domain_validator(value) is not True: + raise errors.bad_request.InvalidDomainName() def validate_lucene_query(value): @@ -29,19 +54,272 @@ class LuceneQueryField(fields.StringField): validate_lucene_query(value) -class EmailField(fields.StringField): - def validate(self, value): - super().validate(value) - if value is None: - return - if email_validator(value) is not True: - raise errors.bad_request.InvalidEmailAddress() +def make_default(field_cls, default_value): + class _FieldWithDefault(field_cls): + def get_default_value(self): + return default_value + + return _FieldWithDefault -class DomainField(fields.StringField): +class ListField(fields.ListField): + def __init__(self, items_types=None, *args, default=NotSet, **kwargs): + if default is not NotSet and callable(default): + default = default() + + super(ListField, self).__init__(items_types, *args, default=default, **kwargs) + + def _cast_value(self, value): + try: + return super(ListField, self)._cast_value(value) + except TypeError: + if len(self.items_types) == 1 and issubclass(self.items_types[0], Enum): + return self.items_types[0](value) + return value + + def validate_single_value(self, item): + super(ListField, self).validate_single_value(item) + if isinstance(item, ModelBase): + item.validate() + + +# since there is no distinction between None and empty DictField +# this value can be used as sentinel in order to distinguish +# between not set and empty DictField +DictFieldNotSet = {} + + +class DictField(fields.BaseField): + types = (dict,) + + def __init__(self, value_types=None, *args, **kwargs): + self.value_types = self._assign_types(value_types) + super(DictField, self).__init__(*args, **kwargs) + + def get_default_value(self): + default = super(DictField, self).get_default_value() + if default is None and not self.required: + return {} + return default + + @staticmethod + def _assign_types(value_types): + if value_types: + try: + value_types = tuple(value_types) + except TypeError: + value_types = (value_types,) + else: + value_types = tuple() + + return tuple( + _LazyType(type_) if isinstance(type_, six.string_types) else type_ + for type_ in value_types + ) + + def parse_value(self, values): + """Cast value to proper collection.""" + result = self.get_default_value() + + if values is None: + return result + + if not self.value_types or not isinstance(values, dict): + return values + + return {key: self._cast_value(value) for key, value in values.items()} + + def _cast_value(self, value): + if isinstance(value, self.value_types): + return value + else: + if len(self.value_types) != 1: + tpl = 'Cannot decide which type to choose from "{types}".' + raise jsonmodels.errors.ValidationError( + tpl.format( + types=', '.join([t.__name__ for t in self.value_types]) + ) + ) + return self.value_types[0](**value) + def validate(self, value): - super().validate(value) - if value is None: + super(DictField, self).validate(value) + + if not self.value_types: return - if domain_validator(value) is not True: - raise errors.bad_request.InvalidDomainName() + + if not value: + return + + for item in value.values(): + self.validate_single_value(item) + + def validate_single_value(self, item): + if not self.value_types: + return + + if not isinstance(item, self.value_types): + raise jsonmodels.errors.ValidationError( + "All items must be instances " + 'of "{types}", and not "{type}".'.format( + types=", ".join([t.__name__ for t in self.value_types]), + type=type(item).__name__, + ) + ) + + def _elem_to_struct(self, value): + try: + return value.to_struct() + except AttributeError: + return value + + def to_struct(self, values): + return {k: self._elem_to_struct(v) for k, v in values.items()} + + +class IntField(fields.IntField): + def parse_value(self, value): + try: + return super(IntField, self).parse_value(value) + except (ValueError, TypeError): + return value + + +class NullableEnumValidator(EnumValidator): + """Validator for enums that allows a None value.""" + + def validate(self, value): + if value is not None: + super(NullableEnumValidator, self).validate(value) + + +class EnumField(fields.StringField): + def __init__( + self, + values_or_type: Union[Iterable, Type[Enum]], + *args, + required=False, + default=None, + **kwargs + ): + choices = list(map(self.parse_value, values_or_type)) + validator_cls = EnumValidator if required else NullableEnumValidator + kwargs.setdefault("validators", []).append(validator_cls(*choices)) + super().__init__( + default=self.parse_value(default), required=required, *args, **kwargs + ) + + def parse_value(self, value): + if isinstance(value, Enum): + return str(value.value) + return super().parse_value(value) + + +class ActualEnumField(fields.StringField): + def __init__( + self, + enum_class: Type[Enum], + *args, + validators=None, + required=False, + default=None, + **kwargs + ): + self.__enum = enum_class + self.types = (enum_class,) + # noinspection PyTypeChecker + choices = list(enum_class) + validator_cls = EnumValidator if required else NullableEnumValidator + validators = [*(validators or []), validator_cls(*choices)] + super().__init__( + default=self.parse_value(default) if default else NotSet, + *args, + required=required, + validators=validators, + **kwargs + ) + + def parse_value(self, value): + if value is None and not self.required: + return self.get_default_value() + try: + # noinspection PyArgumentList + return self.__enum(value) + except ValueError: + return value + + def to_struct(self, value): + return super().to_struct(value.value) + + +class JsonSerializableMixin: + def to_json(self: ModelBase): + return dumps(self.to_struct()) + + @classmethod + def from_json(cls: Type[ModelBase], s): + return cls(**loads(s)) + + +def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]: + class _Wrapped(cls): + _callable_default = None + + def get_default_value(self): + if self._callable_default: + return self._callable_default() + return super(_Wrapped, self).get_default_value() + + def __init__(self, *args, default=None, **kwargs): + if default and callable(default): + self._callable_default = default + default = default() + super(_Wrapped, self).__init__(*args, default=default, **kwargs) + + return _Wrapped + + +class MongoengineFieldsDict(DictField): + """ + DictField representing mongoengine field names/value mapping. + Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update + operators. + """ + + mongoengine_update_operators = ( + "inc", + "dec", + "push", + "push_all", + "pop", + "pull", + "pull_all", + "add_to_set", + ) + + @staticmethod + def _normalize_mongo_value(value): + if isinstance(value, BaseDocument): + return value.to_mongo() + return value + + @classmethod + def _normalize_mongo_field_path(cls, path, value): + parts = path.split("__") + if len(parts) > 1: + if parts[0] == "set": + parts = parts[1:] + elif parts[0] == "unset": + parts = parts[1:] + value = None + elif parts[0] in cls.mongoengine_update_operators: + return None, None + return ".".join(parts), cls._normalize_mongo_value(value) + + def parse_value(self, value): + value = super(MongoengineFieldsDict, self).parse_value(value) + return { + k: v + for k, v in (self._normalize_mongo_field_path(*p) for p in value.items()) + if k is not None + } diff --git a/apiserver/apimodels/base.py b/apiserver/apimodels/base.py index d4db847..e0cdd3c 100644 --- a/apiserver/apimodels/base.py +++ b/apiserver/apimodels/base.py @@ -1,289 +1,7 @@ -from __future__ import absolute_import - -from enum import Enum -from typing import Union, Type, Iterable - -import jsonmodels.errors -import six -from jsonmodels.fields import _LazyType, NotSet -from jsonmodels.models import Base as ModelBase -from jsonmodels.validators import Enum as EnumValidator - from jsonmodels import models, fields from jsonmodels.validators import Length -from mongoengine.base import BaseDocument -from apiserver.utilities.json import loads, dumps - -def make_default(field_cls, default_value): - class _FieldWithDefault(field_cls): - def get_default_value(self): - return default_value - - return _FieldWithDefault - - -class ListField(fields.ListField): - def __init__(self, items_types=None, *args, default=NotSet, **kwargs): - if default is not NotSet and callable(default): - default = default() - - super(ListField, self).__init__(items_types, *args, default=default, **kwargs) - - def _cast_value(self, value): - try: - return super(ListField, self)._cast_value(value) - except TypeError: - if len(self.items_types) == 1 and issubclass(self.items_types[0], Enum): - return self.items_types[0](value) - return value - - def validate_single_value(self, item): - super(ListField, self).validate_single_value(item) - if isinstance(item, ModelBase): - item.validate() - - -# since there is no distinction between None and empty DictField -# this value can be used as sentinel in order to distinguish -# between not set and empty DictField -DictFieldNotSet = {} - - -class DictField(fields.BaseField): - types = (dict,) - - def __init__(self, value_types=None, *args, **kwargs): - self.value_types = self._assign_types(value_types) - super(DictField, self).__init__(*args, **kwargs) - - def get_default_value(self): - default = super(DictField, self).get_default_value() - if default is None and not self.required: - return {} - return default - - @staticmethod - def _assign_types(value_types): - if value_types: - try: - value_types = tuple(value_types) - except TypeError: - value_types = (value_types,) - else: - value_types = tuple() - - return tuple( - _LazyType(type_) if isinstance(type_, six.string_types) else type_ - for type_ in value_types - ) - - def parse_value(self, values): - """Cast value to proper collection.""" - result = self.get_default_value() - - if values is None: - return result - - if not self.value_types or not isinstance(values, dict): - return values - - return {key: self._cast_value(value) for key, value in values.items()} - - def _cast_value(self, value): - if isinstance(value, self.value_types): - return value - else: - if len(self.value_types) != 1: - tpl = 'Cannot decide which type to choose from "{types}".' - raise jsonmodels.errors.ValidationError( - tpl.format( - types=', '.join([t.__name__ for t in self.value_types]) - ) - ) - return self.value_types[0](**value) - - def validate(self, value): - super(DictField, self).validate(value) - - if not self.value_types: - return - - if not value: - return - - for item in value.values(): - self.validate_single_value(item) - - def validate_single_value(self, item): - if not self.value_types: - return - - if not isinstance(item, self.value_types): - raise jsonmodels.errors.ValidationError( - "All items must be instances " - 'of "{types}", and not "{type}".'.format( - types=", ".join([t.__name__ for t in self.value_types]), - type=type(item).__name__, - ) - ) - - def _elem_to_struct(self, value): - try: - return value.to_struct() - except AttributeError: - return value - - def to_struct(self, values): - return {k: self._elem_to_struct(v) for k, v in values.items()} - - -class IntField(fields.IntField): - def parse_value(self, value): - try: - return super(IntField, self).parse_value(value) - except (ValueError, TypeError): - return value - - -class NullableEnumValidator(EnumValidator): - """Validator for enums that allows a None value.""" - - def validate(self, value): - if value is not None: - super(NullableEnumValidator, self).validate(value) - - -class EnumField(fields.StringField): - def __init__( - self, - values_or_type: Union[Iterable, Type[Enum]], - *args, - required=False, - default=None, - **kwargs - ): - choices = list(map(self.parse_value, values_or_type)) - validator_cls = EnumValidator if required else NullableEnumValidator - kwargs.setdefault("validators", []).append(validator_cls(*choices)) - super().__init__( - default=self.parse_value(default), required=required, *args, **kwargs - ) - - def parse_value(self, value): - if isinstance(value, Enum): - return str(value.value) - return super().parse_value(value) - - -class ActualEnumField(fields.StringField): - def __init__( - self, - enum_class: Type[Enum], - *args, - validators=None, - required=False, - default=None, - **kwargs - ): - self.__enum = enum_class - self.types = (enum_class,) - # noinspection PyTypeChecker - choices = list(enum_class) - validator_cls = EnumValidator if required else NullableEnumValidator - validators = [*(validators or []), validator_cls(*choices)] - super().__init__( - default=self.parse_value(default) if default else NotSet, - *args, - required=required, - validators=validators, - **kwargs - ) - - def parse_value(self, value): - if value is None and not self.required: - return self.get_default_value() - try: - # noinspection PyArgumentList - return self.__enum(value) - except ValueError: - return value - - def to_struct(self, value): - return super().to_struct(value.value) - - -class JsonSerializableMixin: - def to_json(self: ModelBase): - return dumps(self.to_struct()) - - @classmethod - def from_json(cls: Type[ModelBase], s): - return cls(**loads(s)) - - -def callable_default(cls: Type[fields.BaseField]) -> Type[fields.BaseField]: - class _Wrapped(cls): - _callable_default = None - - def get_default_value(self): - if self._callable_default: - return self._callable_default() - return super(_Wrapped, self).get_default_value() - - def __init__(self, *args, default=None, **kwargs): - if default and callable(default): - self._callable_default = default - default = default() - super(_Wrapped, self).__init__(*args, default=default, **kwargs) - - return _Wrapped - - -class MongoengineFieldsDict(DictField): - """ - DictField representing mongoengine field names/value mapping. - Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update - operators. - """ - - mongoengine_update_operators = ( - "inc", - "dec", - "push", - "push_all", - "pop", - "pull", - "pull_all", - "add_to_set", - ) - - @staticmethod - def _normalize_mongo_value(value): - if isinstance(value, BaseDocument): - return value.to_mongo() - return value - - @classmethod - def _normalize_mongo_field_path(cls, path, value): - parts = path.split("__") - if len(parts) > 1: - if parts[0] == "set": - parts = parts[1:] - elif parts[0] == "unset": - parts = parts[1:] - value = None - elif parts[0] in cls.mongoengine_update_operators: - return None, None - return ".".join(parts), cls._normalize_mongo_value(value) - - def parse_value(self, value): - value = super(MongoengineFieldsDict, self).parse_value(value) - return { - k: v - for k, v in (self._normalize_mongo_field_path(*p) for p in value.items()) - if k is not None - } +from apiserver.apimodels import MongoengineFieldsDict, ListField class UpdateResponse(models.Base): diff --git a/apiserver/apimodels/custom_validators/__init__.py b/apiserver/apimodels/custom_validators/__init__.py new file mode 100644 index 0000000..efa2602 --- /dev/null +++ b/apiserver/apimodels/custom_validators/__init__.py @@ -0,0 +1,34 @@ +import validators +from jsonmodels.errors import ValidationError + + +class ForEach(object): + def __init__(self, validator): + self.validator = validator + + def validate(self, values): + for value in values: + self.validator.validate(value) + + def modify_schema(self, field_schema): + return self.validator.modify_schema(field_schema) + + +class Hostname(object): + + def validate(self, value): + if validators.domain(value) is not True: + raise ValidationError(f"Value '{value}' is not a valid hostname") + + def modify_schema(self, field_schema): + field_schema["format"] = "hostname" + + +class Email(object): + + def validate(self, value): + if validators.email(value) is not True: + raise ValidationError(f"Value '{value}' is not a valid email address") + + def modify_schema(self, field_schema): + field_schema["format"] = "email" diff --git a/apiserver/config/__init__.py b/apiserver/config/__init__.py index 96e2744..3b65241 100644 --- a/apiserver/config/__init__.py +++ b/apiserver/config/__init__.py @@ -1,10 +1 @@ -import logging.config -from pathlib import Path - -from .basic import BasicConfig - - -def load_config(): - config = BasicConfig(Path(__file__).with_name("default")) - logging.config.dictConfig(config.get("logging")) - return config +from .basic import BasicConfig, ConfigurationError, Factory diff --git a/apiserver/config/basic.py b/apiserver/config/basic.py index 7a589c8..d9ec908 100644 --- a/apiserver/config/basic.py +++ b/apiserver/config/basic.py @@ -1,10 +1,12 @@ import logging +import logging.config import os import platform from functools import reduce from os import getenv from os.path import expandvars from pathlib import Path +from typing import List, Any, Type, TypeVar from pyhocon import ConfigTree, ConfigFactory from pyparsing import ( @@ -14,30 +16,41 @@ from pyparsing import ( ParseSyntaxException, ) -DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config" -EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR" -EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';' +from apiserver.utilities import json -EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__" -EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}" +EXTRA_CONFIG_PATHS = ("/opt/trains/config",) +EXTRA_CONFIG_PATH_OVERRIDE_VAR = "TRAINS_CONFIG_DIR" +EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ";" class BasicConfig: NotSet = object() - def __init__(self, folder): - self.folder = Path(folder) - if not self.folder.is_dir(): + extra_config_values_env_key_sep = "__" + default_config_dir = "default" + + def __init__( + self, folder: str = None, verbose: bool = True, prefix: str = "trains" + ): + folder = ( + Path(folder) + if folder + else Path(__file__).with_name(self.default_config_dir) + ) + if not folder.is_dir(): raise ValueError("Invalid configuration folder") - self.prefix = "trains" + self.verbose = verbose + self.prefix = prefix + self.extra_config_values_env_key_prefix = f"{self.prefix.upper()}__" - self._load() + self._paths = [folder, *self._get_paths()] + self._config = self._reload() def __getitem__(self, key): return self._config[key] - def get(self, key, default=NotSet): + def get(self, key: str, default: Any = NotSet) -> Any: value = self._config.get(key, default) if value is self.NotSet and not default: raise KeyError( @@ -45,51 +58,62 @@ class BasicConfig: ) return value - def logger(self, name): + def to_dict(self) -> dict: + return self._config.as_plain_ordered_dict() + + def as_json(self) -> str: + return json.dumps(self.to_dict(), indent=2) + + def logger(self, name: str) -> logging.Logger: if Path(name).is_file(): name = Path(name).stem path = ".".join((self.prefix, name)) return logging.getLogger(path) - def _read_extra_env_config_values(self): + def _read_extra_env_config_values(self) -> ConfigTree: """ Loads extra configuration from environment-injected values """ result = ConfigTree() - prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX + prefix = self.extra_config_values_env_key_prefix keys = sorted(k for k in os.environ if k.startswith(prefix)) for key in keys: - path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower() + path = ( + key[len(prefix) :] + .replace(self.extra_config_values_env_key_sep, ".") + .lower() + ) result = ConfigTree.merge_configs( result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}") ) return result - def _read_env_paths(self, key): - value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH) - if value is None: - return + def _get_paths(self) -> List[Path]: + default_paths = EXTRA_CONFIG_PATH_SEP.join(EXTRA_CONFIG_PATHS) + value = getenv(EXTRA_CONFIG_PATH_OVERRIDE_VAR, default_paths) + paths = [ Path(expandvars(v)).expanduser() for v in value.split(EXTRA_CONFIG_PATH_SEP) ] - invalid = [ - path - for path in paths - if not path.is_dir() and str(path) != DEFAULT_EXTRA_CONFIG_PATH - ] - if invalid: - print(f"WARNING: Invalid paths in {key} env var: {' '.join(map(str, invalid))}") + + if value is not default_paths: + invalid = [path for path in paths if not path.is_dir()] + if invalid: + print( + f"WARNING: Invalid paths in {EXTRA_CONFIG_PATH_OVERRIDE_VAR} env var: {' '.join(map(str, invalid))}" + ) + return [path for path in paths if path.is_dir()] - def _load(self, verbose=True): - extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or [] - extra_config_values = self._read_extra_env_config_values() - configs = [ - self._read_recursive(path, verbose=verbose) - for path in [self.folder] + extra_config_paths - ] + def reload(self): + self._config = self._reload() - self._config = reduce( + def _reload(self) -> ConfigTree: + extra_config_values = self._read_extra_env_config_values() + + configs = [self._read_recursive(path) for path in self._paths] + + return reduce( lambda last, config: ConfigTree.merge_configs( last, config, copy_trees=True ), @@ -97,32 +121,31 @@ class BasicConfig: ConfigTree(), ) - def _read_recursive(self, conf_root, verbose=True): + def _read_recursive(self, conf_root) -> ConfigTree: conf = ConfigTree() if not conf_root: return conf if not conf_root.is_dir(): - if verbose: + if self.verbose: if not conf_root.exists(): print(f"No config in {conf_root}") else: print(f"Not a directory: {conf_root}") return conf - if verbose: + if self.verbose: print(f"Loading config from {conf_root}") for file in conf_root.rglob("*.conf"): key = ".".join(file.relative_to(conf_root).with_suffix("").parts) - conf.put(key, self._read_single_file(file, verbose=verbose)) + conf.put(key, self._read_single_file(file)) return conf - @staticmethod - def _read_single_file(file_path, verbose=True): - if verbose: + def _read_single_file(self, file_path): + if self.verbose: print(f"Loading config from file {file_path}") try: @@ -137,8 +160,38 @@ class BasicConfig: print(f"Failed loading {file_path}: {ex}") raise + def initialize_logging(self): + logging_config = self.get("logging", None) + if not logging_config: + return + logging.config.dictConfig(logging_config) + class ConfigurationError(Exception): def __init__(self, msg, file_path=None, *args): super(ConfigurationError, self).__init__(msg, *args) self.file_path = file_path + + +ConfigType = TypeVar("ConfigType", bound=BasicConfig) + + +class Factory: + _config_cls: Type[ConfigType] = BasicConfig + + @classmethod + def get(cls) -> BasicConfig: + config = cls._config_cls() + config.initialize_logging() + return config + + @classmethod + def set_cls(cls, cls_: Type[ConfigType]): + cls._config_cls = cls_ + + +__all__ = [ + "Factory", + "BasicConfig", + "ConfigurationError", +] diff --git a/apiserver/config_repo.py b/apiserver/config_repo.py index b91369f..b72a1f6 100644 --- a/apiserver/config_repo.py +++ b/apiserver/config_repo.py @@ -1,3 +1,3 @@ -from apiserver.config import load_config +from apiserver.config import Factory -config = load_config() +config = Factory.get()