Stability and cleanups

This commit is contained in:
allegroai
2019-09-03 12:58:01 +03:00
parent f8d3894e02
commit 64ba30df13
14 changed files with 149 additions and 79 deletions

View File

@@ -1,3 +1,2 @@
from .version import __version__
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
from .config import load as load_config

View File

@@ -1,19 +1,78 @@
from .session import Session
import importlib
import pkgutil
import re
from typing import Any
from .session import Session
from ..utilities.check_updates import Version
class ApiServiceProxy(object):
_main_services_module = "trains.backend_api.services"
_max_available_version = None
def __init__(self, module):
self.__wrapped_name__ = module
self.__wrapped_version__ = Session.api_version
def __getattr__(self, attr):
if attr in ['__wrapped_name__', '__wrapped__', '__wrapped_version__']:
if attr in ["__wrapped_name__", "__wrapped__", "__wrapped_version__"]:
return self.__dict__.get(attr)
if not self.__dict__.get('__wrapped__') or self.__dict__.get('__wrapped_version__') != Session.api_version:
self.__dict__['__wrapped_version__'] = Session.api_version
self.__dict__['__wrapped__'] = importlib.import_module('.v'+str(Session.api_version).replace('.', '_') +
'.' + self.__dict__.get('__wrapped_name__'),
package='trains.backend_api.services')
return getattr(self.__dict__['__wrapped__'], attr)
if not self.__dict__.get("__wrapped__") or self.__dict__.get("__wrapped_version__") != Session.api_version:
if not ApiServiceProxy._max_available_version:
from ..backend_api import services
ApiServiceProxy._max_available_version = max([
Version(name[1:].replace("_", "."))
for name in [
module_name
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
]])
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
self.__dict__["__wrapped_version__"] = version
name = ".v{}.{}".format(
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")
)
self.__dict__["__wrapped__"] = self._import_module(name, self._main_services_module)
return getattr(self.__dict__["__wrapped__"], attr)
def _import_module(self, name, package):
# type: (str, str) -> Any
return importlib.import_module(name, package=package)
class ExtApiServiceProxy(ApiServiceProxy):
_extra_services_modules = []
def _import_module(self, name, _):
# type: (str, str) -> Any
for module_path in self._get_services_modules():
try:
return importlib.import_module(name, package=module_path)
except ModuleNotFoundError:
pass
raise ModuleNotFoundError(
"No module '{}' in all predefined services module paths".format(name)
)
@classmethod
def add_services_module(cls, module_path):
# type: (str) -> None
"""
Add an additional service module path to look in when importing types
"""
cls._extra_services_modules.append(module_path)
def _get_services_modules(self):
"""
Yield all services module paths.
Paths are yielded in reverse order, so that users can add a services module that will override
the built-in main service module path (e.g. in case a type defined in the built-in module was redefined)
"""
for path in reversed(self._extra_services_modules):
yield path
yield self._main_services_module

View File

@@ -16,7 +16,7 @@ from .request import Request, BatchRequest
from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ..version import __version__
from ...version import __version__
class LoginError(Exception):
@@ -225,6 +225,10 @@ class Session(TokenManager):
self._session_requests += 1
return res
def add_auth_headers(self, headers):
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
return headers
def send_request(
self,
service,
@@ -249,8 +253,9 @@ class Session(TokenManager):
:param async_enable: whether request is asynchronous
:return: requests Response instance
"""
headers = headers.copy() if headers else {}
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
headers = self.add_auth_headers(
headers.copy() if headers else {}
)
if async_enable:
headers[self._ASYNC_HEADER] = "1"
return self._send_request(
@@ -493,6 +498,7 @@ class Session(TokenManager):
)
auth = HTTPBasicAuth(self.access_key, self.secret_key)
res = None
try:
data = {"expiration_sec": exp} if exp else {}
res = self._send_request(
@@ -518,8 +524,16 @@ class Session(TokenManager):
return resp["data"]["token"]
except LoginError:
six.reraise(*sys.exc_info())
except KeyError as ex:
# check if this is a misconfigured api server (getting 200 without the data section)
if res and res.status_code == 200:
raise ValueError('It seems *api_server* is misconfigured. '
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
else:
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
"exception: {}".format(res, ex))
except Exception as ex:
raise LoginError(str(ex))
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
def __str__(self):
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(

View File

@@ -99,7 +99,7 @@ def get_http_session_with_retry(
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
session.mount('http://', adapter)
session.mount('https://', adapter)
# update verify host certiface
# update verify host certificate
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
if not session.verify and __disable_certificate_verification_warning < 2:
# show warning

View File

@@ -1 +0,0 @@
__version__ = '2.0.0'