mirror of
https://github.com/clearml/clearml-server
synced 2025-04-22 23:24:24 +00:00
Invalidate token on user logoff
This commit is contained in:
parent
88abf28287
commit
d3013ac285
@ -1,4 +1,4 @@
|
|||||||
from .auth import get_auth_func, authorize_impersonation
|
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
|
||||||
from .payload import Token, Basic, AuthType, Payload
|
from .payload import Token, Basic, AuthType, Payload
|
||||||
from .identity import Identity
|
from .identity import Identity
|
||||||
from .utils import get_client_id, get_secret_key
|
from .utils import get_client_id, get_secret_key
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from time import time
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import jwt
|
import jwt
|
||||||
@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
|
|||||||
from apiserver.database.model.auth import User, Entities, Credentials
|
from apiserver.database.model.auth import User, Entities, Credentials
|
||||||
from apiserver.database.model.company import Company
|
from apiserver.database.model.company import Company
|
||||||
from apiserver.database.utils import get_options
|
from apiserver.database.utils import get_options
|
||||||
|
from apiserver.redis_manager import redman
|
||||||
from .fixed_user import FixedUser
|
from .fixed_user import FixedUser
|
||||||
from .identity import Identity
|
from .identity import Identity
|
||||||
from .payload import Payload, Token, Basic, AuthType
|
from .payload import Payload, Token, Basic, AuthType
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
|
|
||||||
entity_keys = set(get_options(Entities))
|
entity_keys = set(get_options(Entities))
|
||||||
|
|
||||||
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
|
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
|
||||||
|
_revoked_tokens_key = "revoked_tokens"
|
||||||
|
redis = redman.connection("apiserver")
|
||||||
|
|
||||||
|
|
||||||
def get_auth_func(auth_type):
|
def get_auth_func(auth_type):
|
||||||
@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
|
|||||||
log.error(f"{msg} Call info: {info}")
|
log.error(f"{msg} Call info: {info}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return Token.from_encoded_token(jwt_token)
|
token = Token.from_encoded_token(jwt_token)
|
||||||
|
if is_token_revoked(token):
|
||||||
|
raise errors.unauthorized.InvalidToken("revoked token")
|
||||||
|
return token
|
||||||
except jwt.exceptions.InvalidKeyError as ex:
|
except jwt.exceptions.InvalidKeyError as ex:
|
||||||
log_error("Failed parsing token.")
|
log_error("Failed parsing token.")
|
||||||
raise errors.unauthorized.InvalidToken(
|
raise errors.unauthorized.InvalidToken(
|
||||||
@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
|
|||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
|
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_token_revoked(token: Token) -> bool:
|
||||||
|
if not isinstance(token, Token) or not token.session_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_auth_token(token: Token):
|
||||||
|
if not isinstance(token, Token) or not token.session_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp_now = int(time())
|
||||||
|
expiration_timestamp = token.exp
|
||||||
|
if not expiration_timestamp:
|
||||||
|
expiration_timestamp = timestamp_now + Token.default_expiration_sec
|
||||||
|
|
||||||
|
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
|
||||||
|
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@ -20,7 +22,15 @@ class Token(Payload):
|
|||||||
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
|
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
|
self,
|
||||||
|
exp=None,
|
||||||
|
iat=None,
|
||||||
|
nbf=None,
|
||||||
|
env=None,
|
||||||
|
identity=None,
|
||||||
|
session_id=None,
|
||||||
|
entities=None,
|
||||||
|
**_,
|
||||||
):
|
):
|
||||||
super(Token, self).__init__(
|
super(Token, self).__init__(
|
||||||
AuthType.bearer_token, identity=identity, entities=entities
|
AuthType.bearer_token, identity=identity, entities=entities
|
||||||
@ -28,8 +38,13 @@ class Token(Payload):
|
|||||||
self.exp = exp
|
self.exp = exp
|
||||||
self.iat = iat
|
self.iat = iat
|
||||||
self.nbf = nbf
|
self.nbf = nbf
|
||||||
|
self._session_id = session_id
|
||||||
self._env = env or config.get("env", "<unknown>")
|
self._env = env or config.get("env", "<unknown>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_id(self):
|
||||||
|
return self._session_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def env(self):
|
def env(self):
|
||||||
return self._env
|
return self._env
|
||||||
@ -102,8 +117,11 @@ class Token(Payload):
|
|||||||
expiration_sec = expiration_sec or cls.default_expiration_sec
|
expiration_sec = expiration_sec or cls.default_expiration_sec
|
||||||
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
session_id = uuid4().hex
|
||||||
|
|
||||||
token = cls(identity=identity, entities=entities, iat=now)
|
token = cls(
|
||||||
|
identity=identity, entities=entities, iat=now, session_id=session_id
|
||||||
|
)
|
||||||
|
|
||||||
if expiration_sec:
|
if expiration_sec:
|
||||||
# add 'expiration' claim
|
# add 'expiration' claim
|
||||||
|
@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
|
|||||||
from apiserver.database.model.auth import User, Role
|
from apiserver.database.model.auth import User, Role
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.service_repo.auth import Token
|
from apiserver.service_repo.auth import Token
|
||||||
|
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
|
||||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
@ -35,7 +36,7 @@ log = config.logger(__file__)
|
|||||||
response_data_model=GetTokenResponse,
|
response_data_model=GetTokenResponse,
|
||||||
)
|
)
|
||||||
def login(call: APICall, *_, **__):
|
def login(call: APICall, *_, **__):
|
||||||
""" Generates a token based on the authenticated user (intended for use with credentials) """
|
"""Generates a token based on the authenticated user (intended for use with credentials)"""
|
||||||
call.result.data_model = AuthBLL.get_token_for_user(
|
call.result.data_model = AuthBLL.get_token_for_user(
|
||||||
user_id=call.identity.user,
|
user_id=call.identity.user,
|
||||||
company_id=call.identity.company,
|
company_id=call.identity.company,
|
||||||
@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
|
|||||||
|
|
||||||
@endpoint("auth.logout", min_version="2.2")
|
@endpoint("auth.logout", min_version="2.2")
|
||||||
def logout(call: APICall, *_, **__):
|
def logout(call: APICall, *_, **__):
|
||||||
|
revoke_auth_token(call.auth)
|
||||||
call.result.set_auth_cookie(None)
|
call.result.set_auth_cookie(None)
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ def logout(call: APICall, *_, **__):
|
|||||||
response_data_model=GetTokenResponse,
|
response_data_model=GetTokenResponse,
|
||||||
)
|
)
|
||||||
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
"""Generates a token based on a requested user and company. INTERNAL."""
|
||||||
if call.identity.role not in Role.get_system_roles():
|
if call.identity.role not in Role.get_system_roles():
|
||||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||||
raise errors.bad_request.InvalidUserId(
|
raise errors.bad_request.InvalidUserId(
|
||||||
@ -81,12 +83,14 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
|||||||
response_data_model=ValidateResponse,
|
response_data_model=ValidateResponse,
|
||||||
)
|
)
|
||||||
def validate_token_endpoint(call: APICall, _, __):
|
def validate_token_endpoint(call: APICall, _, __):
|
||||||
""" Validate a token and return identity if valid. INTERNAL. """
|
"""Validate a token and return identity if valid. INTERNAL."""
|
||||||
try:
|
try:
|
||||||
# if invalid, decoding will fail
|
# if invalid, decoding will fail
|
||||||
token = Token.from_encoded_token(call.data_model.token)
|
token = Token.from_encoded_token(call.data_model.token)
|
||||||
call.result.data_model = ValidateResponse(
|
call.result.data_model = ValidateResponse(
|
||||||
valid=True, user=token.identity.user, company=token.identity.company
|
valid=not is_token_revoked(token),
|
||||||
|
user=token.identity.user,
|
||||||
|
company=token.identity.company,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
|
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
|
||||||
@ -98,7 +102,7 @@ def validate_token_endpoint(call: APICall, _, __):
|
|||||||
response_data_model=CreateUserResponse,
|
response_data_model=CreateUserResponse,
|
||||||
)
|
)
|
||||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||||
""" Create a user from. INTERNAL. """
|
"""Create a user from. INTERNAL."""
|
||||||
if (
|
if (
|
||||||
call.identity.role not in Role.get_system_roles()
|
call.identity.role not in Role.get_system_roles()
|
||||||
and request.company != call.identity.company
|
and request.company != call.identity.company
|
||||||
|
@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
|
|||||||
)
|
)
|
||||||
from apiserver.config import info
|
from apiserver.config import info
|
||||||
from apiserver.service_repo import endpoint, APICall
|
from apiserver.service_repo import endpoint, APICall
|
||||||
|
from apiserver.service_repo.auth import revoke_auth_token
|
||||||
from apiserver.service_repo.auth.fixed_user import FixedUser
|
from apiserver.service_repo.auth.fixed_user import FixedUser
|
||||||
|
|
||||||
|
|
||||||
@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
|
|||||||
|
|
||||||
@endpoint("login.logout", min_version="2.13")
|
@endpoint("login.logout", min_version="2.13")
|
||||||
def logout(call: APICall, _, __):
|
def logout(call: APICall, _, __):
|
||||||
|
revoke_auth_token(call.auth)
|
||||||
call.result.set_auth_cookie(None)
|
call.result.set_auth_cookie(None)
|
||||||
|
Loading…
Reference in New Issue
Block a user