Add Redis cluster support

Fix for lru_cache usage
This commit is contained in:
allegroai 2022-02-13 19:48:26 +02:00
parent 970a32287a
commit c4001b4037
4 changed files with 52 additions and 147 deletions
apiserver

View File

@ -1,4 +1,5 @@
from datetime import datetime
from functools import lru_cache
from os import getenv
from typing import Tuple, Optional
@ -81,11 +82,7 @@ class ESFactory:
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
http_auth = (
cls.get_credentials(cluster_name)
if cluster_config.get("secure", True)
else None
)
http_auth = cls.get_credentials(cluster_name)
args = cluster_config.get("args", {})
_instances[cluster_name] = Elasticsearch(
@ -95,7 +92,11 @@ class ESFactory:
return _instances[cluster_name]
@classmethod
def get_credentials(cls, cluster_name: str) -> Optional[Tuple[str, str]]:
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
if not cluster_config.get("secure", True):
return None
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
if not elastic_user:
return None
@ -119,6 +120,7 @@ class ESFactory:
return OVERRIDE_HOST, OVERRIDE_PORT
@classmethod
@lru_cache()
def get_cluster_config(cls, cluster_name):
"""
Returns cluster config for the specified cluster path

View File

@ -1,10 +1,8 @@
import threading
from os import getenv
from time import sleep
from boltons.iterutils import first
from redis import StrictRedis
from redis.sentinel import Sentinel, SentinelConnectionPool
from rediscluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config
@ -38,107 +36,15 @@ if OVERRIDE_PORT:
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class MyPubSubWorkerThread(threading.Thread):
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
super(MyPubSubWorkerThread, self).__init__()
self.daemon = daemon
self.sentinel = sentinel
self.on_new_master = on_new_master
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.msg_sleep_time = msg_sleep_time
self._running = False
self.pubsub = None
def subscribe(self):
if self.pubsub:
try:
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
except Exception:
pass
finally:
self.pubsub = None
subscriptions = {"+switch-master": self.on_new_master}
while not self.pubsub or not self.pubsub.subscribed:
try:
self.pubsub = self.sentinel.pubsub()
self.pubsub.subscribe(**subscriptions)
except Exception as ex:
log.warn(
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
)
sleep(3)
log.info(f"Subscribed to sentinel {self.sentinel_host}")
def run(self):
if self._running:
return
self._running = True
self.subscribe()
while self.pubsub.subscribed:
try:
self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
)
except Exception as ex:
log.warn(
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
)
self.subscribe()
self.pubsub.close()
self._running = False
def stop(self):
# stopping simply unsubscribes from all channels and patterns.
# the unsubscribe responses that are generated will short circuit
# the loop in run(), calling pubsub.close() to clean up the connection
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
# todo,future - multi master clusters?
class RedisCluster(object):
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
self.service_name = service_name
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
self.master = None
self.master_host_port = None
self.reconfigure()
self.sentinel_threads = {}
self.listen()
def reconfigure(self):
try:
self.master_host_port = self.sentinel.discover_master(self.service_name)
self.master = self.sentinel.master_for(self.service_name)
log.info(f"Reconfigured master to {self.master_host_port}")
except Exception as ex:
log.error(f"Error while reconfiguring. {ex.args[0]}")
def listen(self):
def on_new_master(workerThread):
self.reconfigure()
for sentinel in self.sentinel.sentinels:
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
)
self.sentinel_threads[sentinel_host].start()
class RedisManager(object):
def __init__(self, redis_config_dict):
self.aliases = {}
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(f"secure.redis.{alias}.password", None)
alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False)
@ -154,34 +60,15 @@ class RedisManager(object):
if password:
alias_config["password"] = password
db = alias_config.get("db", 0)
sentinels = alias_config.get("sentinels", None)
service_name = alias_config.get("service_name", None)
if not is_cluster and sentinels:
raise ConfigError(
"Redis configuration is invalid. mixed regular and cluster mode",
alias=alias,
)
if is_cluster and (not sentinels or not service_name):
raise ConfigError(
"Redis configuration is invalid. missing sentinels or service_name",
alias=alias,
)
if not is_cluster and (not port or not host):
if not port or not host:
raise ConfigError(
"Redis configuration is invalid. missing port or host", alias=alias
)
if is_cluster:
# todo support all redis connection args via sentinel's connection_kwargs
del alias_config["sentinels"]
del alias_config["cluster"]
del alias_config["service_name"]
self.aliases[alias] = RedisCluster(
sentinels, service_name, **alias_config
)
del alias_config["db"]
self.aliases[alias] = RedisCluster(**alias_config)
else:
self.aliases[alias] = StrictRedis(**alias_config)
@ -189,27 +76,21 @@ class RedisManager(object):
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
if isinstance(obj, RedisCluster):
obj.master.get("health")
return obj.master
else:
obj.get("health")
return obj
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
pool = r.connection_pool
if isinstance(pool, SentinelConnectionPool):
connections = pool.connection_kwargs[
"connection_pool"
]._available_connections
if isinstance(r, RedisCluster):
connections = first(r.connection_pool._available_connections.values())
else:
connections = pool._available_connections
connections = r.connection_pool._available_connections
if len(connections) > 0:
return connections[0].host
else:
if not connections:
return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis"))

View File

@ -7,7 +7,7 @@ from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
@ -52,6 +52,19 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies").copy()
if value is None:
# Removing a cookie
kwargs["max_age"] = 0
kwargs["expires"] = 0
value = ""
elif not company:
# Setting a cookie, let's try to figure out the company
# noinspection PyBroadException
try:
company = Token.decode_identity(value).company
except Exception:
pass
if company:
try:
# use no default value to allow setting a null domain as well
@ -59,11 +72,6 @@ class RequestHandlers:
except KeyError:
pass
if value is None:
kwargs["max_age"] = 0
kwargs["expires"] = 0
value = ""
response.set_cookie(key, value, **kwargs)
return response

View File

@ -12,6 +12,9 @@ from .payload import Payload
token_secret = config.get('secure.auth.token_secret')
log = config.logger(__file__)
class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
@ -94,3 +97,14 @@ class Token(Payload):
token.exp = now + timedelta(seconds=expiration_sec)
return token.encode(**extra_payload)
@classmethod
def decode_identity(cls, encoded_token):
# noinspection PyBroadException
try:
from ..auth import Identity
decoded = cls.decode(encoded_token, verify=False)
return Identity.from_dict(decoded.get("identity", {}))
except Exception as ex:
log.error(f"Failed parsing identity from encoded token: {ex}")