Stability and cleanups

This commit is contained in:
allegroai 2019-09-03 12:58:01 +03:00
parent f8d3894e02
commit 64ba30df13
14 changed files with 149 additions and 79 deletions

View File

@ -1,3 +1,2 @@
from .version import __version__
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
from .config import load as load_config

View File

@ -1,19 +1,78 @@
from .session import Session
import importlib
import pkgutil
import re
from typing import Any
from .session import Session
from ..utilities.check_updates import Version
class ApiServiceProxy(object):
_main_services_module = "trains.backend_api.services"
_max_available_version = None
def __init__(self, module):
self.__wrapped_name__ = module
self.__wrapped_version__ = Session.api_version
def __getattr__(self, attr):
if attr in ['__wrapped_name__', '__wrapped__', '__wrapped_version__']:
if attr in ["__wrapped_name__", "__wrapped__", "__wrapped_version__"]:
return self.__dict__.get(attr)
if not self.__dict__.get('__wrapped__') or self.__dict__.get('__wrapped_version__') != Session.api_version:
self.__dict__['__wrapped_version__'] = Session.api_version
self.__dict__['__wrapped__'] = importlib.import_module('.v'+str(Session.api_version).replace('.', '_') +
'.' + self.__dict__.get('__wrapped_name__'),
package='trains.backend_api.services')
return getattr(self.__dict__['__wrapped__'], attr)
if not self.__dict__.get("__wrapped__") or self.__dict__.get("__wrapped_version__") != Session.api_version:
if not ApiServiceProxy._max_available_version:
from ..backend_api import services
ApiServiceProxy._max_available_version = max([
Version(name[1:].replace("_", "."))
for name in [
module_name
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
]])
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
self.__dict__["__wrapped_version__"] = version
name = ".v{}.{}".format(
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")
)
self.__dict__["__wrapped__"] = self._import_module(name, self._main_services_module)
return getattr(self.__dict__["__wrapped__"], attr)
def _import_module(self, name, package):
# type: (str, str) -> Any
return importlib.import_module(name, package=package)
class ExtApiServiceProxy(ApiServiceProxy):
_extra_services_modules = []
def _import_module(self, name, _):
# type: (str, str) -> Any
for module_path in self._get_services_modules():
try:
return importlib.import_module(name, package=module_path)
except ModuleNotFoundError:
pass
raise ModuleNotFoundError(
"No module '{}' in all predefined services module paths".format(name)
)
@classmethod
def add_services_module(cls, module_path):
# type: (str) -> None
"""
Add an additional service module path to look in when importing types
"""
cls._extra_services_modules.append(module_path)
def _get_services_modules(self):
"""
Yield all services module paths.
Paths are yielded in reverse order, so that users can add a services module that will override
the built-in main service module path (e.g. in case a type defined in the built-in module was redefined)
"""
for path in reversed(self._extra_services_modules):
yield path
yield self._main_services_module

View File

