import sys from abc import ABCMeta, abstractmethod from time import time import jwt from jwt.algorithms import get_default_algorithms import six @six.add_metaclass(ABCMeta) class TokenManager(object): _default_token_exp_threshold_sec = 12 * 60 * 60 _default_req_token_expiration_sec = None @property def token_expiration_threshold_sec(self): return self.__token_expiration_threshold_sec @token_expiration_threshold_sec.setter def token_expiration_threshold_sec(self, value): self.__token_expiration_threshold_sec = value @property def req_token_expiration_sec(self): """ Token expiration sec requested when refreshing token """ return self.__req_token_expiration_sec @req_token_expiration_sec.setter def req_token_expiration_sec(self, value): assert isinstance(value, (type(None), int)) self.__req_token_expiration_sec = value @property def token_expiration_sec(self): return self.__token_expiration_sec @property def token(self): return self._get_token() @property def raw_token(self): return self.__token def __init__( self, token=None, req_token_expiration_sec=None, token_history=None, token_expiration_threshold_sec=None, config=None, **kwargs ): super(TokenManager, self).__init__() assert isinstance(token_history, (type(None), dict)) if config: req_token_expiration_sec = req_token_expiration_sec or config.get( "api.auth.request_token_expiration_sec", None ) token_expiration_threshold_sec = ( token_expiration_threshold_sec or config.get("api.auth.token_expiration_threshold_sec", None) ) self.token_expiration_threshold_sec = ( token_expiration_threshold_sec or self._default_token_exp_threshold_sec ) self.req_token_expiration_sec = ( req_token_expiration_sec or self._default_req_token_expiration_sec ) self._set_token(token) def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None): if token: try: exp = exp or self._get_token_exp(token) if at_least_sec: at_least_sec = max( at_least_sec, self.token_expiration_threshold_sec ) else: at_least_sec = self.token_expiration_threshold_sec return max(0, (exp - time() - at_least_sec)) except Exception: pass return 0 @classmethod def get_decoded_token(cls, token, verify=False): """ Get token expiration time. If not present, assume forever """ if hasattr(jwt, '__version__') and jwt.__version__[0] == '1': return jwt.decode( token, verify=verify, algorithms=get_default_algorithms(), ) return jwt.decode( token, options=dict(verify_signature=verify), algorithms=get_default_algorithms(), ) @classmethod def _get_token_exp(cls, token): """ Get token expiration time. If not present, assume forever """ return cls.get_decoded_token(token).get("exp", sys.maxsize) def _set_token(self, token): if token: self.__token = token self.__token_expiration_sec = self._get_token_exp(token) else: self.__token = None self.__token_expiration_sec = 0 def get_token_valid_period_sec(self): return self._calc_token_valid_period_sec( self.__token, self.token_expiration_sec ) def _get_token(self): if self.get_token_valid_period_sec() <= 0: self.refresh_token() return self.__token @abstractmethod def _do_refresh_token(self, old_token, exp=None): pass def refresh_token(self): self._set_token( self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec) )