Add Redis cluster support

Fix for lru_cache usage
This commit is contained in:
allegroai 2022-02-13 19:48:26 +02:00
parent 970a32287a
commit c4001b4037
4 changed files with 52 additions and 147 deletions
apiserver

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from functools import lru_cache
from os import getenv from os import getenv
from typing import Tuple, Optional from typing import Tuple, Optional
@ -81,11 +82,7 @@ class ESFactory:
if not hosts: if not hosts:
raise InvalidClusterConfiguration(cluster_name) raise InvalidClusterConfiguration(cluster_name)
http_auth = ( http_auth = cls.get_credentials(cluster_name)
cls.get_credentials(cluster_name)
if cluster_config.get("secure", True)
else None
)
args = cluster_config.get("args", {}) args = cluster_config.get("args", {})
_instances[cluster_name] = Elasticsearch( _instances[cluster_name] = Elasticsearch(
@ -95,7 +92,11 @@ class ESFactory:
return _instances[cluster_name] return _instances[cluster_name]
@classmethod @classmethod
def get_credentials(cls, cluster_name: str) -> Optional[Tuple[str, str]]: def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
if not cluster_config.get("secure", True):
return None
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None) elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
if not elastic_user: if not elastic_user:
return None return None
@ -119,6 +120,7 @@ class ESFactory:
return OVERRIDE_HOST, OVERRIDE_PORT return OVERRIDE_HOST, OVERRIDE_PORT
@classmethod @classmethod
@lru_cache()
def get_cluster_config(cls, cluster_name): def get_cluster_config(cls, cluster_name):
""" """
Returns cluster config for the specified cluster path Returns cluster config for the specified cluster path

View File

@ -1,10 +1,8 @@
import threading
from os import getenv from os import getenv
from time import sleep
from boltons.iterutils import first from boltons.iterutils import first
from redis import StrictRedis from redis import StrictRedis
from redis.sentinel import Sentinel, SentinelConnectionPool from rediscluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config from apiserver.config_repo import config
@ -38,107 +36,15 @@ if OVERRIDE_PORT:
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY))) OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class MyPubSubWorkerThread(threading.Thread):
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
super(MyPubSubWorkerThread, self).__init__()
self.daemon = daemon
self.sentinel = sentinel
self.on_new_master = on_new_master
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.msg_sleep_time = msg_sleep_time
self._running = False
self.pubsub = None
def subscribe(self):
if self.pubsub:
try:
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
except Exception:
pass
finally:
self.pubsub = None
subscriptions = {"+switch-master": self.on_new_master}
while not self.pubsub or not self.pubsub.subscribed:
try:
self.pubsub = self.sentinel.pubsub()
self.pubsub.subscribe(**subscriptions)
except Exception as ex:
log.warn(
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
)
sleep(3)
log.info(f"Subscribed to sentinel {self.sentinel_host}")
def run(self):
if self._running:
return
self._running = True
self.subscribe()
while self.pubsub.subscribed:
try:
self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
)
except Exception as ex:
log.warn(
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
)
self.subscribe()
self.pubsub.close()
self._running = False
def stop(self):
# stopping simply unsubscribes from all channels and patterns.
# the unsubscribe responses that are generated will short circuit
# the loop in run(), calling pubsub.close() to clean up the connection
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
# todo,future - multi master clusters?
class RedisCluster(object):
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
self.service_name = service_name
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
self.master = None
self.master_host_port = None
self.reconfigure()
self.sentinel_threads = {}
self.listen()
def reconfigure(self):
try:
self.master_host_port = self.sentinel.discover_master(self.service_name)
self.master = self.sentinel.master_for(self.service_name)
log.info(f"Reconfigured master to {self.master_host_port}")
except Exception as ex:
log.error(f"Error while reconfiguring. {ex.args[0]}")
def listen(self):
def on_new_master(workerThread):
self.reconfigure()
for sentinel in self.sentinel.sentinels:
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
)
self.sentinel_threads[sentinel_host].start()
class RedisManager(object): class RedisManager(object):
def __init__(self, redis_config_dict): def __init__(self, redis_config_dict):
self.aliases = {} self.aliases = {}
for alias, alias_config in redis_config_dict.items(): for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict() alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(f"secure.redis.{alias}.password", None) alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False) is_cluster = alias_config.get("cluster", False)
@ -154,34 +60,15 @@ class RedisManager(object):
if password: if password:
alias_config["password"] = password alias_config["password"] = password
db = alias_config.get("db", 0) if not port or not host:
sentinels = alias_config.get("sentinels", None)
service_name = alias_config.get("service_name", None)
if not is_cluster and sentinels:
raise ConfigError(
"Redis configuration is invalid. mixed regular and cluster mode",
alias=alias,
)
if is_cluster and (not sentinels or not service_name):
raise ConfigError(
"Redis configuration is invalid. missing sentinels or service_name",
alias=alias,
)
if not is_cluster and (not port or not host):
raise ConfigError( raise ConfigError(
"Redis configuration is invalid. missing port or host", alias=alias "Redis configuration is invalid. missing port or host", alias=alias
) )
if is_cluster: if is_cluster:
# todo support all redis connection args via sentinel's connection_kwargs
del alias_config["sentinels"]
del alias_config["cluster"] del alias_config["cluster"]
del alias_config["service_name"] del alias_config["db"]
self.aliases[alias] = RedisCluster( self.aliases[alias] = RedisCluster(**alias_config)
sentinels, service_name, **alias_config
)
else: else:
self.aliases[alias] = StrictRedis(**alias_config) self.aliases[alias] = StrictRedis(**alias_config)
@ -189,27 +76,21 @@ class RedisManager(object):
obj = self.aliases.get(alias) obj = self.aliases.get(alias)
if not obj: if not obj:
raise GeneralError(f"Invalid Redis alias {alias}") raise GeneralError(f"Invalid Redis alias {alias}")
if isinstance(obj, RedisCluster):
obj.master.get("health")
return obj.master
else:
obj.get("health") obj.get("health")
return obj return obj
def host(self, alias): def host(self, alias):
r = self.connection(alias) r = self.connection(alias)
pool = r.connection_pool if isinstance(r, RedisCluster):
if isinstance(pool, SentinelConnectionPool): connections = first(r.connection_pool._available_connections.values())
connections = pool.connection_kwargs[
"connection_pool"
]._available_connections
else: else:
connections = pool._available_connections connections = r.connection_pool._available_connections
if len(connections) > 0: if not connections:
return connections[0].host
else:
return None return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis")) redman = RedisManager(config.get("hosts.redis"))

