from contextlib import contextmanager
from typing import Optional, TypeVar, Generic, Type, Callable

from redis import StrictRedis

from apiserver import database

T = TypeVar("T")


def _do_nothing(_: T):
    return


class RedisCacheManager(Generic[T]):
    """
    Class for store/retrieve of state objects from redis

    self.state_class - class of the state
    self.redis - instance of redis
    self.expiration_interval - expiration interval in seconds
    """

    def __init__(
        self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
    ):
        self.state_class = state_class
        self.redis = redis
        self.expiration_interval = expiration_interval

    def set_state(self, state: T) -> None:
        redis_key = self._get_redis_key(state.id)
        self.redis.set(redis_key, state.to_json())
        self.redis.expire(redis_key, self.expiration_interval)

    def get_state(self, state_id) -> Optional[T]:
        redis_key = self._get_redis_key(state_id)
        response = self.redis.get(redis_key)
        if response:
            return self.state_class.from_json(response)

    def delete_state(self, state_id) -> None:
        self.redis.delete(self._get_redis_key(state_id))

    def _get_redis_key(self, state_id):
        return f"{self.state_class}/{state_id}"

    def get_or_create_state_core(
        self,
        state_id=None,
        init_state: Callable[[T], None] = _do_nothing,
        validate_state: Callable[[T], None] = _do_nothing,
    ) -> T:
        state = self.get_state(state_id) if state_id else None
        if state:
            validate_state(state)
        else:
            state = self.state_class(id=database.utils.id())
            init_state(state)

        return state

    @contextmanager
    def get_or_create_state(
        self,
        state_id=None,
        init_state: Callable[[T], None] = _do_nothing,
        validate_state: Callable[[T], None] = _do_nothing,
    ):
        """
        Try to retrieve state with the given id from the Redis cache if yes then validates it
        If no then create a new one with randomly generated id
        Yield the state and write it back to redis once the user code block exits
        :param state_id: id of the state to retrieve
        :param init_state: user callback to init the newly created state
        If not passed then no init except for the id generation is done
        :param validate_state: user callback to validate the state if retrieved from cache
        Should throw an exception if the state is not valid. If not passed then no validation is done
        """
        state = self.get_or_create_state_core(
            state_id=state_id, init_state=init_state, validate_state=validate_state
        )

        try:
            yield state
        finally:
            self.set_state(state)