mirror of
https://github.com/clearml/clearml-server
synced 2025-04-23 07:34:37 +00:00
Add Redis cluster support
Fix for lru_cache usage
This commit is contained in:
parent
970a32287a
commit
c4001b4037
apiserver
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from typing import Tuple, Optional
|
||||
|
||||
@ -81,11 +82,7 @@ class ESFactory:
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
|
||||
http_auth = (
|
||||
cls.get_credentials(cluster_name)
|
||||
if cluster_config.get("secure", True)
|
||||
else None
|
||||
)
|
||||
http_auth = cls.get_credentials(cluster_name)
|
||||
|
||||
args = cluster_config.get("args", {})
|
||||
_instances[cluster_name] = Elasticsearch(
|
||||
@ -95,7 +92,11 @@ class ESFactory:
|
||||
return _instances[cluster_name]
|
||||
|
||||
@classmethod
|
||||
def get_credentials(cls, cluster_name: str) -> Optional[Tuple[str, str]]:
|
||||
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
|
||||
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
|
||||
if not cluster_config.get("secure", True):
|
||||
return None
|
||||
|
||||
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
||||
if not elastic_user:
|
||||
return None
|
||||
@ -119,6 +120,7 @@ class ESFactory:
|
||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_cluster_config(cls, cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
|
@ -1,10 +1,8 @@
|
||||
import threading
|
||||
from os import getenv
|
||||
from time import sleep
|
||||
|
||||
from boltons.iterutils import first
|
||||
from redis import StrictRedis
|
||||
from redis.sentinel import Sentinel, SentinelConnectionPool
|
||||
from rediscluster import RedisCluster
|
||||
|
||||
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
||||
from apiserver.config_repo import config
|
||||
@ -38,107 +36,15 @@ if OVERRIDE_PORT:
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
|
||||
|
||||
class MyPubSubWorkerThread(threading.Thread):
|
||||
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
|
||||
super(MyPubSubWorkerThread, self).__init__()
|
||||
self.daemon = daemon
|
||||
self.sentinel = sentinel
|
||||
self.on_new_master = on_new_master
|
||||
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.msg_sleep_time = msg_sleep_time
|
||||
self._running = False
|
||||
self.pubsub = None
|
||||
|
||||
def subscribe(self):
|
||||
if self.pubsub:
|
||||
try:
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self.pubsub = None
|
||||
|
||||
subscriptions = {"+switch-master": self.on_new_master}
|
||||
|
||||
while not self.pubsub or not self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub = self.sentinel.pubsub()
|
||||
self.pubsub.subscribe(**subscriptions)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
|
||||
)
|
||||
sleep(3)
|
||||
log.info(f"Subscribed to sentinel {self.sentinel_host}")
|
||||
|
||||
def run(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
|
||||
self.subscribe()
|
||||
|
||||
while self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
|
||||
)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
|
||||
)
|
||||
self.subscribe()
|
||||
|
||||
self.pubsub.close()
|
||||
self._running = False
|
||||
|
||||
def stop(self):
|
||||
# stopping simply unsubscribes from all channels and patterns.
|
||||
# the unsubscribe responses that are generated will short circuit
|
||||
# the loop in run(), calling pubsub.close() to clean up the connection
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
|
||||
|
||||
# todo,future - multi master clusters?
|
||||
class RedisCluster(object):
|
||||
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
|
||||
self.service_name = service_name
|
||||
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
|
||||
self.master = None
|
||||
self.master_host_port = None
|
||||
self.reconfigure()
|
||||
self.sentinel_threads = {}
|
||||
self.listen()
|
||||
|
||||
def reconfigure(self):
|
||||
try:
|
||||
self.master_host_port = self.sentinel.discover_master(self.service_name)
|
||||
self.master = self.sentinel.master_for(self.service_name)
|
||||
log.info(f"Reconfigured master to {self.master_host_port}")
|
||||
except Exception as ex:
|
||||
log.error(f"Error while reconfiguring. {ex.args[0]}")
|
||||
|
||||
def listen(self):
|
||||
def on_new_master(workerThread):
|
||||
self.reconfigure()
|
||||
|
||||
for sentinel in self.sentinel.sentinels:
|
||||
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
|
||||
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
|
||||
)
|
||||
self.sentinel_threads[sentinel_host].start()
|
||||
|
||||
|
||||
class RedisManager(object):
|
||||
def __init__(self, redis_config_dict):
|
||||
self.aliases = {}
|
||||
for alias, alias_config in redis_config_dict.items():
|
||||
|
||||
alias_config = alias_config.as_plain_ordered_dict()
|
||||
alias_config["password"] = config.get(f"secure.redis.{alias}.password", None)
|
||||
alias_config["password"] = config.get(
|
||||
f"secure.redis.{alias}.password", None
|
||||
)
|
||||
|
||||
is_cluster = alias_config.get("cluster", False)
|
||||
|
||||
@ -154,34 +60,15 @@ class RedisManager(object):
|
||||
if password:
|
||||
alias_config["password"] = password
|
||||
|
||||
db = alias_config.get("db", 0)
|
||||
|
||||
sentinels = alias_config.get("sentinels", None)
|
||||
service_name = alias_config.get("service_name", None)
|
||||
|
||||
if not is_cluster and sentinels:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. mixed regular and cluster mode",
|
||||
alias=alias,
|
||||
)
|
||||
if is_cluster and (not sentinels or not service_name):
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing sentinels or service_name",
|
||||
alias=alias,
|
||||
)
|
||||
if not is_cluster and (not port or not host):
|
||||
if not port or not host:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing port or host", alias=alias
|
||||
)
|
||||
|
||||
if is_cluster:
|
||||
# todo support all redis connection args via sentinel's connection_kwargs
|
||||
del alias_config["sentinels"]
|
||||
del alias_config["cluster"]
|
||||
del alias_config["service_name"]
|
||||
self.aliases[alias] = RedisCluster(
|
||||
sentinels, service_name, **alias_config
|
||||
)
|
||||
del alias_config["db"]
|
||||
self.aliases[alias] = RedisCluster(**alias_config)
|
||||
else:
|
||||
self.aliases[alias] = StrictRedis(**alias_config)
|
||||
|
||||
@ -189,27 +76,21 @@ class RedisManager(object):
|
||||
obj = self.aliases.get(alias)
|
||||
if not obj:
|
||||
raise GeneralError(f"Invalid Redis alias {alias}")
|
||||
if isinstance(obj, RedisCluster):
|
||||
obj.master.get("health")
|
||||
return obj.master
|
||||
else:
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
def host(self, alias):
|
||||
r = self.connection(alias)
|
||||
pool = r.connection_pool
|
||||
if isinstance(pool, SentinelConnectionPool):
|
||||
connections = pool.connection_kwargs[
|
||||
"connection_pool"
|
||||
]._available_connections
|
||||
if isinstance(r, RedisCluster):
|
||||
connections = first(r.connection_pool._available_connections.values())
|
||||
else:
|
||||
connections = pool._available_connections
|
||||
connections = r.connection_pool._available_connections
|
||||
|
||||
if len(connections) > 0:
|
||||
return connections[0].host
|
||||
else:
|
||||
if not connections:
|
||||
return None
|
||||
|
||||
return connections[0].host
|
||||
|
||||
|
||||
redman = RedisManager(config.get("hosts.redis"))
|
||||
|
@ -7,7 +7,7 @@ from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
from apiserver.service_repo.auth import AuthType
|
||||
from apiserver.service_repo.auth import AuthType, Token
|
||||
from apiserver.service_repo.errors import PathParsingError
|
||||
from apiserver.utilities import json
|
||||
|
||||
@ -52,6 +52,19 @@ class RequestHandlers:
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
kwargs = config.get("apiserver.auth.cookies").copy()
|
||||
if value is None:
|
||||
# Removing a cookie
|
||||
kwargs["max_age"] = 0
|
||||
kwargs["expires"] = 0
|
||||
value = ""
|
||||
elif not company:
|
||||
# Setting a cookie, let's try to figure out the company
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
company = Token.decode_identity(value).company
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if company:
|
||||
try:
|
||||
# use no default value to allow setting a null domain as well
|
||||
@ -59,11 +72,6 @@ class RequestHandlers:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if value is None:
|
||||
kwargs["max_age"] = 0
|
||||
kwargs["expires"] = 0
|
||||
value = ""
|
||||
|
||||
response.set_cookie(key, value, **kwargs)
|
||||
|
||||
return response
|
||||
|
@ -12,6 +12,9 @@ from .payload import Payload
|
||||
token_secret = config.get('secure.auth.token_secret')
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Token(Payload):
|
||||
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
|
||||
|
||||
@ -94,3 +97,14 @@ class Token(Payload):
|
||||
token.exp = now + timedelta(seconds=expiration_sec)
|
||||
|
||||
return token.encode(**extra_payload)
|
||||
|
||||
@classmethod
|
||||
def decode_identity(cls, encoded_token):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from ..auth import Identity
|
||||
|
||||
decoded = cls.decode(encoded_token, verify=False)
|
||||
return Identity.from_dict(decoded.get("identity", {}))
|
||||
except Exception as ex:
|
||||
log.error(f"Failed parsing identity from encoded token: {ex}")
|
||||
|
Loading…
Reference in New Issue
Block a user