mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
|
import json
|
||
|
from hashlib import sha256
|
||
|
from typing import Optional
|
||
|
|
||
|
import attr
|
||
|
from flask import abort, Response, Request
|
||
|
from redis import StrictRedis
|
||
|
from clearml_agent.backend_api import Session
|
||
|
from werkzeug.exceptions import HTTPException
|
||
|
|
||
|
from config import config
|
||
|
from redis_manager import redman
|
||
|
|
||
|
log = config.logger(__file__)
|
||
|
|
||
|
|
||
|
@attr.s(auto_attribs=True)
|
||
|
class TokenInfo:
|
||
|
company: str
|
||
|
user: str
|
||
|
|
||
|
|
||
|
class FileserverSession(Session):
|
||
|
@property
|
||
|
def client(self):
|
||
|
return "fileserver"
|
||
|
|
||
|
@client.setter
|
||
|
def client(self, _):
|
||
|
# do not allow the base class to override the client
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AuthHandler:
|
||
|
enabled = config.get("fileserver.auth.enabled", False)
|
||
|
_instance = None
|
||
|
|
||
|
@classmethod
|
||
|
def instance(cls):
|
||
|
if not cls.enabled:
|
||
|
return None
|
||
|
if not cls._instance:
|
||
|
cls._instance = cls()
|
||
|
return cls._instance
|
||
|
|
||
|
def __init__(self):
|
||
|
self.session = FileserverSession(
|
||
|
api_key=config.get("secure.credentials.fileserver.user_key"),
|
||
|
secret_key=config.get("secure.credentials.fileserver.user_secret"),
|
||
|
host=config.get("hosts.api_server"),
|
||
|
initialize_logging=False,
|
||
|
)
|
||
|
|
||
|
self.redis: StrictRedis = redman.connection("fileserver")
|
||
|
|
||
|
def _validate_and_get_token_info(self, token: str) -> TokenInfo:
|
||
|
token_hash = sha256(token.encode()).hexdigest() if len(token) > 256 else token
|
||
|
key = f"token_{token_hash}"
|
||
|
token_data = self.redis.get(key)
|
||
|
if token_data:
|
||
|
return TokenInfo(**json.loads(token_data))
|
||
|
|
||
|
try:
|
||
|
res = self.session.send_request(
|
||
|
service="auth", action="validate_token", json={"token": token}
|
||
|
)
|
||
|
|
||
|
if res.status_code == 500:
|
||
|
log.error("Error validating token")
|
||
|
abort(Response(f"Internal error (status={res.status_code})", 500))
|
||
|
elif res.status_code != 200:
|
||
|
log.error("Error validating token")
|
||
|
abort(res.status_code)
|
||
|
|
||
|
data = res.json()["data"]
|
||
|
if not data["valid"]:
|
||
|
log.error(f"Error validating token: {data['msg']}")
|
||
|
abort(Response(data["msg"], 401))
|
||
|
|
||
|
info = TokenInfo(
|
||
|
company=data.get("company", "unknown"),
|
||
|
user=data.get("user"),
|
||
|
)
|
||
|
|
||
|
timeout_sec = config.get(
|
||
|
"fileserver.auth.tokens_cache_threshold_sec", 12 * 60 * 60
|
||
|
)
|
||
|
self.redis.setex(key, time=timeout_sec, value=json.dumps(attr.asdict(info)))
|
||
|
|
||
|
return info
|
||
|
|
||
|
except HTTPException:
|
||
|
raise
|
||
|
except Exception:
|
||
|
log.exception(f"Failed decoding token")
|
||
|
abort(500)
|
||
|
|
||
|
def validate(self, request: Request):
|
||
|
token = self.get_token(request)
|
||
|
if not token:
|
||
|
log.error("Error getting token")
|
||
|
abort(401)
|
||
|
|
||
|
self._validate_and_get_token_info(token)
|
||
|
|
||
|
@staticmethod
|
||
|
def get_token(request: Request) -> Optional[str]:
|
||
|
auth_header = request.headers.get("Authorization")
|
||
|
if auth_header:
|
||
|
if not auth_header.startswith("Bearer "):
|
||
|
log.error("Only bearer token authorization is supported")
|
||
|
abort(
|
||
|
Response("Only bearer token authorization is supported", status=401)
|
||
|
)
|
||
|
token = auth_header.partition(" ")[2]
|
||
|
return token
|
||
|
|
||
|
last_ex = None
|
||
|
for cookie_name in config.get("fileserver.auth.cookie_names", []):
|
||
|
cookie = request.cookies.get(cookie_name)
|
||
|
if not cookie:
|
||
|
continue
|
||
|
try:
|
||
|
return cookie
|
||
|
except HTTPException as ex:
|
||
|
last_ex = ex
|
||
|
|
||
|
if last_ex:
|
||
|
raise last_ex
|