View File

@ -7,7 +7,7 @@ from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json from apiserver.utilities import json
@ -52,6 +52,19 @@ class RequestHandlers:
if call.result.cookies: if call.result.cookies:
for key, value in call.result.cookies.items(): for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies").copy() kwargs = config.get("apiserver.auth.cookies").copy()
if value is None:
# Removing a cookie
kwargs["max_age"] = 0
kwargs["expires"] = 0
value = ""
elif not company:
# Setting a cookie, let's try to figure out the company
# noinspection PyBroadException
try:
company = Token.decode_identity(value).company
except Exception:
pass
if company: if company:
try: try:
# use no default value to allow setting a null domain as well # use no default value to allow setting a null domain as well
@ -59,11 +72,6 @@ class RequestHandlers:
except KeyError: except KeyError:
pass pass
if value is None:
kwargs["max_age"] = 0
kwargs["expires"] = 0
value = ""
response.set_cookie(key, value, **kwargs) response.set_cookie(key, value, **kwargs)
return response return response

View File

@ -12,6 +12,9 @@ from .payload import Payload
token_secret = config.get('secure.auth.token_secret') token_secret = config.get('secure.auth.token_secret')
log = config.logger(__file__)
class Token(Payload): class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec') default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
@ -94,3 +97,14 @@ class Token(Payload):
token.exp = now + timedelta(seconds=expiration_sec) token.exp = now + timedelta(seconds=expiration_sec)
return token.encode(**extra_payload) return token.encode(**extra_payload)
@classmethod
def decode_identity(cls, encoded_token):
# noinspection PyBroadException
try:
from ..auth import Identity
decoded = cls.decode(encoded_token, verify=False)
return Identity.from_dict(decoded.get("identity", {}))
except Exception as ex:
log.error(f"Failed parsing identity from encoded token: {ex}")