Invalidate token on user logoff

This commit is contained in:
allegroai 2024-03-18 15:38:44 +02:00
parent 88abf28287
commit d3013ac285
5 changed files with 60 additions and 12 deletions

View File

@ -1,4 +1,4 @@
from .auth import get_auth_func, authorize_impersonation
from .auth import get_auth_func, authorize_impersonation, revoke_auth_token
from .payload import Token, Basic, AuthType, Payload
from .identity import Identity
from .utils import get_client_id, get_secret_key

View File

@ -1,5 +1,6 @@
import base64
from datetime import datetime
from time import time
import bcrypt
import jwt
@ -11,15 +12,16 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.redis_manager import redman
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
log = config.logger(__file__)
entity_keys = set(get_options(Entities))
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
_revoked_tokens_key = "revoked_tokens"
redis = redman.connection("apiserver")
def get_auth_func(auth_type):
@ -41,8 +43,10 @@ def authorize_token(jwt_token, service, action, call):
log.error(f"{msg} Call info: {info}")
try:
return Token.from_encoded_token(jwt_token)
token = Token.from_encoded_token(jwt_token)
if is_token_revoked(token):
raise errors.unauthorized.InvalidToken("revoked token")
return token
except jwt.exceptions.InvalidKeyError as ex:
log_error("Failed parsing token.")
raise errors.unauthorized.InvalidToken(
@ -154,3 +158,23 @@ def compare_secret_key_hash(secret_key: str, hashed_secret: str) -> bool:
return bcrypt.checkpw(
secret_key.encode(), base64.b64decode(hashed_secret.encode("ascii"))
)
def is_token_revoked(token: Token) -> bool:
if not isinstance(token, Token) or not token.session_id:
return False
return redis.zscore(_revoked_tokens_key, token.session_id) is not None
def revoke_auth_token(token: Token):
if not isinstance(token, Token) or not token.session_id:
return
timestamp_now = int(time())
expiration_timestamp = token.exp
if not expiration_timestamp:
expiration_timestamp = timestamp_now + Token.default_expiration_sec
redis.zadd(_revoked_tokens_key, {token.session_id: expiration_timestamp})
redis.zremrangebyscore(_revoked_tokens_key, min=0, max=timestamp_now)

View File

@ -1,3 +1,5 @@
from uuid import uuid4
import jwt
from datetime import datetime, timedelta
@ -20,7 +22,15 @@ class Token(Payload):
default_expiration_sec = config.get("apiserver.auth.default_expiration_sec")
def __init__(
self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_
self,
exp=None,
iat=None,
nbf=None,
env=None,
identity=None,
session_id=None,
entities=None,
**_,
):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities
@ -28,8 +38,13 @@ class Token(Payload):
self.exp = exp
self.iat = iat
self.nbf = nbf
self._session_id = session_id
self._env = env or config.get("env", "<unknown>")
@property
def session_id(self):
return self._session_id
@property
def env(self):
return self._env
@ -102,8 +117,11 @@ class Token(Payload):
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
session_id = uuid4().hex
token = cls(identity=identity, entities=entities, iat=now)
token = cls(
identity=identity, entities=entities, iat=now, session_id=session_id
)
if expiration_sec:
# add 'expiration' claim

View File

@ -24,6 +24,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Role
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Token
from apiserver.service_repo.auth.auth import is_token_revoked, revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
@ -35,7 +36,7 @@ log = config.logger(__file__)
response_data_model=GetTokenResponse,
)
def login(call: APICall, *_, **__):
""" Generates a token based on the authenticated user (intended for use with credentials) """
"""Generates a token based on the authenticated user (intended for use with credentials)"""
call.result.data_model = AuthBLL.get_token_for_user(
user_id=call.identity.user,
company_id=call.identity.company,
@ -48,6 +49,7 @@ def login(call: APICall, *_, **__):
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)
@ -57,7 +59,7 @@ def logout(call: APICall, *_, **__):
response_data_model=GetTokenResponse,
)
def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """
"""Generates a token based on a requested user and company. INTERNAL."""
if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId(
@ -81,12 +83,14 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
response_data_model=ValidateResponse,
)
def validate_token_endpoint(call: APICall, _, __):
""" Validate a token and return identity if valid. INTERNAL. """
"""Validate a token and return identity if valid. INTERNAL."""
try:
# if invalid, decoding will fail
token = Token.from_encoded_token(call.data_model.token)
call.result.data_model = ValidateResponse(
valid=True, user=token.identity.user, company=token.identity.company
valid=not is_token_revoked(token),
user=token.identity.user,
company=token.identity.company,
)
except Exception as e:
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
@ -98,7 +102,7 @@ def validate_token_endpoint(call: APICall, _, __):
response_data_model=CreateUserResponse,
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
"""Create a user from. INTERNAL."""
if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company

View File

@ -7,6 +7,7 @@ from apiserver.apimodels.login import (
)
from apiserver.config import info
from apiserver.service_repo import endpoint, APICall
from apiserver.service_repo.auth import revoke_auth_token
from apiserver.service_repo.auth.fixed_user import FixedUser
@ -37,4 +38,5 @@ def supported_modes(call: APICall, _, __: GetSupportedModesRequest):
@endpoint("login.logout", min_version="2.13")
def logout(call: APICall, _, __):
revoke_auth_token(call.auth)
call.result.set_auth_cookie(None)