Invalidate token on user logoff

This commit is contained in:
allegroai 2024-03-18 15:38:44 +02:00
parent 88abf28287
commit d3013ac285
5 changed files with 60 additions and 12 deletions

View File

@ -1,4 +1,4 @@
from .auth import get_auth_func, authorize_impersonation
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
from .payload import Token, Basic, AuthType, Payload
from .identity import Identity
from .utils import get_client_id, get_secret_key

View File

@ -1,5 +1,6 @@
import base64
from datetime import datetime
from time import time
import bcrypt
import jwt
@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.redis_manager import redman
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
log = config.logger(__file__)
entity_keys = set(get_options(Entities))
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
_revoked_tokens_key = "revoked_tokens"
redis = redman.connection("apiserver")
def get_auth_func(auth_type):
@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
log.error(f"{msg} Call info: {info}")
try:
return Token.from_encoded_token(jwt_token)
token = Token.from_encoded_token(jwt_token)
if is_token_revoked(token):
raise errors.unauthorized.InvalidToken("revoked token")
return token
except jwt.exceptions.InvalidKeyError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken(
@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
return bcrypt.checkpw(
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
)
def is_token_revoked(token: Token) -> bool:
if not isinstance(token, Token) or not token.session_id:
return False
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
def revoke_auth_token(token: Token):
if not isinstance(token, Token) or not token.session_id:
return
timestamp_now = int(time())
expiration_timestamp = token.exp
if not expiration_timestamp:
expiration_timestamp = timestamp_now + Token.default_expiration_sec
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)

View File

@ -1,3 +1,5 @@
from uuid import uuid4
import jwt
from datetime import datetime, timedelta
@ -20,7 +22,15 @@ class Token(Payload):
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
def __init__(
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
self,
exp=None,
iat=None,
nbf=None,
env=None,
identity=None,
session_id=None,
entities=None,
**_,
):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities
@ -28,8 +38,13 @@ class Token(Payload):
self.exp = exp
self.iat = iat
self.nbf = nbf
self._session_id = session_id
self._env = env or config.get("env", "<unknown>")
@property
def session_id(self):
return self._session_id
@property
def env(self):
return self._env
@ -102,8 +117,11 @@ class Token(Payload):
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
session_id = uuid4().hex
token = cls(identity=identity, entities=entities, iat=now)
token = cls(
identity=identity, entities=entities, iat=now, session_id=session_id
)
if expiration_sec:
# add 'expiration' claim

View File

@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Role
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Token
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)
@ -86,7 +88,9 @@ def validate_token_endpoint(call: APICall, _, __):
# if invalid, decoding will fail
token = Token.from_encoded_token(call.data_model.token)
call.result.data_model = ValidateResponse(
valid=True, user=token.identity.user, company=token.identity.company
valid=not is_token_revoked(token),
user=token.identity.user,
company=token.identity.company,
)
except Exception as e:
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])

View File

@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
)
from apiserver.config import info
from apiserver.service_repo import endpoint, APICall
from apiserver.service_repo.auth import revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
@endpoint("login.logout", min_version="2.13")
def logout(call: APICall, _, __):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)