@ -16,7 +16,7 @@ from .request import Request, BatchRequest
from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ..version import __version__
from ...version import __version__
class LoginError(Exception):
@ -225,6 +225,10 @@ class Session(TokenManager):
self._session_requests += 1
return res
def add_auth_headers(self, headers):
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
return headers
def send_request(
self,
service,
@ -249,8 +253,9 @@ class Session(TokenManager):
:param async_enable: whether request is asynchronous
:return: requests Response instance
"""
headers = headers.copy() if headers else {}
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
headers = self.add_auth_headers(
headers.copy() if headers else {}
)
if async_enable:
headers[self._ASYNC_HEADER] = "1"
return self._send_request(
@ -493,6 +498,7 @@ class Session(TokenManager):
)
auth = HTTPBasicAuth(self.access_key, self.secret_key)
res = None
try:
data = {"expiration_sec": exp} if exp else {}
res = self._send_request(
@ -518,8 +524,16 @@ class Session(TokenManager):
return resp["data"]["token"]
except LoginError:
six.reraise(*sys.exc_info())
except KeyError as ex:
# check if this is a misconfigured api server (getting 200 without the data section)
if res and res.status_code == 200:
raise ValueError('It seems *api_server* is misconfigured. '
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
else:
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
"exception: {}".format(res, ex))
except Exception as ex:
raise LoginError(str(ex))
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
def __str__(self):
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(

View File

@ -99,7 +99,7 @@ def get_http_session_with_retry(
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
session.mount('http://', adapter)
session.mount('https://', adapter)
# update verify host certiface
# update verify host certificate
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
if not session.verify and __disable_certificate_verification_warning < 2:
# show warning

View File

@ -1 +0,0 @@
__version__ = '2.0.0'

View File

@ -12,7 +12,7 @@ from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
from ..config import config_obj
from ..config.defs import LOG_LEVEL_ENV_VAR
from ..debugging import get_logger
from ..backend_api.version import __version__
from ..version import __version__
from .session import SendError, SessionInterface

View File

@ -348,6 +348,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def download_model_weights(self):
""" Download the model weights into a local file in our cache """
uri = self.data.uri
if not uri or not uri.strip():
return None
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
filename = uri.split('/')[-1]
ext = '.'.join(filename.split('.')[1:])

View File

@ -724,31 +724,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(script=script)
@classmethod
def create_new_task(cls, session, task_entry, log=None):
"""
Create a new task
:param session: Session object used for sending requests to the API
:type session: Session
:param task_entry: A task entry instance
:type task_entry: tasks.CreateRequest
:param log: Optional log
:type log: logging.Logger
:return: A new Task instance
"""
if isinstance(task_entry, dict):
task_entry = tasks.CreateRequest(**task_entry)
assert isinstance(task_entry, tasks.CreateRequest)
res = cls._send(session=session, req=task_entry, log=log)
return cls(session, task_id=res.response.id)
@classmethod
def clone_task(cls, cloned_task_id, name, comment=None, execution_overrides=None,
def clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
tags=None, parent=None, project=None, log=None, session=None):
"""
Clone a task
:param session: Session object used for sending requests to the API
:type session: Session
:param cloned_task_id: Task ID for the task to be cloned
:type cloned_task_id: str
:param name: New for the new task
@ -760,13 +739,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:type execution_overrides: dict
:param tags: Optional updated model tags
:type tags: [str]
:param parent: Optional parent ID of the new task.
:param parent: Optional parent Task ID of the new task.
:type parent: str
:param project: Optional project ID of the new task.
If None, the new task will inherit the cloned task's project.
:type parent: str
:type project: str
:param log: Log object used by the infrastructure.
:type log: logging.Logger
:param session: Session object used for sending requests to the API
:type session: Session
:return: The new tasks's ID
"""
@ -781,7 +762,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
execution = ConfigTree.merge_configs(ConfigFactory.from_dict(execution),
ConfigFactory.from_dict(execution_overrides or {}))
req = tasks.CreateRequest(
name=name,
name=name or task.name,
type=task.type,
input=task.input,
tags=tags if tags is not None else task.tags,
@ -796,27 +777,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return res.response.id
@classmethod
def enqueue_task(cls, task_id, session=None, queue_id=None, log=None):
"""
Enqueue a task for execution
:param session: Session object used for sending requests to the API
:type session: Session
:param task_id: ID of the task to be enqueued
:type task_id: str
:param queue_id: ID of the queue in which to enqueue the task. If not provided, the default queue will be used.
:type queue_id: str
:param log: Log object
:type log: logging.Logger
:return: enqueue response
"""
assert isinstance(task_id, six.string_types)
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
res = cls._send(session=session, req=req, log=log)
resp = res.response
return resp
@classmethod
def get_all(cls, session, log=None, **kwargs):
def get_all(cls, session=None, log=None, **kwargs):
"""
Get all tasks
:param session: Session object used for sending requests to the API
@ -827,6 +788,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:type kwargs: dict
:return: API response
"""
session = session if session else cls._get_default_session()
req = tasks.GetAllRequest(**kwargs)
res = cls._send(session=session, req=req, log=log)
return res

View File

@ -62,7 +62,7 @@ class PostImportHookPatching(object):
@staticmethod
def add_on_import(name, func):
PostImportHookPatching._init_hook()
if not name in PostImportHookPatching._post_import_hooks or \
if name not in PostImportHookPatching._post_import_hooks or \
func not in PostImportHookPatching._post_import_hooks[name]:
PostImportHookPatching._post_import_hooks[name].append(func)

View File

@ -107,14 +107,11 @@ class PatchedMatplotlib:
except Exception:
pass
return True
@staticmethod
def update_current_task(task):
if PatchedMatplotlib.patch_matplotlib():
PatchedMatplotlib._current_task = task
# update api version
from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
# create plotly renderer
try:
from plotly import optional_imports
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
@ -122,6 +119,13 @@ class PatchedMatplotlib:
except Exception:
pass
return True
@staticmethod
def update_current_task(task):
if PatchedMatplotlib.patch_matplotlib():
PatchedMatplotlib._current_task = task
@staticmethod
def patched_imshow(*args, **kw):
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)

View File

@ -370,6 +370,8 @@ class InputModel(BaseModel):
"""
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
weights_url = StorageHelper.conform_url(weights_url)
if not weights_url:
raise ValueError("Please provide a valid weights_url parameter")
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name"],

View File

@ -1,5 +1,4 @@
import getpass
import io
import json
import os
import threading
@ -20,6 +19,7 @@ from libcloud.common.types import ProviderError, LibcloudError
from libcloud.storage.providers import get_driver
from libcloud.storage.types import Provider
from pathlib2 import Path
from requests.exceptions import ConnectionError
from six import binary_type
from six.moves.queue import Queue, Empty
from six.moves.urllib.parse import urlparse
@ -47,6 +47,10 @@ class StorageError(Exception):
pass
class DownloadError(Exception):
pass
class _DownloadProgressReport(object):
def __init__(self, total_size, verbose, remote_path, report_chunk_size_mb, log):
self._total_size = total_size
@ -700,6 +704,8 @@ class StorageHelper(object):
self._log.info(
'Downloaded %.2f MB successfully from %s , saved to %s' % (dl_total_mb, remote_path, local_path))
return local_path
except DownloadError as e:
raise
except Exception as e:
self._log.error("Could not download %s , err: %s " % (remote_path, str(e)))
if delete_on_failure:
@ -715,6 +721,8 @@ class StorageHelper(object):
try:
obj = self._get_object(remote_path)
return self._driver.download_object_as_stream(obj, chunk_size=chunk_size)
except DownloadError as e:
raise
except Exception as e:
self._log.error("Could not download file : %s, err:%s " % (remote_path, str(e)))
return None
@ -902,6 +910,8 @@ class StorageHelper(object):
return self._driver.get_object(container_name=self._container.name, object_name=object_name)
except ProviderError:
raise
except ConnectionError as ex:
raise DownloadError
except Exception as e:
self.log.exception('Storage helper problem for {}'.format(str(object_name)))
return None
@ -912,10 +922,20 @@ class _HttpDriver(object):
timeout = (5.0, 30.)
class _Container(object):
_default_backend_session = None
def __init__(self, name, retries=5, **kwargs):
self.name = name
self.session = get_http_session_with_retry(total=retries)
def get_headers(self, url):
if not self._default_backend_session:
from ..backend_interface.base import InterfaceBase
self._default_backend_session = InterfaceBase._get_default_session()
if url.startswith(self._default_backend_session.get_files_server_host()):
return self._default_backend_session.add_auth_headers({})
return None
def __init__(self, retries=5):
self._retries = retries
self._containers = {}
@ -928,7 +948,9 @@ class _HttpDriver(object):
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
url = object_name[:object_name.index('/')]
url_path = object_name[len(url)+1:]
res = container.session.post(container.name+url, files={url_path: iterator}, timeout=self.timeout)
full_url = container.name+url
res = container.session.post(full_url, files={url_path: iterator}, timeout=self.timeout,
headers=container.get_headers(full_url))
if res.status_code != requests.codes.ok:
raise ValueError('Failed uploading object %s (%d): %s' % (object_name, res.status_code, res.text))
return res
@ -943,7 +965,8 @@ class _HttpDriver(object):
container = self._containers[container_name]
# set stream flag before get request
container.session.stream = kwargs.get('stream', True)
res = container.session.get(''.join((container_name, object_name.lstrip('/'))), timeout=self.timeout)
url = ''.join((container_name, object_name.lstrip('/')))
res = container.session.get(url, timeout=self.timeout, headers=container.get_headers(url))
if res.status_code != requests.codes.ok:
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
return res
@ -1138,7 +1161,7 @@ class _Boto3Driver(object):
return self._containers[container_name]
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
import boto3
import boto3.s3.transfer
stream = _Stream(iterator)
try:
container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig(
@ -1151,7 +1174,7 @@ class _Boto3Driver(object):
return True
def upload_object(self, file_path, container, object_name, extra=None, **kwargs):
import boto3
import boto3.s3.transfer
try:
container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig(
use_threads=container.config.multipart,
@ -1188,7 +1211,7 @@ class _Boto3Driver(object):
log.error('Failed downloading: %s' % ex)
a_stream.close()
import boto3
import boto3.s3.transfer
# return iterable object
stream = _Stream()
container = self._containers[obj.container_name]
@ -1201,7 +1224,7 @@ class _Boto3Driver(object):
return stream
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
import boto3
import boto3.s3.transfer
p = Path(local_path)
if not overwrite_existing and p.is_file():
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
@ -1613,6 +1636,6 @@ class _AzureBlobServiceStorageDriver(object):
name,
)
return f.path.segments[0], join(*f.path.segments[1:])
return f.path.segments[0], os.path.join(*f.path.segments[1:])
return name

View File

@ -459,7 +459,7 @@ class Task(_Task):
Returns Task object based on either, task_id (system uuid) or task name
:param task_id: unique task id string (if exists other parameters are ignored)
:param project_name: project name (str) the task belogs to
:param project_name: project name (str) the task belongs to
:param task_name: task name (str) in within the selected project
:return: Task object
"""

View File

@ -24,6 +24,9 @@ _Version = collections.namedtuple(
class _BaseVersion(object):
def __init__(self, key):
self._key = key
def __hash__(self):
return hash(self._key)
@ -105,7 +108,7 @@ class Version(_BaseVersion):
)
# Generate a key which will be used for sorting
self._key = self._cmpkey(
key = self._cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
@ -114,6 +117,8 @@ class Version(_BaseVersion):
self._version.local,
)
super(Version, self).__init__(key)
def __repr__(self):
return "<Version({0})>".format(repr(str(self)))