import json from hashlib import sha256 from typing import Optional import attr from flask import abort, Response, Request from redis import StrictRedis from clearml_agent.backend_api import Session from werkzeug.exceptions import HTTPException from config import config from redis_manager import redman log = config.logger(__file__) @attr.s(auto_attribs=True) class TokenInfo: company: str user: str class FileserverSession(Session): @property def client(self): return "fileserver" @client.setter def client(self, _): # do not allow the base class to override the client pass class AuthHandler: enabled = config.get("fileserver.auth.enabled", False) _instance = None @classmethod def instance(cls): if not cls.enabled: return None if not cls._instance: cls._instance = cls() return cls._instance def __init__(self): self.session = FileserverSession( api_key=config.get("secure.credentials.fileserver.user_key"), secret_key=config.get("secure.credentials.fileserver.user_secret"), host=config.get("hosts.api_server"), initialize_logging=False, ) self.redis: StrictRedis = redman.connection("fileserver") def _validate_and_get_token_info(self, token: str) -> TokenInfo: token_hash = sha256(token.encode()).hexdigest() if len(token) > 256 else token key = f"token_{token_hash}" token_data = self.redis.get(key) if token_data: return TokenInfo(**json.loads(token_data)) try: res = self.session.send_request( service="auth", action="validate_token", json={"token": token} ) if res.status_code == 500: log.error("Error validating token") abort(Response(f"Internal error (status={res.status_code})", 500)) elif res.status_code != 200: log.error("Error validating token") abort(res.status_code) data = res.json()["data"] if not data["valid"]: log.error(f"Error validating token: {data['msg']}") abort(Response(data["msg"], 401)) info = TokenInfo( company=data.get("company", "unknown"), user=data.get("user"), ) timeout_sec = config.get( "fileserver.auth.tokens_cache_threshold_sec", 12 * 60 * 60 ) self.redis.setex(key, time=timeout_sec, value=json.dumps(attr.asdict(info))) return info except HTTPException: raise except Exception: log.exception(f"Failed decoding token") abort(500) def validate(self, request: Request): token = self.get_token(request) if not token: log.error("Error getting token") abort(401) self._validate_and_get_token_info(token) @staticmethod def get_token(request: Request) -> Optional[str]: auth_header = request.headers.get("Authorization") if auth_header: if not auth_header.startswith("Bearer "): log.error("Only bearer token authorization is supported") abort( Response("Only bearer token authorization is supported", status=401) ) token = auth_header.partition(" ")[2] return token last_ex = None for cookie_name in config.get("fileserver.auth.cookie_names", []): cookie = request.cookies.get(cookie_name) if not cookie: continue try: return cookie except HTTPException as ex: last_ex = ex if last_ex: raise last_ex