import json
import os
import time
from contextlib import contextmanager
from typing import Sequence
from typing import Union, Tuple, Type

import requests
import six
from boltons.iterutils import remap
from boltons.typeutils import issubclass
from requests.adapters import HTTPAdapter
from requests.auth import HTTPBasicAuth
from requests.packages.urllib3.util.retry import Retry

from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config

log = config.logger(__file__)


class APICallResult:
    def __init__(self, res):
        if isinstance(res, str):
            self._dict = json.loads(res)
        else:
            self._dict = res

        self.data = self._dict["data"]
        self.meta = APICallResultMeta(self._dict["meta"])

    def as_dict(self):
        return self._dict


class APICallResultMeta:
    def __init__(self, meta_dict):
        m = meta_dict
        self._dict = m
        self.call_id = m["id"]
        self.trx_id = m["trx"]
        self.result_code = m["result_code"]
        self.result_msg = m["result_msg"]
        self.result_subcode = m["result_subcode"]
        self.error_stack = m.get("error_stack") or ""
        self.endpoint_name = m["endpoint"]["name"]
        self.endpoint_version = m["endpoint"]["actual_version"]
        self.requested_endpoint_version = m["endpoint"]["requested_version"]
        self.codes = self.result_code, self.result_subcode


def format_duration(duration):
    return "" if duration is None else f" ({int(duration)} sec)"


class APIError(Exception):
    def __init__(self, result: APICallResult, duration=None):
        super().__init__(result)
        self.result = result
        self.duration = duration

    def __str__(self):
        meta = self.result.meta
        message = (
            f"APIClient got {meta.result_code}/{meta.result_subcode} from {meta.endpoint_name}: "
            f"{meta.result_msg}{format_duration(self.duration)}"
        )
        if meta.error_stack:
            header = "\n--- SERVER ERROR {} ---\n"
            formatted_traceback = "\n{}{}{}\n".format(header.format("START"), meta.error_stack, header.format("END"))
            message += formatted_traceback
        return message


class AttrDict(dict):
    """
    ``dict`` which supports attribute lookup syntax.
    Use to implement polymorphism over ``dict``s and database objects, which don't support subscription syntax.
    """

    def __init__(self, dct=None, **kwargs):
        super().__init__(
            remap(
                {**(dct or {}), **kwargs},
                lambda _, key, value: (
                    key,
                    type(self)(value)
                    if isinstance(value, dict) and not isinstance(value, type(self))
                    else value,
                ),
            )
        )

    def __getattr__(self, item):
        return self[item]


