import json as json_lib import sys import types from socket import gethostname import requests import six from pyhocon import ConfigTree from requests.auth import HTTPBasicAuth from .callresult import CallResult from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY from .request import Request, BatchRequest from .token_manager import TokenManager from ..config import load from ..utils import get_http_session_with_retry from ..version import __version__ class LoginError(Exception): pass class Session(TokenManager): """ TRAINS API Session class. """ _AUTHORIZATION_HEADER = "Authorization" _WORKER_HEADER = "X-Trains-Worker" _ASYNC_HEADER = "X-Trains-Async" _CLIENT_HEADER = "X-Trains-Client" _async_status_code = 202 _session_requests = 0 _session_initial_timeout = (1.0, 10) _session_timeout = (5.0, None) # TODO: add requests.codes.gateway_timeout once we support async commits _retry_codes = [ requests.codes.bad_gateway, requests.codes.service_unavailable, requests.codes.bandwidth_limit_exceeded, requests.codes.too_many_requests, ] @property def access_key(self): return self.__access_key @property def secret_key(self): return self.__secret_key @property def host(self): return self.__host @property def worker(self): return self.__worker def __init__( self, worker=None, api_key=None, secret_key=None, host=None, logger=None, verbose=None, initialize_logging=True, client=None, config=None, **kwargs ): if config is not None: self.config = config else: self.config = load() if initialize_logging: self.config.initialize_logging() token_expiration_threshold_sec = self.config.get( "auth.token_expiration_threshold_sec", 60 ) super(Session, self).__init__( token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs ) self._verbose = verbose if verbose is not None else ENV_VERBOSE.get() self._logger = logger self.__access_key = api_key or ENV_ACCESS_KEY.get( default=self.config.get("api.credentials.access_key", None) ) if not self.access_key: raise ValueError( "Missing access_key. Please set in configuration file or pass in session init." ) self.__secret_key = secret_key or ENV_SECRET_KEY.get( default=self.config.get("api.credentials.secret_key", None) ) if not self.secret_key: raise ValueError( "Missing secret_key. Please set in configuration file or pass in session init." ) host = host or ENV_HOST.get(default=self.config.get("api.host")) if not host: raise ValueError("host is required in init or config") self.__host = host.strip("/") http_retries_config = self.config.get( "api.http.retries", ConfigTree() ).as_plain_ordered_dict() http_retries_config["status_forcelist"] = self._retry_codes self.__http_session = get_http_session_with_retry(**http_retries_config) self.__worker = worker or gethostname() self.__max_req_size = self.config.get("api.http.max_req_size") if not self.__max_req_size: raise ValueError("missing max request size") self.client = client or "api-{}".format(__version__) self.refresh_token() def _send_request( self, service, action, version=None, method="get", headers=None, auth=None, data=None, json=None, refresh_token_if_unauthorized=True, ): """ Internal implementation for making a raw API request. - Constructs the api endpoint name - Injects the worker id into the headers - Allows custom authorization using a requests auth object - Intercepts `Unauthorized` responses and automatically attempts to refresh the session token once in this case (only once). This is done since permissions are embedded in the token, and addresses a case where server-side permissions have changed but are not reflected in the current token. Refreshing the token will generate a token with the updated permissions. """ host = self.host headers = headers.copy() if headers else {} headers[self._WORKER_HEADER] = self.worker headers[self._CLIENT_HEADER] = self.client token_refreshed_on_error = False url = ( "{host}/v{version}/{service}.{action}" if version else "{host}/{service}.{action}" ).format(**locals()) while True: res = self.__http_session.request( method, url, headers=headers, auth=auth, data=data, json=json, timeout=self._session_initial_timeout if self._session_requests < 1 else self._session_timeout, ) if ( refresh_token_if_unauthorized and res.status_code == requests.codes.unauthorized and not token_refreshed_on_error ): # it seems we're unauthorized, so we'll try to refresh our token once in case permissions changed since # the last time we got the token, and try again self.refresh_token() token_refreshed_on_error = True # try again continue if ( res.status_code == requests.codes.service_unavailable and self.config.get("api.http.wait_on_maintenance_forever", True) ): self._logger.warn( "Service unavailable: {} is undergoing maintenance, retrying...".format( host ) ) continue break self._session_requests += 1 return res def send_request( self, service, action, version=None, method="get", headers=None, data=None, json=None, async_enable=False, ): """ Send a raw API request. :param service: service name :param action: action name :param version: version number (default is the preconfigured api version) :param method: method type (default is 'get') :param headers: request headers (authorization and content type headers will be automatically added) :param json: json to send in the request body (jsonable object or builtin types construct. if used, content type will be application/json) :param data: Dictionary, bytes, or file-like object to send in the request body :param async_enable: whether request is asynchronous :return: requests Response instance """ headers = headers.copy() if headers else {} headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token) if async_enable: headers[self._ASYNC_HEADER] = "1" return self._send_request( service=service, action=action, version=version, method=method, headers=headers, data=data, json=json, ) def send_request_batch( self, service, action, version=None, headers=None, data=None, json=None, method="get", ): """ Send a raw batch API request. Batch requests always use application/json-lines content type. :param service: service name :param action: action name :param version: version number (default is the preconfigured api version) :param headers: request headers (authorization and content type headers will be automatically added) :param json: iterable of json items (batched items, jsonable objects or builtin types constructs). These will be sent as a multi-line payload in the request body. :param data: iterable of bytes objects (batched items). These will be sent as a multi-line payload in the request body. :param method: HTTP method :return: requests Response instance """ if not all( isinstance(x, (list, tuple, type(None), types.GeneratorType)) for x in (data, json) ): raise ValueError("Expecting list, tuple or generator in 'data' or 'json'") if not data and not json: raise ValueError( "Missing data (data or json), batch requests are meaningless without it." ) headers = headers.copy() if headers else {} headers["Content-Type"] = "application/json-lines" if data: req_data = "\n".join(data) else: req_data = "\n".join(json_lib.dumps(x) for x in json) cur = 0 results = [] while True: size = self.__max_req_size slice = req_data[cur : cur + size] if not slice: break if len(slice) < size: # this is the remainder, no need to search for newline pass elif slice[-1] != "\n": # search for the last newline in order to send a coherent request size = slice.rfind("\n") + 1 # readjust the slice slice = req_data[cur : cur + size] res = self.send_request( method=method, service=service, action=action, data=slice, headers=headers, version=version, ) results.append(res) if res.status_code != requests.codes.ok: break cur += size return results def validate_request(self, req_obj): """ Validate an API request against the current version and the request's schema """ try: # make sure we're using a compatible version for this request # validate the request (checks required fields and specific field version restrictions) validate = req_obj.validate except AttributeError: raise TypeError( '"req_obj" parameter must be an backend_api.session.Request object' ) validate() def send_async(self, req_obj): """ Asynchronously sends an API request using a request object. :param req_obj: The request object :type req_obj: Request :return: CallResult object containing the raw response, response metadata and parsed response object. """ return self.send(req_obj=req_obj, async_enable=True) def send(self, req_obj, async_enable=False, headers=None): """ Sends an API request using a request object. :param req_obj: The request object :type req_obj: Request :param async_enable: Request this method be executed in an asynchronous manner :param headers: Additional headers to send with request :return: CallResult object containing the raw response, response metadata and parsed response object. """ self.validate_request(req_obj) if isinstance(req_obj, BatchRequest): # TODO: support async for batch requests as well if async_enable: raise NotImplementedError( "Async behavior is currently not implemented for batch requests" ) json_data = req_obj.get_json() res = self.send_request_batch( service=req_obj._service, action=req_obj._action, version=req_obj._version, json=json_data, method=req_obj._method, headers=headers, ) # TODO: handle multiple results in this case try: res = next(r for r in res if r.status_code != 200) except StopIteration: # all are 200 res = res[0] else: res = self.send_request( service=req_obj._service, action=req_obj._action, version=req_obj._version, json=req_obj.to_dict(), method=req_obj._method, async_enable=async_enable, headers=headers, ) call_result = CallResult.from_result( res=res, request_cls=req_obj.__class__, logger=self._logger, service=req_obj._service, action=req_obj._action, session=self, ) return call_result def _do_refresh_token(self, old_token, exp=None): """ TokenManager abstract method implementation. Here we ignore the old token and simply obtain a new token. """ verbose = self._verbose and self._logger if verbose: self._logger.info( "Refreshing token from {} (access_key={}, exp={})".format( self.host, self.access_key, exp ) ) auth = HTTPBasicAuth(self.access_key, self.secret_key) try: data = {"expiration_sec": exp} if exp else {} res = self._send_request( service="auth", action="login", auth=auth, json=data, refresh_token_if_unauthorized=False, ) try: resp = res.json() except ValueError: resp = {} if res.status_code != 200: msg = resp.get("meta", {}).get("result_msg", res.reason) raise LoginError( "Failed getting token (error {} from {}): {}".format( res.status_code, self.host, msg ) ) if verbose: self._logger.info("Received new token") return resp["data"]["token"] except LoginError: six.reraise(*sys.exc_info()) except Exception as ex: raise LoginError(str(ex)) def __str__(self): return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format( self=self, secret_key=self.secret_key[:5] + "*" * (len(self.secret_key) - 5) )