mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Stability and cleanups
This commit is contained in:
parent
f8d3894e02
commit
64ba30df13
@ -1,3 +1,2 @@
|
|||||||
from .version import __version__
|
|
||||||
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
|
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
|
||||||
from .config import load as load_config
|
from .config import load as load_config
|
||||||
|
@ -1,19 +1,78 @@
|
|||||||
from .session import Session
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import pkgutil
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .session import Session
|
||||||
|
from ..utilities.check_updates import Version
|
||||||
|
|
||||||
|
|
||||||
class ApiServiceProxy(object):
|
class ApiServiceProxy(object):
|
||||||
|
_main_services_module = "trains.backend_api.services"
|
||||||
|
_max_available_version = None
|
||||||
|
|
||||||
def __init__(self, module):
|
def __init__(self, module):
|
||||||
self.__wrapped_name__ = module
|
self.__wrapped_name__ = module
|
||||||
self.__wrapped_version__ = Session.api_version
|
self.__wrapped_version__ = Session.api_version
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
if attr in ['__wrapped_name__', '__wrapped__', '__wrapped_version__']:
|
if attr in ["__wrapped_name__", "__wrapped__", "__wrapped_version__"]:
|
||||||
return self.__dict__.get(attr)
|
return self.__dict__.get(attr)
|
||||||
|
|
||||||
if not self.__dict__.get('__wrapped__') or self.__dict__.get('__wrapped_version__') != Session.api_version:
|
if not self.__dict__.get("__wrapped__") or self.__dict__.get("__wrapped_version__") != Session.api_version:
|
||||||
self.__dict__['__wrapped_version__'] = Session.api_version
|
if not ApiServiceProxy._max_available_version:
|
||||||
self.__dict__['__wrapped__'] = importlib.import_module('.v'+str(Session.api_version).replace('.', '_') +
|
from ..backend_api import services
|
||||||
'.' + self.__dict__.get('__wrapped_name__'),
|
ApiServiceProxy._max_available_version = max([
|
||||||
package='trains.backend_api.services')
|
Version(name[1:].replace("_", "."))
|
||||||
return getattr(self.__dict__['__wrapped__'], attr)
|
for name in [
|
||||||
|
module_name
|
||||||
|
for _, module_name, _ in pkgutil.iter_modules(services.__path__)
|
||||||
|
if re.match(r"^v[0-9]+_[0-9]+$", module_name)
|
||||||
|
]])
|
||||||
|
|
||||||
|
version = str(min(Version(Session.api_version), ApiServiceProxy._max_available_version))
|
||||||
|
self.__dict__["__wrapped_version__"] = version
|
||||||
|
name = ".v{}.{}".format(
|
||||||
|
version.replace(".", "_"), self.__dict__.get("__wrapped_name__")
|
||||||
|
)
|
||||||
|
self.__dict__["__wrapped__"] = self._import_module(name, self._main_services_module)
|
||||||
|
|
||||||
|
return getattr(self.__dict__["__wrapped__"], attr)
|
||||||
|
|
||||||
|
def _import_module(self, name, package):
|
||||||
|
# type: (str, str) -> Any
|
||||||
|
return importlib.import_module(name, package=package)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtApiServiceProxy(ApiServiceProxy):
|
||||||
|
_extra_services_modules = []
|
||||||
|
|
||||||
|
def _import_module(self, name, _):
|
||||||
|
# type: (str, str) -> Any
|
||||||
|
for module_path in self._get_services_modules():
|
||||||
|
try:
|
||||||
|
return importlib.import_module(name, package=module_path)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"No module '{}' in all predefined services module paths".format(name)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_services_module(cls, module_path):
|
||||||
|
# type: (str) -> None
|
||||||
|
"""
|
||||||
|
Add an additional service module path to look in when importing types
|
||||||
|
"""
|
||||||
|
cls._extra_services_modules.append(module_path)
|
||||||
|
|
||||||
|
def _get_services_modules(self):
|
||||||
|
"""
|
||||||
|
Yield all services module paths.
|
||||||
|
Paths are yielded in reverse order, so that users can add a services module that will override
|
||||||
|
the built-in main service module path (e.g. in case a type defined in the built-in module was redefined)
|
||||||
|
"""
|
||||||
|
for path in reversed(self._extra_services_modules):
|
||||||
|
yield path
|
||||||
|
yield self._main_services_module
|
||||||
|
@ -16,7 +16,7 @@ from .request import Request, BatchRequest
|
|||||||
from .token_manager import TokenManager
|
from .token_manager import TokenManager
|
||||||
from ..config import load
|
from ..config import load
|
||||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||||
from ..version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
|
|
||||||
class LoginError(Exception):
|
class LoginError(Exception):
|
||||||
@ -225,6 +225,10 @@ class Session(TokenManager):
|
|||||||
self._session_requests += 1
|
self._session_requests += 1
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def add_auth_headers(self, headers):
|
||||||
|
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||||
|
return headers
|
||||||
|
|
||||||
def send_request(
|
def send_request(
|
||||||
self,
|
self,
|
||||||
service,
|
service,
|
||||||
@ -249,8 +253,9 @@ class Session(TokenManager):
|
|||||||
:param async_enable: whether request is asynchronous
|
:param async_enable: whether request is asynchronous
|
||||||
:return: requests Response instance
|
:return: requests Response instance
|
||||||
"""
|
"""
|
||||||
headers = headers.copy() if headers else {}
|
headers = self.add_auth_headers(
|
||||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
headers.copy() if headers else {}
|
||||||
|
)
|
||||||
if async_enable:
|
if async_enable:
|
||||||
headers[self._ASYNC_HEADER] = "1"
|
headers[self._ASYNC_HEADER] = "1"
|
||||||
return self._send_request(
|
return self._send_request(
|
||||||
@ -493,6 +498,7 @@ class Session(TokenManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
||||||
|
res = None
|
||||||
try:
|
try:
|
||||||
data = {"expiration_sec": exp} if exp else {}
|
data = {"expiration_sec": exp} if exp else {}
|
||||||
res = self._send_request(
|
res = self._send_request(
|
||||||
@ -518,8 +524,16 @@ class Session(TokenManager):
|
|||||||
return resp["data"]["token"]
|
return resp["data"]["token"]
|
||||||
except LoginError:
|
except LoginError:
|
||||||
six.reraise(*sys.exc_info())
|
six.reraise(*sys.exc_info())
|
||||||
|
except KeyError as ex:
|
||||||
|
# check if this is a misconfigured api server (getting 200 without the data section)
|
||||||
|
if res and res.status_code == 200:
|
||||||
|
raise ValueError('It seems *api_server* is misconfigured. '
|
||||||
|
'Is this the TRAINS API server {} ?'.format(self.get_api_server_host()))
|
||||||
|
else:
|
||||||
|
raise LoginError("Response data mismatch: No 'token' in 'data' value from res, receive : {}, "
|
||||||
|
"exception: {}".format(res, ex))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise LoginError(str(ex))
|
raise LoginError('Unrecognized Authentication Error: {} {}'.format(type(ex), ex))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
||||||
|
@ -99,7 +99,7 @@ def get_http_session_with_retry(
|
|||||||
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
||||||
session.mount('http://', adapter)
|
session.mount('http://', adapter)
|
||||||
session.mount('https://', adapter)
|
session.mount('https://', adapter)
|
||||||
# update verify host certiface
|
# update verify host certificate
|
||||||
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
|
session.verify = ENV_HOST_VERIFY_CERT.get(default=get_config().get('api.verify_certificate', True))
|
||||||
if not session.verify and __disable_certificate_verification_warning < 2:
|
if not session.verify and __disable_certificate_verification_warning < 2:
|
||||||
# show warning
|
# show warning
|
||||||
|
@ -1 +0,0 @@
|
|||||||
__version__ = '2.0.0'
|
|
@ -12,7 +12,7 @@ from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
|
|||||||
from ..config import config_obj
|
from ..config import config_obj
|
||||||
from ..config.defs import LOG_LEVEL_ENV_VAR
|
from ..config.defs import LOG_LEVEL_ENV_VAR
|
||||||
from ..debugging import get_logger
|
from ..debugging import get_logger
|
||||||
from ..backend_api.version import __version__
|
from ..version import __version__
|
||||||
from .session import SendError, SessionInterface
|
from .session import SendError, SessionInterface
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,6 +348,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
def download_model_weights(self):
|
def download_model_weights(self):
|
||||||
""" Download the model weights into a local file in our cache """
|
""" Download the model weights into a local file in our cache """
|
||||||
uri = self.data.uri
|
uri = self.data.uri
|
||||||
|
if not uri or not uri.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
|
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
|
||||||
filename = uri.split('/')[-1]
|
filename = uri.split('/')[-1]
|
||||||
ext = '.'.join(filename.split('.')[1:])
|
ext = '.'.join(filename.split('.')[1:])
|
||||||
|
@ -724,31 +724,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
self._edit(script=script)
|
self._edit(script=script)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_new_task(cls, session, task_entry, log=None):
|
def clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
|
||||||
"""
|
|
||||||
Create a new task
|
|
||||||
:param session: Session object used for sending requests to the API
|
|
||||||
:type session: Session
|
|
||||||
:param task_entry: A task entry instance
|
|
||||||
:type task_entry: tasks.CreateRequest
|
|
||||||
:param log: Optional log
|
|
||||||
:type log: logging.Logger
|
|
||||||
:return: A new Task instance
|
|
||||||
"""
|
|
||||||
if isinstance(task_entry, dict):
|
|
||||||
task_entry = tasks.CreateRequest(**task_entry)
|
|
||||||
|
|
||||||
assert isinstance(task_entry, tasks.CreateRequest)
|
|
||||||
res = cls._send(session=session, req=task_entry, log=log)
|
|
||||||
return cls(session, task_id=res.response.id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clone_task(cls, cloned_task_id, name, comment=None, execution_overrides=None,
|
|
||||||
tags=None, parent=None, project=None, log=None, session=None):
|
tags=None, parent=None, project=None, log=None, session=None):
|
||||||
"""
|
"""
|
||||||
Clone a task
|
Clone a task
|
||||||
:param session: Session object used for sending requests to the API
|
|
||||||
:type session: Session
|
|
||||||
:param cloned_task_id: Task ID for the task to be cloned
|
:param cloned_task_id: Task ID for the task to be cloned
|
||||||
:type cloned_task_id: str
|
:type cloned_task_id: str
|
||||||
:param name: New for the new task
|
:param name: New for the new task
|
||||||
@ -760,13 +739,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
:type execution_overrides: dict
|
:type execution_overrides: dict
|
||||||
:param tags: Optional updated model tags
|
:param tags: Optional updated model tags
|
||||||
:type tags: [str]
|
:type tags: [str]
|
||||||
:param parent: Optional parent ID of the new task.
|
:param parent: Optional parent Task ID of the new task.
|
||||||
:type parent: str
|
:type parent: str
|
||||||
:param project: Optional project ID of the new task.
|
:param project: Optional project ID of the new task.
|
||||||
If None, the new task will inherit the cloned task's project.
|
If None, the new task will inherit the cloned task's project.
|
||||||
:type parent: str
|
:type project: str
|
||||||
:param log: Log object used by the infrastructure.
|
:param log: Log object used by the infrastructure.
|
||||||
:type log: logging.Logger
|
:type log: logging.Logger
|
||||||
|
:param session: Session object used for sending requests to the API
|
||||||
|
:type session: Session
|
||||||
:return: The new tasks's ID
|
:return: The new tasks's ID
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -781,7 +762,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
execution = ConfigTree.merge_configs(ConfigFactory.from_dict(execution),
|
execution = ConfigTree.merge_configs(ConfigFactory.from_dict(execution),
|
||||||
ConfigFactory.from_dict(execution_overrides or {}))
|
ConfigFactory.from_dict(execution_overrides or {}))
|
||||||
req = tasks.CreateRequest(
|
req = tasks.CreateRequest(
|
||||||
name=name,
|
name=name or task.name,
|
||||||
type=task.type,
|
type=task.type,
|
||||||
input=task.input,
|
input=task.input,
|
||||||
tags=tags if tags is not None else task.tags,
|
tags=tags if tags is not None else task.tags,
|
||||||
@ -796,27 +777,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
return res.response.id
|
return res.response.id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enqueue_task(cls, task_id, session=None, queue_id=None, log=None):
|
def get_all(cls, session=None, log=None, **kwargs):
|
||||||
"""
|
|
||||||
Enqueue a task for execution
|
|
||||||
:param session: Session object used for sending requests to the API
|
|
||||||
:type session: Session
|
|
||||||
:param task_id: ID of the task to be enqueued
|
|
||||||
:type task_id: str
|
|
||||||
:param queue_id: ID of the queue in which to enqueue the task. If not provided, the default queue will be used.
|
|
||||||
:type queue_id: str
|
|
||||||
:param log: Log object
|
|
||||||
:type log: logging.Logger
|
|
||||||
:return: enqueue response
|
|
||||||
"""
|
|
||||||
assert isinstance(task_id, six.string_types)
|
|
||||||
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
|
||||||
res = cls._send(session=session, req=req, log=log)
|
|
||||||
resp = res.response
|
|
||||||
return resp
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_all(cls, session, log=None, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Get all tasks
|
Get all tasks
|
||||||
:param session: Session object used for sending requests to the API
|
:param session: Session object used for sending requests to the API
|
||||||
@ -827,6 +788,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
:type kwargs: dict
|
:type kwargs: dict
|
||||||
:return: API response
|
:return: API response
|
||||||
"""
|
"""
|
||||||
|
session = session if session else cls._get_default_session()
|
||||||
req = tasks.GetAllRequest(**kwargs)
|
req = tasks.GetAllRequest(**kwargs)
|
||||||
res = cls._send(session=session, req=req, log=log)
|
res = cls._send(session=session, req=req, log=log)
|
||||||
return res
|
return res
|
||||||
|
@ -62,7 +62,7 @@ class PostImportHookPatching(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_on_import(name, func):
|
def add_on_import(name, func):
|
||||||
PostImportHookPatching._init_hook()
|
PostImportHookPatching._init_hook()
|
||||||
if not name in PostImportHookPatching._post_import_hooks or \
|
if name not in PostImportHookPatching._post_import_hooks or \
|
||||||
func not in PostImportHookPatching._post_import_hooks[name]:
|
func not in PostImportHookPatching._post_import_hooks[name]:
|
||||||
PostImportHookPatching._post_import_hooks[name].append(func)
|
PostImportHookPatching._post_import_hooks[name].append(func)
|
||||||
|
|
||||||
|
@ -107,14 +107,11 @@ class PatchedMatplotlib:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return True
|
# update api version
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_current_task(task):
|
|
||||||
if PatchedMatplotlib.patch_matplotlib():
|
|
||||||
PatchedMatplotlib._current_task = task
|
|
||||||
from ..backend_api import Session
|
from ..backend_api import Session
|
||||||
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
||||||
|
|
||||||
|
# create plotly renderer
|
||||||
try:
|
try:
|
||||||
from plotly import optional_imports
|
from plotly import optional_imports
|
||||||
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
|
PatchedMatplotlib._matplotlylib = optional_imports.get_module('plotly.matplotlylib')
|
||||||
@ -122,6 +119,13 @@ class PatchedMatplotlib:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task):
|
||||||
|
if PatchedMatplotlib.patch_matplotlib():
|
||||||
|
PatchedMatplotlib._current_task = task
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_imshow(*args, **kw):
|
def patched_imshow(*args, **kw):
|
||||||
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
||||||
|
@ -370,6 +370,8 @@ class InputModel(BaseModel):
|
|||||||
"""
|
"""
|
||||||
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
config_text = cls._resolve_config(config_text=config_text, config_dict=config_dict)
|
||||||
weights_url = StorageHelper.conform_url(weights_url)
|
weights_url = StorageHelper.conform_url(weights_url)
|
||||||
|
if not weights_url:
|
||||||
|
raise ValueError("Please provide a valid weights_url parameter")
|
||||||
result = _Model._get_default_session().send(models.GetAllRequest(
|
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||||
uri=[weights_url],
|
uri=[weights_url],
|
||||||
only_fields=["id", "name"],
|
only_fields=["id", "name"],
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import getpass
|
import getpass
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@ -20,6 +19,7 @@ from libcloud.common.types import ProviderError, LibcloudError
|
|||||||
from libcloud.storage.providers import get_driver
|
from libcloud.storage.providers import get_driver
|
||||||
from libcloud.storage.types import Provider
|
from libcloud.storage.types import Provider
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
from requests.exceptions import ConnectionError
|
||||||
from six import binary_type
|
from six import binary_type
|
||||||
from six.moves.queue import Queue, Empty
|
from six.moves.queue import Queue, Empty
|
||||||
from six.moves.urllib.parse import urlparse
|
from six.moves.urllib.parse import urlparse
|
||||||
@ -47,6 +47,10 @@ class StorageError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _DownloadProgressReport(object):
|
class _DownloadProgressReport(object):
|
||||||
def __init__(self, total_size, verbose, remote_path, report_chunk_size_mb, log):
|
def __init__(self, total_size, verbose, remote_path, report_chunk_size_mb, log):
|
||||||
self._total_size = total_size
|
self._total_size = total_size
|
||||||
@ -700,6 +704,8 @@ class StorageHelper(object):
|
|||||||
self._log.info(
|
self._log.info(
|
||||||
'Downloaded %.2f MB successfully from %s , saved to %s' % (dl_total_mb, remote_path, local_path))
|
'Downloaded %.2f MB successfully from %s , saved to %s' % (dl_total_mb, remote_path, local_path))
|
||||||
return local_path
|
return local_path
|
||||||
|
except DownloadError as e:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log.error("Could not download %s , err: %s " % (remote_path, str(e)))
|
self._log.error("Could not download %s , err: %s " % (remote_path, str(e)))
|
||||||
if delete_on_failure:
|
if delete_on_failure:
|
||||||
@ -715,6 +721,8 @@ class StorageHelper(object):
|
|||||||
try:
|
try:
|
||||||
obj = self._get_object(remote_path)
|
obj = self._get_object(remote_path)
|
||||||
return self._driver.download_object_as_stream(obj, chunk_size=chunk_size)
|
return self._driver.download_object_as_stream(obj, chunk_size=chunk_size)
|
||||||
|
except DownloadError as e:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log.error("Could not download file : %s, err:%s " % (remote_path, str(e)))
|
self._log.error("Could not download file : %s, err:%s " % (remote_path, str(e)))
|
||||||
return None
|
return None
|
||||||
@ -902,6 +910,8 @@ class StorageHelper(object):
|
|||||||
return self._driver.get_object(container_name=self._container.name, object_name=object_name)
|
return self._driver.get_object(container_name=self._container.name, object_name=object_name)
|
||||||
except ProviderError:
|
except ProviderError:
|
||||||
raise
|
raise
|
||||||
|
except ConnectionError as ex:
|
||||||
|
raise DownloadError
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.exception('Storage helper problem for {}'.format(str(object_name)))
|
self.log.exception('Storage helper problem for {}'.format(str(object_name)))
|
||||||
return None
|
return None
|
||||||
@ -912,10 +922,20 @@ class _HttpDriver(object):
|
|||||||
timeout = (5.0, 30.)
|
timeout = (5.0, 30.)
|
||||||
|
|
||||||
class _Container(object):
|
class _Container(object):
|
||||||
|
_default_backend_session = None
|
||||||
|
|
||||||
def __init__(self, name, retries=5, **kwargs):
|
def __init__(self, name, retries=5, **kwargs):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.session = get_http_session_with_retry(total=retries)
|
self.session = get_http_session_with_retry(total=retries)
|
||||||
|
|
||||||
|
def get_headers(self, url):
|
||||||
|
if not self._default_backend_session:
|
||||||
|
from ..backend_interface.base import InterfaceBase
|
||||||
|
self._default_backend_session = InterfaceBase._get_default_session()
|
||||||
|
if url.startswith(self._default_backend_session.get_files_server_host()):
|
||||||
|
return self._default_backend_session.add_auth_headers({})
|
||||||
|
return None
|
||||||
|
|
||||||
def __init__(self, retries=5):
|
def __init__(self, retries=5):
|
||||||
self._retries = retries
|
self._retries = retries
|
||||||
self._containers = {}
|
self._containers = {}
|
||||||
@ -928,7 +948,9 @@ class _HttpDriver(object):
|
|||||||
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
||||||
url = object_name[:object_name.index('/')]
|
url = object_name[:object_name.index('/')]
|
||||||
url_path = object_name[len(url)+1:]
|
url_path = object_name[len(url)+1:]
|
||||||
res = container.session.post(container.name+url, files={url_path: iterator}, timeout=self.timeout)
|
full_url = container.name+url
|
||||||
|
res = container.session.post(full_url, files={url_path: iterator}, timeout=self.timeout,
|
||||||
|
headers=container.get_headers(full_url))
|
||||||
if res.status_code != requests.codes.ok:
|
if res.status_code != requests.codes.ok:
|
||||||
raise ValueError('Failed uploading object %s (%d): %s' % (object_name, res.status_code, res.text))
|
raise ValueError('Failed uploading object %s (%d): %s' % (object_name, res.status_code, res.text))
|
||||||
return res
|
return res
|
||||||
@ -943,7 +965,8 @@ class _HttpDriver(object):
|
|||||||
container = self._containers[container_name]
|
container = self._containers[container_name]
|
||||||
# set stream flag before get request
|
# set stream flag before get request
|
||||||
container.session.stream = kwargs.get('stream', True)
|
container.session.stream = kwargs.get('stream', True)
|
||||||
res = container.session.get(''.join((container_name, object_name.lstrip('/'))), timeout=self.timeout)
|
url = ''.join((container_name, object_name.lstrip('/')))
|
||||||
|
res = container.session.get(url, timeout=self.timeout, headers=container.get_headers(url))
|
||||||
if res.status_code != requests.codes.ok:
|
if res.status_code != requests.codes.ok:
|
||||||
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
|
raise ValueError('Failed getting object %s (%d): %s' % (object_name, res.status_code, res.text))
|
||||||
return res
|
return res
|
||||||
@ -1138,7 +1161,7 @@ class _Boto3Driver(object):
|
|||||||
return self._containers[container_name]
|
return self._containers[container_name]
|
||||||
|
|
||||||
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
def upload_object_via_stream(self, iterator, container, object_name, extra=None, **kwargs):
|
||||||
import boto3
|
import boto3.s3.transfer
|
||||||
stream = _Stream(iterator)
|
stream = _Stream(iterator)
|
||||||
try:
|
try:
|
||||||
container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig(
|
container.bucket.upload_fileobj(stream, object_name, Config=boto3.s3.transfer.TransferConfig(
|
||||||
@ -1151,7 +1174,7 @@ class _Boto3Driver(object):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def upload_object(self, file_path, container, object_name, extra=None, **kwargs):
|
def upload_object(self, file_path, container, object_name, extra=None, **kwargs):
|
||||||
import boto3
|
import boto3.s3.transfer
|
||||||
try:
|
try:
|
||||||
container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig(
|
container.bucket.upload_file(file_path, object_name, Config=boto3.s3.transfer.TransferConfig(
|
||||||
use_threads=container.config.multipart,
|
use_threads=container.config.multipart,
|
||||||
@ -1188,7 +1211,7 @@ class _Boto3Driver(object):
|
|||||||
log.error('Failed downloading: %s' % ex)
|
log.error('Failed downloading: %s' % ex)
|
||||||
a_stream.close()
|
a_stream.close()
|
||||||
|
|
||||||
import boto3
|
import boto3.s3.transfer
|
||||||
# return iterable object
|
# return iterable object
|
||||||
stream = _Stream()
|
stream = _Stream()
|
||||||
container = self._containers[obj.container_name]
|
container = self._containers[obj.container_name]
|
||||||
@ -1201,7 +1224,7 @@ class _Boto3Driver(object):
|
|||||||
return stream
|
return stream
|
||||||
|
|
||||||
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
def download_object(self, obj, local_path, overwrite_existing=True, delete_on_failure=True, callback=None):
|
||||||
import boto3
|
import boto3.s3.transfer
|
||||||
p = Path(local_path)
|
p = Path(local_path)
|
||||||
if not overwrite_existing and p.is_file():
|
if not overwrite_existing and p.is_file():
|
||||||
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
log.warn('failed saving after download: overwrite=False and file exists (%s)' % str(p))
|
||||||
@ -1613,6 +1636,6 @@ class _AzureBlobServiceStorageDriver(object):
|
|||||||
name,
|
name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return f.path.segments[0], join(*f.path.segments[1:])
|
return f.path.segments[0], os.path.join(*f.path.segments[1:])
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
@ -459,7 +459,7 @@ class Task(_Task):
|
|||||||
Returns Task object based on either, task_id (system uuid) or task name
|
Returns Task object based on either, task_id (system uuid) or task name
|
||||||
|
|
||||||
:param task_id: unique task id string (if exists other parameters are ignored)
|
:param task_id: unique task id string (if exists other parameters are ignored)
|
||||||
:param project_name: project name (str) the task belogs to
|
:param project_name: project name (str) the task belongs to
|
||||||
:param task_name: task name (str) in within the selected project
|
:param task_name: task name (str) in within the selected project
|
||||||
:return: Task object
|
:return: Task object
|
||||||
"""
|
"""
|
||||||
|
@ -24,6 +24,9 @@ _Version = collections.namedtuple(
|
|||||||
|
|
||||||
|
|
||||||
class _BaseVersion(object):
|
class _BaseVersion(object):
|
||||||
|
def __init__(self, key):
|
||||||
|
self._key = key
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self._key)
|
return hash(self._key)
|
||||||
|
|
||||||
@ -105,7 +108,7 @@ class Version(_BaseVersion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Generate a key which will be used for sorting
|
# Generate a key which will be used for sorting
|
||||||
self._key = self._cmpkey(
|
key = self._cmpkey(
|
||||||
self._version.epoch,
|
self._version.epoch,
|
||||||
self._version.release,
|
self._version.release,
|
||||||
self._version.pre,
|
self._version.pre,
|
||||||
@ -114,6 +117,8 @@ class Version(_BaseVersion):
|
|||||||
self._version.local,
|
self._version.local,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
super(Version, self).__init__(key)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<Version({0})>".format(repr(str(self)))
|
return "<Version({0})>".format(repr(str(self)))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user