mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Stability and cleanups
This commit is contained in:
@@ -1,3 +1,2 @@
|
||||
from .version import __version__
|
||||
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
|
||||
from .config import load as load_config
|
||||
|
||||
@@ -1,19 +1,78 @@
|
||||
from .session import Session
|
||||
import importlib
|
||||
import pkgutil
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .session import Session
|
||||
from ..utilities.check_updates import Version
|
||||
|
||||
|
||||
class ApiServiceProxy(object):
|
||||
_main_services_module = "trains.backend_api.services"
|
||||
_max_available_version = None
|
||||
|
||||
def __init__(self, module):
|
||||
self.__wrapped_name__ = module
|
||||
self.__wrapped_version__ = Session.api_version
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ['__wrapped_name__', '__wrapped__', '__wrapped_version__']:
|
||||
if attr in ["__wrapped_name__", "__wrapped__", "__wrapped_version__"]:
|
||||
return self.__dict__.get(attr)
|
||||
|
||||
if not self.__dict__.get('__wrapped__') or self.__dict__.get('__wrapped_version__') != Session.api_version:
|
||||
self.__dict__['__wrapped_version__'] = Session.api_version
|
||||
self.__dict__['__wrapped__'] = importlib.import_module('.v'+str(Session.api_version).replace('.', '_') +
|
||||
'.' + self.__dict__.get('__wrapped_name__'),
|
||||
package='trains.backend_api.services')
|
||||
return getattr(self.__dict__['__wrapped__'], attr)
|
||||
if not self.__dict__.get("__wrapped__") or self.__dict__.get("__wrapped_version__") != Session.api_version:
|
||||
if not ApiServiceProxy._max_available_version:
|
||||
from ..backend_api import services
|
||||
ApiServiceProxy._max_available_version = max([
|
||||
Version(name[1:].replace("_", "."))
|
||||
for name in [
|
||||
module_name
|
||||
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
|
||||
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
|
||||
]])
|
||||
|
||||
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
|
||||
self.__dict__["__wrapped_version__"] = version
|
||||
name = ".v{}.{}".format(
|
||||
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")
|
||||
)
|
||||
self.__dict__["__wrapped__"] = self._import_module(name, self._main_services_module)
|
||||
|
||||
return getattr(self.__dict__["__wrapped__"], attr)
|
||||
|
||||
def _import_module(self, name, package):
|
||||
# type: (str, str) -> Any
|
||||
return importlib.import_module(name, package=package)
|
||||
|
||||
|
||||
class ExtApiServiceProxy(ApiServiceProxy):
|
||||
_extra_services_modules = []
|
||||
|
||||
def _import_module(self, name, _):
|
||||
# type: (str, str) -> Any
|
||||
for module_path in self._get_services_modules():
|
||||
try:
|
||||
return importlib.import_module(name, package=module_path)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
raise ModuleNotFoundError(
|
||||
"No module '{}' in all predefined services module paths".format(name)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_services_module(cls, module_path):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Add an additional service module path to look in when importing types
|
||||
"""
|
||||
cls._extra_services_modules.append(module_path)
|
||||
|
||||
def _get_services_modules(self):
|
||||
"""
|
||||
Yield all services module paths.
|
||||
Paths are yielded in reverse order, so that users can add a services module that will override
|
||||
the built-in main service module path (e.g. in case a type defined in the built-in module was redefined)
|
||||
"""
|
||||
for path in reversed(self._extra_services_modules):
|
||||
yield path
|
||||
yield self._main_services_module
|
||||
|
||||
@@ -16,7 +16,7 @@ from .request import Request, BatchRequest
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ..version import __version__
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoginError(Exception):
|
||||
@@ -225,6 +225,10 @@ class Session(TokenManager):
|
||||
self._session_requests += 1
|
||||
return res
|
||||
|
||||
def add_auth_headers(self, headers):
|
||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||
return headers
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
service,
|
||||
@@ -249,8 +253,9 @@ class Session(TokenManager):
|
||||
:param async_enable: whether request is asynchronous
|
||||
:return: requests Response instance
|
||||
"""
|
||||
headers = headers.copy() if headers else {}
|
||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||
headers = self.add_auth_headers(
|
||||
headers.copy() if headers else {}
|
||||
)
|
||||
if async_enable:
|
||||
headers[self._ASYNC_HEADER] = "1"
|
||||
return self._send_request(
|
||||
@@ -493,6 +498,7 @@ class Session(TokenManager):
|
||||
)
|
||||
|
||||
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
||||
res = None
|
||||
try:
|
||||
data = {"expiration_sec": exp} if exp else {}
|
||||
res = self._send_request(
|
||||
@@ -518,8 +524,16 @@ class Session(TokenManager):
|
||||
return resp["data"]["token"]
|
||||
except LoginError:
|
||||
six.reraise(*sys.exc_info())
|
||||
except KeyError as ex:
|
||||
# check if this is a misconfigured api server (getting 200 without the data section)
|
||||
if res and res.status_code == 200:
|
||||
raise ValueError('It seems *api_server* is misconfigured. '
|
||||
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
|
||||
else:
|
||||
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
||||
"exception: {}".format(res, ex))
|
||||
except Exception as ex:
|
||||
raise LoginError(str(ex))
|
||||
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
||||
|
||||
def __str__(self):
|
||||
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
||||
|
||||
@@ -99,7 +99,7 @@ def get_http_session_with_retry(
|
||||
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
# update verify host certiface
|
||||
# update verify host certificate
|
||||
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
|
||||
if not session.verify and __disable_certificate_verification_warning < 2:
|
||||
# show warning
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
__version__ = '2.0.0'
|
||||
Reference in New Issue
Block a user