mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 02:46:53 +00:00
Add token authorization to fileserver
This commit is contained in:
parent
5456ee4ebf
commit
7c9889605a
129
fileserver/auth.py
Normal file
129
fileserver/auth.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
|
||||
import attr
|
||||
from flask import abort, Response, Request
|
||||
from redis import StrictRedis
|
||||
from clearml_agent.backend_api import Session
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from config import config
|
||||
from redis_manager import redman
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TokenInfo:
|
||||
company: str
|
||||
user: str
|
||||
|
||||
|
||||
class FileserverSession(Session):
|
||||
@property
|
||||
def client(self):
|
||||
return "fileserver"
|
||||
|
||||
@client.setter
|
||||
def client(self, _):
|
||||
# do not allow the base class to override the client
|
||||
pass
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
enabled = config.get("fileserver.auth.enabled", False)
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls):
|
||||
if not cls.enabled:
|
||||
return None
|
||||
if not cls._instance:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.session = FileserverSession(
|
||||
api_key=config.get("secure.credentials.fileserver.user_key"),
|
||||
secret_key=config.get("secure.credentials.fileserver.user_secret"),
|
||||
host=config.get("hosts.api_server"),
|
||||
initialize_logging=False,
|
||||
)
|
||||
|
||||
self.redis: StrictRedis = redman.connection("fileserver")
|
||||
|
||||
def _validate_and_get_token_info(self, token: str) -> TokenInfo:
|
||||
token_hash = sha256(token.encode()).hexdigest() if len(token) > 256 else token
|
||||
key = f"token_{token_hash}"
|
||||
token_data = self.redis.get(key)
|
||||
if token_data:
|
||||
return TokenInfo(**json.loads(token_data))
|
||||
|
||||
try:
|
||||
res = self.session.send_request(
|
||||
service="auth", action="validate_token", json={"token": token}
|
||||
)
|
||||
|
||||
if res.status_code == 500:
|
||||
log.error("Error validating token")
|
||||
abort(Response(f"Internal error (status={res.status_code})", 500))
|
||||
elif res.status_code != 200:
|
||||
log.error("Error validating token")
|
||||
abort(res.status_code)
|
||||
|
||||
data = res.json()["data"]
|
||||
if not data["valid"]:
|
||||
log.error(f"Error validating token: {data['msg']}")
|
||||
abort(Response(data["msg"], 401))
|
||||
|
||||
info = TokenInfo(
|
||||
company=data.get("company", "unknown"),
|
||||
user=data.get("user"),
|
||||
)
|
||||
|
||||
timeout_sec = config.get(
|
||||
"fileserver.auth.tokens_cache_threshold_sec", 12 * 60 * 60
|
||||
)
|
||||
self.redis.setex(key, time=timeout_sec, value=json.dumps(attr.asdict(info)))
|
||||
|
||||
return info
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
log.exception(f"Failed decoding token")
|
||||
abort(500)
|
||||
|
||||
def validate(self, request: Request):
|
||||
token = self.get_token(request)
|
||||
if not token:
|
||||
log.error("Error getting token")
|
||||
abort(401)
|
||||
|
||||
self._validate_and_get_token_info(token)
|
||||
|
||||
@staticmethod
|
||||
def get_token(request: Request) -> Optional[str]:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
if not auth_header.startswith("Bearer "):
|
||||
log.error("Only bearer token authorization is supported")
|
||||
abort(
|
||||
Response("Only bearer token authorization is supported", status=401)
|
||||
)
|
||||
token = auth_header.partition(" ")[2]
|
||||
return token
|
||||
|
||||
last_ex = None
|
||||
for cookie_name in config.get("fileserver.auth.cookie_names", []):
|
||||
cookie = request.cookies.get(cookie_name)
|
||||
if not cookie:
|
||||
continue
|
||||
try:
|
||||
return cookie
|
||||
except HTTPException as ex:
|
||||
last_ex = ex
|
||||
|
||||
if last_ex:
|
||||
raise last_ex
|
@ -12,4 +12,15 @@ delete {
|
||||
|
||||
cors {
|
||||
origins: "*"
|
||||
}
|
||||
|
||||
auth {
|
||||
# enable/disable auth validation on upload/download
|
||||
enabled: false
|
||||
|
||||
# names of cookies in which authorization token can be found
|
||||
cookie_names: ["clearml_token_basic"]
|
||||
|
||||
tokens_cache_threshold_sec: 43200
|
||||
|
||||
}
|
7
fileserver/config/default/hosts.conf
Normal file
7
fileserver/config/default/hosts.conf
Normal file
@ -0,0 +1,7 @@
|
||||
redis {
|
||||
fileserver {
|
||||
host: "redis"
|
||||
port: 6379
|
||||
db: 8
|
||||
}
|
||||
}
|
7
fileserver/config/default/secure.conf
Normal file
7
fileserver/config/default/secure.conf
Normal file
@ -0,0 +1,7 @@
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
fileserver {
|
||||
user_key: ""
|
||||
user_secret: ""
|
||||
}
|
||||
}
|
@ -15,6 +15,7 @@ from flask_cors import CORS
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.security import safe_join
|
||||
|
||||
from auth import AuthHandler
|
||||
from config import config
|
||||
from utils import get_env_bool
|
||||
|
||||
@ -34,10 +35,14 @@ app.config["UPLOAD_FOLDER"] = first(
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
|
||||
"fileserver.download.cache_timeout_sec", 5 * 60
|
||||
)
|
||||
auth_handler = AuthHandler.instance()
|
||||
|
||||
|
||||
@app.route("/", methods=["GET"])
|
||||
def ping():
|
||||
if auth_handler and auth_handler.get_token(request):
|
||||
auth_handler.validate(request)
|
||||
|
||||
return "OK", 200
|
||||
|
||||
|
||||
@ -57,6 +62,9 @@ def after_request(response):
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def upload():
|
||||
if auth_handler:
|
||||
auth_handler.validate(request)
|
||||
|
||||
results = []
|
||||
for filename, file in request.files.items():
|
||||
if not filename:
|
||||
@ -76,6 +84,9 @@ def upload():
|
||||
|
||||
@app.route("/<path:path>", methods=["GET"])
|
||||
def download(path):
|
||||
if auth_handler:
|
||||
auth_handler.validate(request)
|
||||
|
||||
as_attachment = "download" in request.args
|
||||
|
||||
_, encoding = mimetypes.guess_type(os.path.basename(path))
|
||||
@ -105,6 +116,9 @@ def _get_full_path(path: str) -> Path:
|
||||
|
||||
@app.route("/<path:path>", methods=["DELETE"])
|
||||
def delete(path):
|
||||
if auth_handler:
|
||||
auth_handler.validate(request)
|
||||
|
||||
full_path = _get_full_path(path)
|
||||
if not full_path.exists() or not full_path.is_file():
|
||||
log.error(f"Error deleting file {str(full_path)}. Not found or not a file")
|
||||
@ -117,6 +131,9 @@ def delete(path):
|
||||
|
||||
|
||||
def batch_delete():
|
||||
if auth_handler:
|
||||
auth_handler.validate(request)
|
||||
|
||||
body = request.get_json(force=True, silent=False)
|
||||
if not body:
|
||||
abort(Response("Json payload is missing", 400))
|
||||
|
96
fileserver/redis_manager.py
Normal file
96
fileserver/redis_manager.py
Normal file
@ -0,0 +1,96 @@
|
||||
from os import getenv
|
||||
|
||||
from boltons.iterutils import first
|
||||
from redis import StrictRedis
|
||||
from redis.cluster import RedisCluster
|
||||
|
||||
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
||||
from apiserver.config_repo import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_HOST",
|
||||
"TRAINS_REDIS_SERVICE_HOST",
|
||||
"REDIS_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_PORT",
|
||||
"TRAINS_REDIS_SERVICE_PORT",
|
||||
"REDIS_SERVICE_PORT",
|
||||
)
|
||||
OVERRIDE_PASSWORD_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_PASSWORD",
|
||||
"TRAINS_REDIS_SERVICE_PASSWORD",
|
||||
"REDIS_SERVICE_PASSWORD",
|
||||
)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
log.info(f"Using override redis host {OVERRIDE_HOST}")
|
||||
|
||||
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override redis port {OVERRIDE_PORT}")
|
||||
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
|
||||
|
||||
class RedisManager(object):
|
||||
def __init__(self, redis_config_dict):
|
||||
self.aliases = {}
|
||||
for alias, alias_config in redis_config_dict.items():
|
||||
|
||||
alias_config = alias_config.as_plain_ordered_dict()
|
||||
alias_config["password"] = config.get(
|
||||
f"secure.redis.{alias}.password", None
|
||||
)
|
||||
|
||||
is_cluster = alias_config.get("cluster", False)
|
||||
|
||||
host = OVERRIDE_HOST or alias_config.get("host", None)
|
||||
if host:
|
||||
alias_config["host"] = host
|
||||
|
||||
port = OVERRIDE_PORT or alias_config.get("port", None)
|
||||
if port:
|
||||
alias_config["port"] = port
|
||||
|
||||
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
|
||||
if password:
|
||||
alias_config["password"] = password
|
||||
|
||||
if not port or not host:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing port or host", alias=alias
|
||||
)
|
||||
|
||||
if is_cluster:
|
||||
del alias_config["cluster"]
|
||||
del alias_config["db"]
|
||||
self.aliases[alias] = RedisCluster(**alias_config)
|
||||
else:
|
||||
self.aliases[alias] = StrictRedis(**alias_config)
|
||||
|
||||
def connection(self, alias) -> StrictRedis:
|
||||
obj = self.aliases.get(alias)
|
||||
if not obj:
|
||||
raise GeneralError(f"Invalid Redis alias {alias}")
|
||||
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
def host(self, alias):
|
||||
r = self.connection(alias)
|
||||
if isinstance(r, RedisCluster):
|
||||
connections = r.get_default_node().redis_connection.connection_pool._available_connections
|
||||
else:
|
||||
connections = r.connection_pool._available_connections
|
||||
|
||||
if not connections:
|
||||
return None
|
||||
|
||||
return connections[0].host
|
||||
|
||||
|
||||
redman = RedisManager(config.get("hosts.redis"))
|
@ -1,9 +1,11 @@
|
||||
boltons>=19.1.0
|
||||
clearml-agent>=1.5.2
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
flask>=2.3.3
|
||||
gunicorn>=20.1.0
|
||||
pyhocon>=0.3.35
|
||||
redis>=4.5.4,<5
|
||||
setuptools>=65.5.1
|
||||
urllib3>=1.26.18
|
||||
werkzeug>=3.0.1
|
Loading…
Reference in New Issue
Block a user