mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add Redis cluster support
Fix for lru_cache usage
This commit is contained in:
parent
970a32287a
commit
c4001b4037
apiserver
@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import lru_cache
|
||||||
from os import getenv
|
from os import getenv
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
@ -81,11 +82,7 @@ class ESFactory:
|
|||||||
if not hosts:
|
if not hosts:
|
||||||
raise InvalidClusterConfiguration(cluster_name)
|
raise InvalidClusterConfiguration(cluster_name)
|
||||||
|
|
||||||
http_auth = (
|
http_auth = cls.get_credentials(cluster_name)
|
||||||
cls.get_credentials(cluster_name)
|
|
||||||
if cluster_config.get("secure", True)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
args = cluster_config.get("args", {})
|
args = cluster_config.get("args", {})
|
||||||
_instances[cluster_name] = Elasticsearch(
|
_instances[cluster_name] = Elasticsearch(
|
||||||
@ -95,7 +92,11 @@ class ESFactory:
|
|||||||
return _instances[cluster_name]
|
return _instances[cluster_name]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_credentials(cls, cluster_name: str) -> Optional[Tuple[str, str]]:
|
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
|
||||||
|
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
|
||||||
|
if not cluster_config.get("secure", True):
|
||||||
|
return None
|
||||||
|
|
||||||
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
||||||
if not elastic_user:
|
if not elastic_user:
|
||||||
return None
|
return None
|
||||||
@ -119,6 +120,7 @@ class ESFactory:
|
|||||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@lru_cache()
|
||||||
def get_cluster_config(cls, cluster_name):
|
def get_cluster_config(cls, cluster_name):
|
||||||
"""
|
"""
|
||||||
Returns cluster config for the specified cluster path
|
Returns cluster config for the specified cluster path
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import threading
|
|
||||||
from os import getenv
|
from os import getenv
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
from boltons.iterutils import first
|
from boltons.iterutils import first
|
||||||
from redis import StrictRedis
|
from redis import StrictRedis
|
||||||
from redis.sentinel import Sentinel, SentinelConnectionPool
|
from rediscluster import RedisCluster
|
||||||
|
|
||||||
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
@ -38,107 +36,15 @@ if OVERRIDE_PORT:
|
|||||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||||
|
|
||||||
|
|
||||||
class MyPubSubWorkerThread(threading.Thread):
|
|
||||||
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
|
|
||||||
super(MyPubSubWorkerThread, self).__init__()
|
|
||||||
self.daemon = daemon
|
|
||||||
self.sentinel = sentinel
|
|
||||||
self.on_new_master = on_new_master
|
|
||||||
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
|
||||||
self.msg_sleep_time = msg_sleep_time
|
|
||||||
self._running = False
|
|
||||||
self.pubsub = None
|
|
||||||
|
|
||||||
def subscribe(self):
|
|
||||||
if self.pubsub:
|
|
||||||
try:
|
|
||||||
self.pubsub.unsubscribe()
|
|
||||||
self.pubsub.punsubscribe()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
self.pubsub = None
|
|
||||||
|
|
||||||
subscriptions = {"+switch-master": self.on_new_master}
|
|
||||||
|
|
||||||
while not self.pubsub or not self.pubsub.subscribed:
|
|
||||||
try:
|
|
||||||
self.pubsub = self.sentinel.pubsub()
|
|
||||||
self.pubsub.subscribe(**subscriptions)
|
|
||||||
except Exception as ex:
|
|
||||||
log.warn(
|
|
||||||
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
|
|
||||||
)
|
|
||||||
sleep(3)
|
|
||||||
log.info(f"Subscribed to sentinel {self.sentinel_host}")
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
self._running = True
|
|
||||||
|
|
||||||
self.subscribe()
|
|
||||||
|
|
||||||
while self.pubsub.subscribed:
|
|
||||||
try:
|
|
||||||
self.pubsub.get_message(
|
|
||||||
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
log.warn(
|
|
||||||
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
|
|
||||||
)
|
|
||||||
self.subscribe()
|
|
||||||
|
|
||||||
self.pubsub.close()
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
# stopping simply unsubscribes from all channels and patterns.
|
|
||||||
# the unsubscribe responses that are generated will short circuit
|
|
||||||
# the loop in run(), calling pubsub.close() to clean up the connection
|
|
||||||
self.pubsub.unsubscribe()
|
|
||||||
self.pubsub.punsubscribe()
|
|
||||||
|
|
||||||
|
|
||||||
# todo,future - multi master clusters?
|
|
||||||
class RedisCluster(object):
|
|
||||||
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
|
|
||||||
self.service_name = service_name
|
|
||||||
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
|
|
||||||
self.master = None
|
|
||||||
self.master_host_port = None
|
|
||||||
self.reconfigure()
|
|
||||||
self.sentinel_threads = {}
|
|
||||||
self.listen()
|
|
||||||
|
|
||||||
def reconfigure(self):
|
|
||||||
try:
|
|
||||||
self.master_host_port = self.sentinel.discover_master(self.service_name)
|
|
||||||
self.master = self.sentinel.master_for(self.service_name)
|
|
||||||
log.info(f"Reconfigured master to {self.master_host_port}")
|
|
||||||
except Exception as ex:
|
|
||||||
log.error(f"Error while reconfiguring. {ex.args[0]}")
|
|
||||||
|
|
||||||
def listen(self):
|
|
||||||
def on_new_master(workerThread):
|
|
||||||
self.reconfigure()
|
|
||||||
|
|
||||||
for sentinel in self.sentinel.sentinels:
|
|
||||||
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
|
||||||
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
|
|
||||||
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
|
|
||||||
)
|
|
||||||
self.sentinel_threads[sentinel_host].start()
|
|
||||||
|
|
||||||
|
|
||||||
class RedisManager(object):
|
class RedisManager(object):
|
||||||
def __init__(self, redis_config_dict):
|
def __init__(self, redis_config_dict):
|
||||||
self.aliases = {}
|
self.aliases = {}
|
||||||
for alias, alias_config in redis_config_dict.items():
|
for alias, alias_config in redis_config_dict.items():
|
||||||
|
|
||||||
alias_config = alias_config.as_plain_ordered_dict()
|
alias_config = alias_config.as_plain_ordered_dict()
|
||||||
alias_config["password"] = config.get(f"secure.redis.{alias}.password", None)
|
alias_config["password"] = config.get(
|
||||||
|
f"secure.redis.{alias}.password", None
|
||||||
|
)
|
||||||
|
|
||||||
is_cluster = alias_config.get("cluster", False)
|
is_cluster = alias_config.get("cluster", False)
|
||||||
|
|
||||||
@ -154,34 +60,15 @@ class RedisManager(object):
|
|||||||
if password:
|
if password:
|
||||||
alias_config["password"] = password
|
alias_config["password"] = password
|
||||||
|
|
||||||
db = alias_config.get("db", 0)
|
if not port or not host:
|
||||||
|
|
||||||
sentinels = alias_config.get("sentinels", None)
|
|
||||||
service_name = alias_config.get("service_name", None)
|
|
||||||
|
|
||||||
if not is_cluster and sentinels:
|
|
||||||
raise ConfigError(
|
|
||||||
"Redis configuration is invalid. mixed regular and cluster mode",
|
|
||||||
alias=alias,
|
|
||||||
)
|
|
||||||
if is_cluster and (not sentinels or not service_name):
|
|
||||||
raise ConfigError(
|
|
||||||
"Redis configuration is invalid. missing sentinels or service_name",
|
|
||||||
alias=alias,
|
|
||||||
)
|
|
||||||
if not is_cluster and (not port or not host):
|
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"Redis configuration is invalid. missing port or host", alias=alias
|
"Redis configuration is invalid. missing port or host", alias=alias
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_cluster:
|
if is_cluster:
|
||||||
# todo support all redis connection args via sentinel's connection_kwargs
|
|
||||||
del alias_config["sentinels"]
|
|
||||||
del alias_config["cluster"]
|
del alias_config["cluster"]
|
||||||
del alias_config["service_name"]
|
del alias_config["db"]
|
||||||
self.aliases[alias] = RedisCluster(
|
self.aliases[alias] = RedisCluster(**alias_config)
|
||||||
sentinels, service_name, **alias_config
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.aliases[alias] = StrictRedis(**alias_config)
|
self.aliases[alias] = StrictRedis(**alias_config)
|
||||||
|
|
||||||
@ -189,27 +76,21 @@ class RedisManager(object):
|
|||||||
obj = self.aliases.get(alias)
|
obj = self.aliases.get(alias)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise GeneralError(f"Invalid Redis alias {alias}")
|
raise GeneralError(f"Invalid Redis alias {alias}")
|
||||||
if isinstance(obj, RedisCluster):
|
|
||||||
obj.master.get("health")
|
|
||||||
return obj.master
|
|
||||||
else:
|
|
||||||
obj.get("health")
|
obj.get("health")
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def host(self, alias):
|
def host(self, alias):
|
||||||
r = self.connection(alias)
|
r = self.connection(alias)
|
||||||
pool = r.connection_pool
|
if isinstance(r, RedisCluster):
|
||||||
if isinstance(pool, SentinelConnectionPool):
|
connections = first(r.connection_pool._available_connections.values())
|
||||||
connections = pool.connection_kwargs[
|
|
||||||
"connection_pool"
|
|
||||||
]._available_connections
|
|
||||||
else:
|
else:
|
||||||
connections = pool._available_connections
|
connections = r.connection_pool._available_connections
|
||||||
|
|
||||||
if len(connections) > 0:
|
if not connections:
|
||||||
return connections[0].host
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return connections[0].host
|
||||||
|
|
||||||
|
|
||||||
redman = RedisManager(config.get("hosts.redis"))
|
redman = RedisManager(config.get("hosts.redis"))
|
||||||
|
@ -7,7 +7,7 @@ from apiserver.apierrors import APIError
|
|||||||
from apiserver.apierrors.base import BaseError
|
from apiserver.apierrors.base import BaseError
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.service_repo import ServiceRepo, APICall
|
from apiserver.service_repo import ServiceRepo, APICall
|
||||||
from apiserver.service_repo.auth import AuthType
|
from apiserver.service_repo.auth import AuthType, Token
|
||||||
from apiserver.service_repo.errors import PathParsingError
|
from apiserver.service_repo.errors import PathParsingError
|
||||||
from apiserver.utilities import json
|
from apiserver.utilities import json
|
||||||
|
|
||||||
@ -52,6 +52,19 @@ class RequestHandlers:
|
|||||||
if call.result.cookies:
|
if call.result.cookies:
|
||||||
for key, value in call.result.cookies.items():
|
for key, value in call.result.cookies.items():
|
||||||
kwargs = config.get("apiserver.auth.cookies").copy()
|
kwargs = config.get("apiserver.auth.cookies").copy()
|
||||||
|
if value is None:
|
||||||
|
# Removing a cookie
|
||||||
|
kwargs["max_age"] = 0
|
||||||
|
kwargs["expires"] = 0
|
||||||
|
value = ""
|
||||||
|
elif not company:
|
||||||
|
# Setting a cookie, let's try to figure out the company
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
company = Token.decode_identity(value).company
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if company:
|
if company:
|
||||||
try:
|
try:
|
||||||
# use no default value to allow setting a null domain as well
|
# use no default value to allow setting a null domain as well
|
||||||
@ -59,11 +72,6 @@ class RequestHandlers:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if value is None:
|
|
||||||
kwargs["max_age"] = 0
|
|
||||||
kwargs["expires"] = 0
|
|
||||||
value = ""
|
|
||||||
|
|
||||||
response.set_cookie(key, value, **kwargs)
|
response.set_cookie(key, value, **kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -12,6 +12,9 @@ from .payload import Payload
|
|||||||
token_secret = config.get('secure.auth.token_secret')
|
token_secret = config.get('secure.auth.token_secret')
|
||||||
|
|
||||||
|
|
||||||
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class Token(Payload):
|
class Token(Payload):
|
||||||
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
|
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
|
||||||
|
|
||||||
@ -94,3 +97,14 @@ class Token(Payload):
|
|||||||
token.exp = now + timedelta(seconds=expiration_sec)
|
token.exp = now + timedelta(seconds=expiration_sec)
|
||||||
|
|
||||||
return token.encode(**extra_payload)
|
return token.encode(**extra_payload)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode_identity(cls, encoded_token):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from ..auth import Identity
|
||||||
|
|
||||||
|
decoded = cls.decode(encoded_token, verify=False)
|
||||||
|
return Identity.from_dict(decoded.get("identity", {}))
|
||||||
|
except Exception as ex:
|
||||||
|
log.error(f"Failed parsing identity from encoded token: {ex}")
|
||||||
|
Loading…
Reference in New Issue
Block a user