2020-06-01 08:27:36 +00:00
|
|
|
from contextlib import contextmanager
|
|
|
|
from typing import Optional, TypeVar, Generic, Type, Callable
|
2020-03-01 16:00:07 +00:00
|
|
|
|
|
|
|
from redis import StrictRedis
|
|
|
|
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver import database
|
|
|
|
from apiserver.timing_context import TimingContext
|
2020-03-01 16:00:07 +00:00
|
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
2020-06-01 08:27:36 +00:00
|
|
|
def _do_nothing(_: T):
|
|
|
|
return
|
|
|
|
|
|
|
|
|
2020-03-01 16:00:07 +00:00
|
|
|
class RedisCacheManager(Generic[T]):
|
|
|
|
"""
|
2020-06-01 08:27:36 +00:00
|
|
|
Class for store/retrieve of state objects from redis
|
2020-03-01 16:00:07 +00:00
|
|
|
|
|
|
|
self.state_class - class of the state
|
|
|
|
self.redis - instance of redis
|
|
|
|
self.expiration_interval - expiration interval in seconds
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self, state_class: Type[T], redis: StrictRedis, expiration_interval: int
|
|
|
|
):
|
|
|
|
self.state_class = state_class
|
|
|
|
self.redis = redis
|
|
|
|
self.expiration_interval = expiration_interval
|
|
|
|
|
|
|
|
def set_state(self, state: T) -> None:
|
|
|
|
redis_key = self._get_redis_key(state.id)
|
|
|
|
with TimingContext("redis", "cache_set_state"):
|
|
|
|
self.redis.set(redis_key, state.to_json())
|
|
|
|
self.redis.expire(redis_key, self.expiration_interval)
|
|
|
|
|
|
|
|
def get_state(self, state_id) -> Optional[T]:
|
|
|
|
redis_key = self._get_redis_key(state_id)
|
|
|
|
with TimingContext("redis", "cache_get_state"):
|
|
|
|
response = self.redis.get(redis_key)
|
|
|
|
if response:
|
|
|
|
return self.state_class.from_json(response)
|
|
|
|
|
|
|
|
def delete_state(self, state_id) -> None:
|
|
|
|
with TimingContext("redis", "cache_delete_state"):
|
|
|
|
self.redis.delete(self._get_redis_key(state_id))
|
|
|
|
|
|
|
|
def _get_redis_key(self, state_id):
|
|
|
|
return f"{self.state_class}/{state_id}"
|
2020-06-01 08:27:36 +00:00
|
|
|
|
2022-02-13 17:23:29 +00:00
|
|
|
def get_or_create_state_core(
|
|
|
|
self,
|
|
|
|
state_id=None,
|
|
|
|
init_state: Callable[[T], None] = _do_nothing,
|
|
|
|
validate_state: Callable[[T], None] = _do_nothing,
|
|
|
|
) -> T:
|
|
|
|
state = self.get_state(state_id) if state_id else None
|
|
|
|
if state:
|
|
|
|
validate_state(state)
|
|
|
|
else:
|
|
|
|
state = self.state_class(id=database.utils.id())
|
|
|
|
init_state(state)
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
2020-06-01 08:27:36 +00:00
|
|
|
@contextmanager
|
|
|
|
def get_or_create_state(
|
|
|
|
self,
|
|
|
|
state_id=None,
|
|
|
|
init_state: Callable[[T], None] = _do_nothing,
|
|
|
|
validate_state: Callable[[T], None] = _do_nothing,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
|
|
|
If no then create a new one with randomly generated id
|
|
|
|
Yield the state and write it back to redis once the user code block exits
|
|
|
|
:param state_id: id of the state to retrieve
|
|
|
|
:param init_state: user callback to init the newly created state
|
|
|
|
If not passed then no init except for the id generation is done
|
|
|
|
:param validate_state: user callback to validate the state if retrieved from cache
|
|
|
|
Should throw an exception if the state is not valid. If not passed then no validation is done
|
|
|
|
"""
|
2022-02-13 17:23:29 +00:00
|
|
|
state = self.get_or_create_state_core(
|
|
|
|
state_id=state_id, init_state=init_state, validate_state=validate_state
|
|
|
|
)
|
2020-06-01 08:27:36 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
yield state
|
|
|
|
finally:
|
|
|
|
self.set_state(state)
|