"""
Objects representing schema entities
"""
import json
import re
from operator import attrgetter
from pathlib import Path
from typing import Mapping, Sequence

import attr
from boltons.dictutils import subdict
from pyhocon import ConfigFactory

from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion


log = config.logger(__file__)

ALL_ROLES = "*"


class EndpointSchema:
    REQUEST_KEY = "request"
    RESPONSE_KEY = "response"
    BATCH_REQUEST_KEY = "batch_request"
    DEFINITIONS_KEY = "definitions"

    def __init__(
        self,
        service_name: str,
        action_name: str,
        version: PartialVersion,
        schema: dict,
        definitions: dict = None,
    ):
        """
        Class for interacting with the schema of a single endpoint
        :param service_name: name of containing service
        :param action_name: name of action
        :param version: endpoint version
        :param schema: endpoint schema
        :param definitions: service definitions
        """
        self.service_name = service_name
        self.action_name = action_name
        self.full_name = f"{service_name}.{action_name}"
        self.version = version
        self.definitions = definitions
        self.request_schema = None
        self.batch_request_schema = None
        if self.REQUEST_KEY in schema:
            self.request_schema = {
                **schema[self.REQUEST_KEY],
                self.DEFINITIONS_KEY: self.definitions,
            }
        elif self.BATCH_REQUEST_KEY in schema:
            self.batch_request_schema = {
                **schema[self.BATCH_REQUEST_KEY],
                self.DEFINITIONS_KEY: self.definitions,
            }
        else:
            raise RuntimeError(
                f"endpoint {self.full_name} version {self.version} "
                "has no request or batch_request schema",
                schema,
            )
        self.response_schema = {
            **schema[self.RESPONSE_KEY],
            "definitions": self.definitions,
        }


class EndpointVersionsGroup:

    endpoints: Sequence[EndpointSchema]
    allow_roles: Sequence[str]
    internal: bool
    authorize: bool

    def __repr__(self):
        return (
            f"{type(self).__name__}<{self.full_name}, "
            f"versions={tuple(e.version for e in self.endpoints)}>"
        )

    def __init__(
        self,
        service_name: str,
        action_name: str,
        conf: dict,
        definitions: dict = None,
        defaults: dict = None,
    ):
        """
        Represents multiple implementations of a single endpoint, discriminated by API version
        :param service_name: name of containing service
        :param action_name: name of action
        :param conf: mapping between minimum version to endpoint schema
        :param definitions: service definitions
        :param defaults: service defaults
        """
        self.service_name = service_name
        self.action_name = action_name
        self.full_name = f"{service_name}.{action_name}"
        self.definitions = definitions or {}
        self.defaults = defaults or {}
        self.internal = self._pop_attr_with_default(conf, "internal")
        self.allow_roles = self._pop_attr_with_default(conf, "allow_roles")
        self.authorize = self._pop_attr_with_default(conf, "authorize")

        def parse_version(version):
            if not re.match(r"^\d+\.\d+$", version):
                raise ValueError(
                    f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}"
                )
            return PartialVersion(version)

        self.endpoints = sorted(
            (
                EndpointSchema(
                    service_name=self.service_name,
                    action_name=self.action_name,
                    version=parse_version(version),
                    schema=endpoint_conf,
                    definitions=self.definitions,
                )
                for version, endpoint_conf in conf.items()
            ),
            key=attrgetter("version"),
        )

    def allows(self, role):
        return ALL_ROLES in self.allow_roles or role in self.allow_roles

    def _pop_attr_with_default(self, conf, attr):
        return conf.pop(attr, self.defaults[attr])

    def get_for_version(self, min_version: PartialVersion):
        """
        Return endpoint schema for version
        """
        if not self.endpoints:
            raise ValueError(f"endpoint group {self} has no versions")
        for endpoint in self.endpoints:
            if min_version <= endpoint.version:
                return endpoint
        raise ValueError(
            f"min_version {min_version} is higher than highest version in group {self}"
        )


class Service:

    endpoint_groups: Mapping[str, EndpointVersionsGroup]

    def __init__(self, name: str, conf: dict, api_defaults: dict):
        """
        Represents schema of one service
        :param name: name of service
        :param conf: service configuration, containing endpoint groups and other details
        :param api_defaults: API-wide endpoint attributes default values
        """
        self.name = name
        conf = subdict(conf, drop=("_description", "_references"))
        self.defaults = {**api_defaults, **conf.pop("_default", {})}
        self.definitions = conf.pop("_definitions", None)
        self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
            endpoint_name: EndpointVersionsGroup(
                service_name=self.name,
                action_name=endpoint_name,
                conf=endpoint_conf,
                defaults=self.defaults,
                definitions=self.definitions,
            )
            for endpoint_name, endpoint_conf in conf.items()
        }


class Schema:
    services: Mapping[str, Service]

    def __init__(self, services: dict, api_defaults: dict):
        """
        Represents the entire API schema
        :param services: services schema
        :param api_defaults: default values of service configuration
        """
        self.api_defaults = api_defaults
        self.services = {
            name: Service(name, conf, api_defaults=self.api_defaults)
            for name, conf in services.items()
        }


@attr.s()
class SchemaReader:
    root = Path(__file__).parent / "services"
    cache_path: Path = None

    def __attrs_post_init__(self):
        if not self.cache_path:
            self.cache_path = self.root / "_cache.json"

    @staticmethod
    def mod_time(path):
        """
        return file modification time
        """
        return path.stat().st_mtime

    @staticmethod
    def read_file(path):
        return ConfigFactory.parse_file(path).as_plain_ordered_dict()

    def get_schema(self):
        """
        Parse the API schema to schema object.
        Load from config files and write to cache file if possible.
        """
        services = [
            service
            for service in self.root.glob("*.conf")
            if not service.name.startswith("_")
        ]

        current_services_names = {path.stem for path in services}

        try:
            if self.mod_time(self.cache_path) >= max(map(self.mod_time, services)):
                log.info("loading schema from cache")
                result = json.loads(self.cache_path.read_text())
                cached_services_names = set(result.pop("services_names", []))
                if cached_services_names == current_services_names:
                    return Schema(**result)
                else:
                    log.info(
                        f"found services files changed: "
                        f"added: {list(current_services_names - cached_services_names)}, "
                        f"removed: {list(cached_services_names - current_services_names)}"
                    )
        except (IOError, KeyError, TypeError, ValueError, AttributeError) as ex:
            log.warning(f"failed loading cache: {ex}")

        log.info("regenerating schema cache")
        services = {path.stem: self.read_file(path) for path in services}
        api_defaults = self.read_file(self.root / "_api_defaults.conf")

        try:
            self.cache_path.write_text(
                json.dumps(
                    dict(
                        services_names=list(current_services_names),
                        services=services,
                        api_defaults=api_defaults,
                    )
                )
            )
        except IOError:
            log.exception(f"failed cache file to {self.cache_path}")

        return Schema(services, api_defaults)