diff --git a/apiserver/bll/auth/__init__.py b/apiserver/bll/auth/__init__.py index e333dfb..c7d9815 100644 --- a/apiserver/bll/auth/__init__.py +++ b/apiserver/bll/auth/__init__.py @@ -64,7 +64,7 @@ class AuthBLL: feature_set="basic", ) - return GetTokenResponse(token=token.decode("ascii")) + return GetTokenResponse(token=token) @staticmethod def create_user(request: CreateUserRequest, call: APICall = None) -> str: diff --git a/apiserver/bll/event/events_iterator.py b/apiserver/bll/event/events_iterator.py index 62da7c0..ec4b2d5 100644 --- a/apiserver/bll/event/events_iterator.py +++ b/apiserver/bll/event/events_iterator.py @@ -188,7 +188,7 @@ class Scroll(jsonmodels.models.Base): key=config.get( "services.events.events_retrieval.scroll_id_key", "1234567890" ), - ).decode() + ) @classmethod def from_scroll_id(cls, scroll_id: str): @@ -199,6 +199,7 @@ class Scroll(jsonmodels.models.Base): key=config.get( "services.events.events_retrieval.scroll_id_key", "1234567890" ), + algorithms=["HS256"], ) ) except jwt.PyJWTError: diff --git a/apiserver/requirements.txt b/apiserver/requirements.txt index a929aa5..e77ca9c 100644 --- a/apiserver/requirements.txt +++ b/apiserver/requirements.txt @@ -21,7 +21,7 @@ nested_dict>=1.61 packaging==20.3 psutil>=5.6.5 pyhocon>=0.3.35 -pyjwt<2.0.0 +pyjwt>=2.4.0 pymongo[srv]==3.12.0 python-rapidjson>=0.6.3 redis==3.5.3 @@ -31,4 +31,4 @@ requests>=2.13.0 semantic_version>=2.8.3,<3 six tqdm -validators>=0.12.4 +validators>=0.12.4 \ No newline at end of file diff --git a/apiserver/service_repo/auth/payload/token.py b/apiserver/service_repo/auth/payload/token.py index a21e923..9fa879d 100644 --- a/apiserver/service_repo/auth/payload/token.py +++ b/apiserver/service_repo/auth/payload/token.py @@ -9,22 +9,25 @@ from apiserver.database.model.auth import Role from .auth_type import AuthType from .payload import Payload -token_secret = config.get('secure.auth.token_secret') +token_secret = config.get("secure.auth.token_secret") log = config.logger(__file__) class Token(Payload): - default_expiration_sec = config.get('apiserver.auth.default_expiration_sec') + default_expiration_sec = config.get("apiserver.auth.default_expiration_sec") - def __init__(self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_): + def __init__( + self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_ + ): super(Token, self).__init__( - AuthType.bearer_token, identity=identity, entities=entities) + AuthType.bearer_token, identity=identity, entities=entities + ) self.exp = exp self.iat = iat self.nbf = nbf - self._env = env or config.get('env', '') + self._env = env or config.get("env", "") @property def env(self): @@ -65,7 +68,12 @@ class Token(Payload): @classmethod def decode(cls, encoded_token, verify=True): - return jwt.decode(encoded_token, token_secret, verify=verify) + options = ( + {"verify_signature": False, "verify_exp": True} if not verify else None + ) + return jwt.decode( + encoded_token, token_secret, algorithms=["HS256"], options=options + ) @classmethod def from_encoded_token(cls, encoded_token, verify=True): @@ -74,23 +82,24 @@ class Token(Payload): token = Token.from_dict(decoded) assert isinstance(token, Token) if not token.identity: - raise errors.unauthorized.InvalidToken('token missing identity') + raise errors.unauthorized.InvalidToken("token missing identity") return token except Exception as e: - raise errors.unauthorized.InvalidToken('failed parsing token, %s' % e.args[0]) + raise errors.unauthorized.InvalidToken( + "failed parsing token, %s" % e.args[0] + ) @classmethod - def create_encoded_token(cls, identity, expiration_sec=None, entities=None, **extra_payload): + def create_encoded_token( + cls, identity, expiration_sec=None, entities=None, **extra_payload + ): if identity.role not in (Role.system,): # limit expiration time for all roles but an internal service expiration_sec = expiration_sec or cls.default_expiration_sec now = datetime.utcnow() - token = cls( - identity=identity, - entities=entities, - iat=now) + token = cls(identity=identity, entities=entities, iat=now) if expiration_sec: # add 'expiration' claim