class APIClient:
    def __init__(
        self,
        api_key=None,
        secret_key=None,
        base_url=None,
        impersonated_user_id=None,
        session_token=None,
    ):
        if not session_token:
            self.api_key = (
                api_key
                or os.environ.get("SM_API_KEY")
                or config.get("apiclient.api_key")
            )
            if not self.api_key:
                raise ValueError("APIClient requires api_key in constructor or config")

            self.secret_key = (
                secret_key
                or os.environ.get("SM_API_SECRET")
                or config.get("apiclient.secret_key")
            )
            if not self.secret_key:
                raise ValueError(
                    "APIClient requires secret_key in constructor or config"
                )

        self.base_url = (
            base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
        )
        if not self.base_url:
            raise ValueError("APIClient requires base_url in constructor or config")

        if self.base_url.endswith("/"):
            self.base_url = self.base_url[:-1]

        self.session_token = session_token

        # create http session
        self.http_session = requests.session()
        retries = config.get("apiclient.retries", 7)
        backoff_factor = config.get("apiclient.backoff_factor", 0.3)
        status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
        retry = Retry(
            total=retries,
            read=retries,
            connect=retries,
            backoff_factor=backoff_factor,
            status_forcelist=status_forcelist,
        )
        adapter = HTTPAdapter(max_retries=retry)
        self.http_session.mount("http://", adapter)
        self.http_session.mount("https://", adapter)

        if impersonated_user_id:
            self.http_session.headers["X-ClearML-Impersonate-As"] = impersonated_user_id

        if not self.session_token:
            self.login()

    def login(self):
        res, self.session_token = self.send("auth.login", extract="token")

    def impersonate(self, user_id):
        return type(self)(
            impersonated_user_id=user_id,
            **{
                key: getattr(self, key)
                for key in ("api_key", "secret_key", "base_url", "session_token")
            },
        )

    def send_batch(
        self,
        endpoint,
        items: Sequence[dict],
        headers_overrides=None,
        raise_errors=True,
        log_response=False,
        extract=None,
    ):
        headers_overrides = headers_overrides or {}
        headers_overrides.update({"Content-Type": "application/json-lines"})
        assert isinstance(items, (list, tuple)), type(items)
        json_lines = (json.dumps(item) for item in items)
        data = "\n".join(json_lines)
        return self.send(
            endpoint,
            data=data,
            headers_overrides=headers_overrides,
            raise_errors=raise_errors,
            log_response=log_response,
            extract=extract,
        )

    def send(
        self,
        endpoint,
        data=None,
        headers_overrides=None,
        raise_errors=True,
        log_response=False,
        is_async=False,
        extract=None,
    ):
        if headers_overrides is None:
            headers_overrides = {}
        if data is None:
            data = {}
        headers = {"Content-Type": "application/json"}
        headers.update(headers_overrides)
        if is_async:
            headers["X-ClearML-Async"] = "1"

        if not isinstance(data, six.string_types):
            data = json.dumps(data)

        if not self.session_token:
            auth = HTTPBasicAuth(self.api_key, self.secret_key)
        else:
            auth = None
            headers["Authorization"] = "Bearer %s" % self.session_token

        url = "%s/%s" % (self.base_url, endpoint)
        start = time.time()

        http_res = self.http_session.post(url, headers=headers, data=data, auth=auth)
        if not http_res.text:
            msg = "APIClient got non standard response from %s. http_status=%s " % (
                url,
                http_res.status_code,
            )
            log.error(msg)
            raise Exception(msg)

        res = APICallResult(http_res.text)
        if res.meta.result_code == 202:
            # poll server for async result
            got_result = False
            call_id = res.meta.call_id
            async_res_url = "%s/async.result?id=%s" % (self.base_url, call_id)
            async_res_headers = headers.copy()
            async_res_headers.pop("X-ClearML-Async")
            while not got_result:
                log.info("Got 202. Checking async result for %s (%s)" % (url, call_id))
                http_res = self.http_session.get(
                    async_res_url, headers=async_res_headers
                )
                if http_res.status_code == 202:
                    time.sleep(5)
                else:
                    res = APICallResult(http_res.text)
                    got_result = True

        duration = time.time() - start

        if res.meta.result_code != 200:
            error = APIError(res, duration)
            log.error(error)
            if raise_errors:
                raise error
        else:
            msg = "APIClient got {} from {}{}".format(
                res.meta.result_code, url, format_duration(duration)
            )
            log.info(msg)
            if log_response:
                log.debug(res.as_dict())

        if extract is None:
            return res, res.data
        else:
            return res, res.data.get(extract)

    class Service(object):
        def __init__(self, api, name):
            self.api = api
            self.name = name

        def __getattr__(self, item):
            return lambda **kwargs: AttrDict(
                self.api.send("{}.{}".format(self.name, item), kwargs)[1]
            )

    def __getattr__(self, item):
        return self.Service(self, item)

    @staticmethod
    @contextmanager
    def raises(codes: Union[Type[BaseError], Tuple[int, int]]):
        if issubclass(codes, BaseError):
            codes = codes.codes
        log.info("Expecting failure: %s, %s", *codes)
        try:
            yield
        except APIError as e:
            if e.result.meta.codes != codes:
                raise
        else:
            assert False, "call should fail"