mirror of
https://github.com/clearml/clearml-server
synced 2025-04-07 14:34:12 +00:00
Invalidate token on user logoff
This commit is contained in:
parent
88abf28287
commit
d3013ac285
@ -1,4 +1,4 @@
|
||||
from .auth import get_auth_func, authorize_impersonation
|
||||
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
|
||||
from .payload import Token, Basic, AuthType, Payload
|
||||
from .identity import Identity
|
||||
from .utils import get_client_id, get_secret_key
|
||||
|
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Entities, Credentials
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.redis_manager import redman
|
||||
from .fixed_user import FixedUser
|
||||
from .identity import Identity
|
||||
from .payload import Payload, Token, Basic, AuthType
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
entity_keys = set(get_options(Entities))
|
||||
|
||||
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
|
||||
_revoked_tokens_key = "revoked_tokens"
|
||||
redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
def get_auth_func(auth_type):
|
||||
@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
|
||||
log.error(f"{msg} Call info: {info}")
|
||||
|
||||
try:
|
||||
return Token.from_encoded_token(jwt_token)
|
||||
|
||||
token = Token.from_encoded_token(jwt_token)
|
||||
if is_token_revoked(token):
|
||||
raise errors.unauthorized.InvalidToken("revoked token")
|
||||
return token
|
||||
except jwt.exceptions.InvalidKeyError as ex:
|
||||
log_error("Failed parsing token.")
|
||||
raise errors.unauthorized.InvalidToken(
|
||||
@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
|
||||
return bcrypt.checkpw(
|
||||
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
|
||||
)
|
||||
|
||||
|
||||
def is_token_revoked(token: Token) -> bool:
|
||||
if not isinstance(token, Token) or not token.session_id:
|
||||
return False
|
||||
|
||||
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
|
||||
|
||||
|
||||
def revoke_auth_token(token: Token):
|
||||
if not isinstance(token, Token) or not token.session_id:
|
||||
return
|
||||
|
||||
timestamp_now = int(time())
|
||||
expiration_timestamp = token.exp
|
||||
if not expiration_timestamp:
|
||||
expiration_timestamp = timestamp_now + Token.default_expiration_sec
|
||||
|
||||
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
|
||||
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import jwt
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
@ -20,7 +22,15 @@ class Token(Payload):
|
||||
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
|
||||
|
||||
def __init__(
|
||||
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
|
||||
self,
|
||||
exp=None,
|
||||
iat=None,
|
||||
nbf=None,
|
||||
env=None,
|
||||
identity=None,
|
||||
session_id=None,
|
||||
entities=None,
|
||||
**_,
|
||||
):
|
||||
super(Token, self).__init__(
|
||||
AuthType.bearer_token, identity=identity, entities=entities
|
||||
@ -28,8 +38,13 @@ class Token(Payload):
|
||||
self.exp = exp
|
||||
self.iat = iat
|
||||
self.nbf = nbf
|
||||
self._session_id = session_id
|
||||
self._env = env or config.get("env", "<unknown>")
|
||||
|
||||
@property
|
||||
def session_id(self):
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self._env
|
||||
@ -102,8 +117,11 @@ class Token(Payload):
|
||||
expiration_sec = expiration_sec or cls.default_expiration_sec
|
||||
|
||||
now = datetime.utcnow()
|
||||
session_id = uuid4().hex
|
||||
|
||||
token = cls(identity=identity, entities=entities, iat=now)
|
||||
token = cls(
|
||||
identity=identity, entities=entities, iat=now, session_id=session_id
|
||||
)
|
||||
|
||||
if expiration_sec:
|
||||
# add 'expiration' claim
|
||||
|
@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.auth import User, Role
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Token
|
||||
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -35,7 +36,7 @@ log = config.logger(__file__)
|
||||
response_data_model=GetTokenResponse,
|
||||
)
|
||||
def login(call: APICall, *_, **__):
|
||||
""" Generates a token based on the authenticated user (intended for use with credentials) """
|
||||
"""Generates a token based on the authenticated user (intended for use with credentials)"""
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
|
||||
|
||||
@endpoint("auth.logout", min_version="2.2")
|
||||
def logout(call: APICall, *_, **__):
|
||||
revoke_auth_token(call.auth)
|
||||
call.result.set_auth_cookie(None)
|
||||
|
||||
|
||||
@ -57,7 +59,7 @@ def logout(call: APICall, *_, **__):
|
||||
response_data_model=GetTokenResponse,
|
||||
)
|
||||
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||
"""Generates a token based on a requested user and company. INTERNAL."""
|
||||
if call.identity.role not in Role.get_system_roles():
|
||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||
raise errors.bad_request.InvalidUserId(
|
||||
@ -81,12 +83,14 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
response_data_model=ValidateResponse,
|
||||
)
|
||||
def validate_token_endpoint(call: APICall, _, __):
|
||||
""" Validate a token and return identity if valid. INTERNAL. """
|
||||
"""Validate a token and return identity if valid. INTERNAL."""
|
||||
try:
|
||||
# if invalid, decoding will fail
|
||||
token = Token.from_encoded_token(call.data_model.token)
|
||||
call.result.data_model = ValidateResponse(
|
||||
valid=True, user=token.identity.user, company=token.identity.company
|
||||
valid=not is_token_revoked(token),
|
||||
user=token.identity.user,
|
||||
company=token.identity.company,
|
||||
)
|
||||
except Exception as e:
|
||||
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
|
||||
@ -98,7 +102,7 @@ def validate_token_endpoint(call: APICall, _, __):
|
||||
response_data_model=CreateUserResponse,
|
||||
)
|
||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
""" Create a user from. INTERNAL. """
|
||||
"""Create a user from. INTERNAL."""
|
||||
if (
|
||||
call.identity.role not in Role.get_system_roles()
|
||||
and request.company != call.identity.company
|
||||
|
@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
|
||||
)
|
||||
from apiserver.config import info
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.service_repo.auth import revoke_auth_token
|
||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
|
||||
|
||||
@endpoint("login.logout", min_version="2.13")
|
||||
def logout(call: APICall, _, __):
|
||||
revoke_auth_token(call.auth)
|
||||
call.result.set_auth_cookie(None)
|
||||
|
Loading…
Reference in New Issue
Block a user