mirror of
https://github.com/clearml/clearml
synced 2025-04-10 15:35:51 +00:00
Stability and cleanups
This commit is contained in:
parent
f8d3894e02
commit
64ba30df13
@ -1,3 +1,2 @@
|
||||
from .version import __version__
|
||||
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
|
||||
from .config import load as load_config
|
||||
|
@ -1,19 +1,78 @@
|
||||
from .session import Session
|
||||
import importlib
|
||||
import pkgutil
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from .session import Session
|
||||
from ..utilities.check_updates import Version
|
||||
|
||||
|
||||
class ApiServiceProxy(object):
|
||||
_main_services_module = "trains.backend_api.services"
|
||||
_max_available_version = None
|
||||
|
||||
def __init__(self, module):
|
||||
self.__wrapped_name__ = module
|
||||
self.__wrapped_version__ = Session.api_version
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ['__wrapped_name__', '__wrapped__', '__wrapped_version__']:
|
||||
if attr in ["__wrapped_name__", "__wrapped__", "__wrapped_version__"]:
|
||||
return self.__dict__.get(attr)
|
||||
|
||||
if not self.__dict__.get('__wrapped__') or self.__dict__.get('__wrapped_version__') != Session.api_version:
|
||||
self.__dict__['__wrapped_version__'] = Session.api_version
|
||||
self.__dict__['__wrapped__'] = importlib.import_module('.v'+str(Session.api_version).replace('.', '_') +
|
||||
'.' + self.__dict__.get('__wrapped_name__'),
|
||||
package='trains.backend_api.services')
|
||||
return getattr(self.__dict__['__wrapped__'], attr)
|
||||
if not self.__dict__.get("__wrapped__") or self.__dict__.get("__wrapped_version__") != Session.api_version:
|
||||
if not ApiServiceProxy._max_available_version:
|
||||
from ..backend_api import services
|
||||
ApiServiceProxy._max_available_version = max([
|
||||
Version(name[1:].replace("_", "."))
|
||||
for name in [
|
||||
module_name
|
||||
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
|
||||
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
|
||||
]])
|
||||
|
||||
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
|
||||
self.__dict__["__wrapped_version__"] = version
|
||||
name = ".v{}.{}".format(
|
||||
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")
|
||||
)
|
||||
self.__dict__["__wrapped__"] = self._import_module(name, self._main_services_module)
|
||||
|
||||
return getattr(self.__dict__["__wrapped__"], attr)
|
||||
|
||||
def _import_module(self, name, package):
|
||||
# type: (str, str) -> Any
|
||||
return importlib.import_module(name, package=package)
|
||||
|
||||
|
||||
class ExtApiServiceProxy(ApiServiceProxy):
|
||||
_extra_services_modules = []
|
||||
|
||||
def _import_module(self, name, _):
|
||||
# type: (str, str) -> Any
|
||||
for module_path in self._get_services_modules():
|
||||
try:
|
||||
return importlib.import_module(name, package=module_path)
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
raise ModuleNotFoundError(
|
||||
"No module '{}' in all predefined services module paths".format(name)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_services_module(cls, module_path):
|
||||
# type: (str) -> None
|
||||
"""
|
||||
Add an additional service module path to look in when importing types
|
||||
"""
|
||||
cls._extra_services_modules.append(module_path)
|
||||
|
||||
def _get_services_modules(self):
|
||||
"""
|
||||
Yield all services module paths.
|
||||
Paths are yielded in reverse order, so that users can add a services module that will override
|
||||
the built-in main service module path (e.g. in case a type defined in the built-in module was redefined)
|
||||
"""
|
||||
for path in reversed(self._extra_services_modules):
|
||||
yield path
|
||||
yield self._main_services_module
|
||||
|
@ -16,7 +16,7 @@ from .request import Request, BatchRequest
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ..version import __version__
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoginError(Exception):
|
||||
@ -225,6 +225,10 @@ class Session(TokenManager):
|
||||
self._session_requests += 1
|
||||
return res
|
||||
|
||||
def add_auth_headers(self, headers):
|
||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||
return headers
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
service,
|
||||
@ -249,8 +253,9 @@ class Session(TokenManager):
|
||||
:param async_enable: whether request is asynchronous
|
||||
:return: requests Response instance
|
||||
"""
|
||||
headers = headers.copy() if headers else {}
|
||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||
headers = self.add_auth_headers(
|
||||
headers.copy() if headers else {}
|
||||
)
|
||||
if async_enable:
|
||||
headers[self._ASYNC_HEADER] = "1"
|
||||
return self._send_request(
|
||||
@ -493,6 +498,7 @@ class Session(TokenManager):
|
||||
)
|
||||
|
||||
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
||||
res = None
|
||||
try:
|
||||
data = {"expiration_sec": exp} if exp else {}
|
||||
res = self._send_request(
|
||||
@ -518,8 +524,16 @@ class Session(TokenManager):
|
||||
return resp["data"]["token"]
|
||||
except LoginError:
|
||||
six.reraise(*sys.exc_info())
|
||||
except KeyError as ex:
|
||||
# check if this is a misconfigured api server (getting 200 without the data section)
|
||||
if res and res.status_code == 200:
|
||||
raise ValueError('It seems *api_server* is misconfigured. '
|
||||
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
|
||||
else:
|
||||
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
||||
"exception: {}".format(res, ex))
|
||||
except Exception as ex:
|
||||
raise LoginError(str(ex))
|
||||
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
||||
|
||||
def __str__(self):
|
||||
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
||||
|
@ -99,7 +99,7 @@ def get_http_session_with_retry(
|
||||
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
# update verify host certiface
|
||||
# update verify host certificate
|
||||
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
|
||||
if not session.verify and __disable_certificate_verification_warning < 2:
|
||||
# show warning
|
||||
|
@ -1 +0,0 @@
|
||||
__version__ = '2.0.0'
|
@ -12,7 +12,7 @@ from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||
from ..config import config_obj
|
||||
from ..config.defs import LOG_LEVEL_ENV_VAR
|
||||
from ..debugging import get_logger
|
||||
from ..backend_api.version import __version__
|
||||
from ..version import __version__
|
||||
from .session import SendError, SessionInterface
|
||||
|
||||
|
||||
|
@ -348,6 +348,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
def download_model_weights(self):
|
||||
""" Download the model weights into a local file in our cache """
|
||||
uri = self.data.uri
|
||||
if not uri or not uri.strip():
|
||||
return None
|
||||
|
||||
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
|
||||
filename = uri.split('/')[-1]
|
||||
ext = '.'.join(filename.split('.')[1:])
|
||||
|
@ -724,31 +724,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._edit(script=script)
|
||||
|
||||
@classmethod
|
||||
def create_new_task(cls, session, task_entry, log=None):
|
||||
"""
|
||||
Create a new task
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param task_entry: A task entry instance
|
||||
:type task_entry: tasks.CreateRequest
|
||||
:param log: Optional log
|
||||
:type log: logging.Logger
|
||||
:return: A new Task instance
|
||||
"""
|
||||
if isinstance(task_entry, dict):
|
||||
task_entry = tasks.CreateRequest(**task_entry)
|
||||
|
||||
assert isinstance(task_entry, tasks.CreateRequest)
|
||||
res = cls._send(session=session, req=task_entry, log=log)
|
||||
return cls(session, task_id=res.response.id)
|
||||
|
||||
@classmethod
|
||||
def clone_task(cls, cloned_task_id, name, comment=None, execution_overrides=None,
|
||||
def clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
|
||||
tags=None, parent=None, project=None, log=None, session=None):
|
||||
"""
|
||||
Clone a task
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param cloned_task_id: Task ID for the task to be cloned
|
||||
:type cloned_task_id: str
|
||||
:param name: New for the new task
|
||||
@ -760,13 +739,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
:type execution_overrides: dict
|
||||
:param tags: Optional updated model tags
|
||||
:type tags: [str]
|
||||
:param parent: Optional parent ID of the new task.
|
||||
:param parent: Optional parent Task ID of the new task.
|
||||
:type parent: str
|
||||
:param project: Optional project ID of the new task.
|
||||
If None, the new task will inherit the cloned task's project.
|
||||
:type parent: str
|
||||
:type project: str
|
||||
:param log: Log object used by the infrastructure.
|
||||
:type log: logging.Logger
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:return: The new tasks's ID
|
||||
"""
|
||||
|
||||
@ -781,7 +762,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
execution = ConfigTree.merge_configs(ConfigFactory.from_dict(execution),
|
||||
ConfigFactory.from_dict(execution_overrides or {}))
|
||||
req = tasks.CreateRequest(
|
||||
name=name,
|
||||
name=name or task.name,
|
||||
type=task.type,
|
||||
input=task.input,
|
||||
tags=tags if tags is not None else task.tags,
|
||||
@ -796,27 +777,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return res.response.id
|
||||
|
||||
@classmethod
|
||||
def enqueue_task(cls, task_id, session=None, queue_id=None, log=None):
|
||||
"""
|
||||
Enqueue a task for execution
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param task_id: ID of the task to be enqueued
|
||||
:type task_id: str
|
||||
:param queue_id: ID of the queue in which to enqueue the task. If not provided, the default queue will be used.
|
||||
:type queue_id: str
|
||||
:param log: Log object
|
||||
:type log: logging.Logger
|
||||
:return: enqueue response
|
||||
"""
|
||||
assert isinstance(task_id, six.string_types)
|
||||
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
||||
res = cls._send(session=session, req=req, log=log)
|
||||
resp = res.response
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, session, log=None, **kwargs):
|
||||
def get_all(cls, session=None, log=None, **kwargs):
|
||||
"""
|
||||
Get all tasks
|
||||
:param session: Session object used for sending requests to the API
|
||||
@ -827,6 +788,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
:type kwargs: dict
|
||||
:return: API response
|
||||
"""
|
||||
session = session if session else cls._get_default_session()
|
||||
req = tasks.GetAllRequest(**kwargs)
|
||||
res = cls._send(session=session, req=req, log=log)
|
||||
return res
|
||||
|
@ -62,7 +62,7 @@ class PostImportHookPatching(object):
|
||||
@staticmethod
|
||||
def add_on_import(name, func):
|
||||
PostImportHookPatching._init_hook()
|
||||
if not name in PostImportHookPatching._post_import_hooks or \
|
||||
if name not in PostImportHookPatching._post_import_hooks or \
|
||||
func not in PostImportHookPatching._post_import_hooks[name]:
|
||||
PostImportHookPatching._post_import_hooks[name].append(func)
|
||||
|
||||
|
@ -107,14 +107,11 @@ class PatchedMatplotlib:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task):
|
||||
if PatchedMatplotlib.patch_matplotlib():
|
||||
PatchedMatplotlib._current_task = task
|
||||
# update api version
|
||||
from ..backend_api import Session
|
||||
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
||||
|
||||
# create plotly renderer
|
||||
try:
|
||||
from plotly import optional_imports
|
||||
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
|
||||
@ -122,6 +119,13 @@ class PatchedMatplotlib:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task):
|
||||
if PatchedMatplotlib.patch_matplotlib():
|
||||
PatchedMatplotlib._current_task = task
|
||||
|
||||
@staticmethod
|
||||
def patched_imshow(*args, **kw):
|
||||
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
||||
|
@ -370,6 +370,8 @@ class InputModel(BaseModel):
|
||||
"""
|
||||
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||
weights_url = StorageHelper.conform_url(weights_url)
|
||||
if not weights_url:
|
||||
raise ValueError("Please provide a valid weights_url parameter")
|
||||
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||
uri=[weights_url],
|
||||
only_fields=["id", "name"],
|
||||
|
@ -1,5 +1,4 @@
|
||||
import getpass
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
@ -20,6 +19,7 @@ from libcloud.common.types import ProviderError, LibcloudError
|
||||
from libcloud.storage.providers import get_driver
|
||||
from libcloud.storage.types import Provider
|
||||
from pathlib2 import Path
|
||||
from requests.exceptions import ConnectionError
|
||||
from six import binary_type
|
||||
from six.moves.queue import Queue, Empty
|
||||
from six.moves.urllib.parse import urlparse
|
||||
@ -47,6 +47,10 @@ class StorageError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _DownloadProgressReport(object):
|
||||
def __init__(self, total_size, verbose, remote_path, report_chunk_size_mb, log):
|
||||
self._total_size = total_size
|
||||
@ -700,6 +704,8 @@ class StorageHelper(object):
|
||||
self._log.info(
|
||||
'Downloaded %.2f MB successfully from %s , saved to %s' % (dl_total_mb, remote_path, local_path))
|
||||
return local_path
|
||||
except DownloadError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
self._log.error("Could not download %s , err: %s " % (remote_path, str(e)))
|
||||
if delete_on_failure:
|
||||
@ -715,6 +721,8 @@ class StorageHelper(object):
|
||||
try:
|
||||
obj = self._get_object(remote_path)
|
||||
return self._driver.download_object_as_stream(obj, chunk_size=chunk_size)
|
||||
except DownloadError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
self._log.error("Could not download file : %s, err:%s " % (remote_path, str(e)))
|
||||
return None
|
||||
@ -902,6 +910,8 @@ class StorageHelper(object):
|
||||
return self._driver.get_object(container_name=self._container.name, object_name=object_name)
|
||||
except ProviderError:
|
||||
raise
|
||||
except ConnectionError as ex:
|
||||
raise DownloadError
|
||||
except Exception as e:
|
||||
self.log.exception('Storage helper problem for {}'.format(str(object_name)))
|
||||
return None
|
||||
@ -912,10 +922,20 @@ class _HttpDriver(object):
|
||||
timeout = (5.0, 30.)
|
||||
|
||||
class _Container(object):
|
||||
_default_backend_session = None
|
||||
|
||||
def __init__(self, name, retries=5, **kwargs):
|
||||
self.name = name
|
||||
self.session = get_http_session_with_retry(total=retries)
|
||||
|
||||
def get_headers(self, url):
|
||||
if not self._default_backend_session:
|
||||
from ..backend_interface.base import InterfaceBase
|
||||
self._default_backend_session = InterfaceBase._get_default_session()
|
||||
if url.startswith(self._default_backend_session.get_files_server_host()):
|
||||
return self._default_backend_session.add_auth_headers({})
|
||||
return None
|
||||
|
||||
def __init__(self, retries=5):
|
||||
self._retries = retries
|
||||
self._containers = {}
|
||||
@ -928,7 +948,9 @@ class _HttpDriver(object):
|
||||
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
||||
url = object_name[:object_name.index('/')]
|
||||
url_path = object_name[len(url)+1:]
|
||||
res = container.session.post(container.name+url, files={url_path: iterator}, timeout=self.timeout)
|
||||
full_url = container.name+url
|
||||
res = container.session.post(full_url, files={url_path: iterator}, timeout=self.timeout,
|
||||
headers=container.get_headers(full_url))
|
||||
if res.status_code != requests.codes.ok:
|
||||
raise ValueError('Failed uploading object %s (%d): %s' % (object_name, res.status_code, res.text))
|
||||
return res
|
||||
@ -943,7 +965,8 @@ class _HttpDriver(object):
|
||||
container = self._containers[container_name]
|
||||
# set stream flag before get request
|
||||
container.session.stream = kwargs.get('stream', True)
|
||||
res = container.session.get(''.join((container_name, object_name.lstrip('/'))), timeout=self.timeout)
|
||||
url = ''.join((container_name, object_name.lstrip('/')))
|
||||
res = container.session.get(url, timeout=self.timeout, headers=container.get_headers(url))
|
||||
if res.status_code != requests.codes.ok:
|
||||
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
|
||||
return res
|
||||
@ -1138,7 +1161,7 @@ class _Boto3Driver(object):
|
||||
return self._containers[container_name]
|
||||
|
||||
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
||||
import boto3
|
||||
import boto3.s3.transfer
|
||||
stream = _Stream(iterator)
|
||||
try:
|
||||
container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig(
|
||||
@ -1151,7 +1174,7 @@ class _Boto3Driver(object):
|
||||
return True
|
||||
|
||||
def upload_object(self, file_path, container, object_name, extra=None, **kwargs):
|
||||
import boto3
|
||||
import boto3.s3.transfer
|
||||
try:
|
||||
container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig(
|
||||
use_threads=container.config.multipart,
|
||||
@ -1188,7 +1211,7 @@ class _Boto3Driver(object):
|
||||
log.error('Failed downloading: %s' % ex)
|
||||
a_stream.close()
|
||||
|
||||
import boto3
|
||||
import boto3.s3.transfer
|
||||
# return iterable object
|
||||
stream = _Stream()
|
||||
container = self._containers[obj.container_name]
|
||||
@ -1201,7 +1224,7 @@ class _Boto3Driver(object):
|
||||
return stream
|
||||
|
||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
||||
import boto3
|
||||
import boto3.s3.transfer
|
||||
p = Path(local_path)
|
||||
if not overwrite_existing and p.is_file():
|
||||
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
||||
@ -1613,6 +1636,6 @@ class _AzureBlobServiceStorageDriver(object):
|
||||
name,
|
||||
)
|
||||
|
||||
return f.path.segments[0], join(*f.path.segments[1:])
|
||||
return f.path.segments[0], os.path.join(*f.path.segments[1:])
|
||||
|
||||
return name
|
||||
|
@ -459,7 +459,7 @@ class Task(_Task):
|
||||
Returns Task object based on either, task_id (system uuid) or task name
|
||||
|
||||
:param task_id: unique task id string (if exists other parameters are ignored)
|
||||
:param project_name: project name (str) the task belogs to
|
||||
:param project_name: project name (str) the task belongs to
|
||||
:param task_name: task name (str) in within the selected project
|
||||
:return: Task object
|
||||
"""
|
||||
|
@ -24,6 +24,9 @@ _Version = collections.namedtuple(
|
||||
|
||||
|
||||
class _BaseVersion(object):
|
||||
def __init__(self, key):
|
||||
self._key = key
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._key)
|
||||
|
||||
@ -105,7 +108,7 @@ class Version(_BaseVersion):
|
||||
)
|
||||
|
||||
# Generate a key which will be used for sorting
|
||||
self._key = self._cmpkey(
|
||||
key = self._cmpkey(
|
||||
self._version.epoch,
|
||||
self._version.release,
|
||||
self._version.pre,
|
||||
@ -114,6 +117,8 @@ class Version(_BaseVersion):
|
||||
self._version.local,
|
||||
)
|
||||
|
||||
super(Version, self).__init__(key)
|
||||
|
||||
def __repr__(self):
|
||||
return "<Version({0})>".format(repr(str(self)))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user