diff --git a/fileserver/auth.py b/fileserver/auth.py new file mode 100644 index 0000000..c1528a3 --- /dev/null +++ b/fileserver/auth.py @@ -0,0 +1,129 @@ +import json +from hashlib import sha256 +from typing import Optional + +import attr +from flask import abort, Response, Request +from redis import StrictRedis +from clearml_agent.backend_api import Session +from werkzeug.exceptions import HTTPException + +from config import config +from redis_manager import redman + +log = config.logger(__file__) + + +@attr.s(auto_attribs=True) +class TokenInfo: + company: str + user: str + + +class FileserverSession(Session): + @property + def client(self): + return "fileserver" + + @client.setter + def client(self, _): + # do not allow the base class to override the client + pass + + +class AuthHandler: + enabled = config.get("fileserver.auth.enabled", False) + _instance = None + + @classmethod + def instance(cls): + if not cls.enabled: + return None + if not cls._instance: + cls._instance = cls() + return cls._instance + + def __init__(self): + self.session = FileserverSession( + api_key=config.get("secure.credentials.fileserver.user_key"), + secret_key=config.get("secure.credentials.fileserver.user_secret"), + host=config.get("hosts.api_server"), + initialize_logging=False, + ) + + self.redis: StrictRedis = redman.connection("fileserver") + + def _validate_and_get_token_info(self, token: str) -> TokenInfo: + token_hash = sha256(token.encode()).hexdigest() if len(token) > 256 else token + key = f"token_{token_hash}" + token_data = self.redis.get(key) + if token_data: + return TokenInfo(**json.loads(token_data)) + + try: + res = self.session.send_request( + service="auth", action="validate_token", json={"token": token} + ) + + if res.status_code == 500: + log.error("Error validating token") + abort(Response(f"Internal error (status={res.status_code})", 500)) + elif res.status_code != 200: + log.error("Error validating token") + abort(res.status_code) + + data = res.json()["data"] + if not data["valid"]: + log.error(f"Error validating token: {data['msg']}") + abort(Response(data["msg"], 401)) + + info = TokenInfo( + company=data.get("company", "unknown"), + user=data.get("user"), + ) + + timeout_sec = config.get( + "fileserver.auth.tokens_cache_threshold_sec", 12 * 60 * 60 + ) + self.redis.setex(key, time=timeout_sec, value=json.dumps(attr.asdict(info))) + + return info + + except HTTPException: + raise + except Exception: + log.exception(f"Failed decoding token") + abort(500) + + def validate(self, request: Request): + token = self.get_token(request) + if not token: + log.error("Error getting token") + abort(401) + + self._validate_and_get_token_info(token) + + @staticmethod + def get_token(request: Request) -> Optional[str]: + auth_header = request.headers.get("Authorization") + if auth_header: + if not auth_header.startswith("Bearer "): + log.error("Only bearer token authorization is supported") + abort( + Response("Only bearer token authorization is supported", status=401) + ) + token = auth_header.partition(" ")[2] + return token + + last_ex = None + for cookie_name in config.get("fileserver.auth.cookie_names", []): + cookie = request.cookies.get(cookie_name) + if not cookie: + continue + try: + return cookie + except HTTPException as ex: + last_ex = ex + + if last_ex: + raise last_ex diff --git a/fileserver/config/default/fileserver.conf b/fileserver/config/default/fileserver.conf index ad1b7e2..ef0963d 100644 --- a/fileserver/config/default/fileserver.conf +++ b/fileserver/config/default/fileserver.conf @@ -12,4 +12,15 @@ delete { cors { origins: "*" +} + +auth { + # enable/disable auth validation on upload/download + enabled: false + + # names of cookies in which authorization token can be found + cookie_names: ["clearml_token_basic"] + + tokens_cache_threshold_sec: 43200 + } \ No newline at end of file diff --git a/fileserver/config/default/hosts.conf b/fileserver/config/default/hosts.conf new file mode 100644 index 0000000..deec888 --- /dev/null +++ b/fileserver/config/default/hosts.conf @@ -0,0 +1,7 @@ +redis { + fileserver { + host: "redis" + port: 6379 + db: 8 + } +} \ No newline at end of file diff --git a/fileserver/config/default/secure.conf b/fileserver/config/default/secure.conf new file mode 100644 index 0000000..7d09d2a --- /dev/null +++ b/fileserver/config/default/secure.conf @@ -0,0 +1,7 @@ +credentials { + # system credentials as they appear in the auth DB, used for intra-service communications + fileserver { + user_key: "" + user_secret: "" + } +} \ No newline at end of file diff --git a/fileserver/fileserver.py b/fileserver/fileserver.py index c6d15c7..2984b6c 100644 --- a/fileserver/fileserver.py +++ b/fileserver/fileserver.py @@ -15,6 +15,7 @@ from flask_cors import CORS from werkzeug.exceptions import NotFound from werkzeug.security import safe_join +from auth import AuthHandler from config import config from utils import get_env_bool @@ -34,10 +35,14 @@ app.config["UPLOAD_FOLDER"] = first( app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get( "fileserver.download.cache_timeout_sec", 5 * 60 ) +auth_handler = AuthHandler.instance() @app.route("/", methods=["GET"]) def ping(): + if auth_handler and auth_handler.get_token(request): + auth_handler.validate(request) + return "OK", 200 @@ -57,6 +62,9 @@ def after_request(response): @app.route("/", methods=["POST"]) def upload(): + if auth_handler: + auth_handler.validate(request) + results = [] for filename, file in request.files.items(): if not filename: @@ -76,6 +84,9 @@ def upload(): @app.route("/", methods=["GET"]) def download(path): + if auth_handler: + auth_handler.validate(request) + as_attachment = "download" in request.args _, encoding = mimetypes.guess_type(os.path.basename(path)) @@ -105,6 +116,9 @@ def _get_full_path(path: str) -> Path: @app.route("/", methods=["DELETE"]) def delete(path): + if auth_handler: + auth_handler.validate(request) + full_path = _get_full_path(path) if not full_path.exists() or not full_path.is_file(): log.error(f"Error deleting file {str(full_path)}. Not found or not a file") @@ -117,6 +131,9 @@ def delete(path): def batch_delete(): + if auth_handler: + auth_handler.validate(request) + body = request.get_json(force=True, silent=False) if not body: abort(Response("Json payload is missing", 400)) diff --git a/fileserver/redis_manager.py b/fileserver/redis_manager.py new file mode 100644 index 0000000..eda7ec0 --- /dev/null +++ b/fileserver/redis_manager.py @@ -0,0 +1,96 @@ +from os import getenv + +from boltons.iterutils import first +from redis import StrictRedis +from redis.cluster import RedisCluster + +from apiserver.apierrors.errors.server_error import ConfigError, GeneralError +from apiserver.config_repo import config + +log = config.logger(__file__) + +OVERRIDE_HOST_ENV_KEY = ( + "CLEARML_REDIS_SERVICE_HOST", + "TRAINS_REDIS_SERVICE_HOST", + "REDIS_SERVICE_HOST", +) +OVERRIDE_PORT_ENV_KEY = ( + "CLEARML_REDIS_SERVICE_PORT", + "TRAINS_REDIS_SERVICE_PORT", + "REDIS_SERVICE_PORT", +) +OVERRIDE_PASSWORD_ENV_KEY = ( + "CLEARML_REDIS_SERVICE_PASSWORD", + "TRAINS_REDIS_SERVICE_PASSWORD", + "REDIS_SERVICE_PASSWORD", +) + +OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY))) +if OVERRIDE_HOST: + log.info(f"Using override redis host {OVERRIDE_HOST}") + +OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY))) +if OVERRIDE_PORT: + log.info(f"Using override redis port {OVERRIDE_PORT}") + +OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY))) + + +class RedisManager(object): + def __init__(self, redis_config_dict): + self.aliases = {} + for alias, alias_config in redis_config_dict.items(): + + alias_config = alias_config.as_plain_ordered_dict() + alias_config["password"] = config.get( + f"secure.redis.{alias}.password", None + ) + + is_cluster = alias_config.get("cluster", False) + + host = OVERRIDE_HOST or alias_config.get("host", None) + if host: + alias_config["host"] = host + + port = OVERRIDE_PORT or alias_config.get("port", None) + if port: + alias_config["port"] = port + + password = OVERRIDE_PASSWORD or alias_config.get("password", None) + if password: + alias_config["password"] = password + + if not port or not host: + raise ConfigError( + "Redis configuration is invalid. missing port or host", alias=alias + ) + + if is_cluster: + del alias_config["cluster"] + del alias_config["db"] + self.aliases[alias] = RedisCluster(**alias_config) + else: + self.aliases[alias] = StrictRedis(**alias_config) + + def connection(self, alias) -> StrictRedis: + obj = self.aliases.get(alias) + if not obj: + raise GeneralError(f"Invalid Redis alias {alias}") + + obj.get("health") + return obj + + def host(self, alias): + r = self.connection(alias) + if isinstance(r, RedisCluster): + connections = r.get_default_node().redis_connection.connection_pool._available_connections + else: + connections = r.connection_pool._available_connections + + if not connections: + return None + + return connections[0].host + + +redman = RedisManager(config.get("hosts.redis")) diff --git a/fileserver/requirements.txt b/fileserver/requirements.txt index e990264..8395cab 100644 --- a/fileserver/requirements.txt +++ b/fileserver/requirements.txt @@ -1,9 +1,11 @@ boltons>=19.1.0 +clearml-agent>=1.5.2 flask-compress>=1.4.0 flask-cors>=3.0.5 flask>=2.3.3 gunicorn>=20.1.0 pyhocon>=0.3.35 +redis>=4.5.4,<5 setuptools>=65.5.1 urllib3>=1.26.18 werkzeug>=3.0.1 \ No newline at end of file