From a169b43885117a9fc0f39f525b08c45f55a9759d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jan 2020 17:16:51 +0200 Subject: [PATCH] Add Task.upload_artifact support for external URLs --- trains/binding/artifacts.py | 43 +++++++++++++++++++++++-------------- trains/storage/helper.py | 18 +++++++++++++++- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py index df0ae615..151a48aa 100644 --- a/trains/binding/artifacts.py +++ b/trains/binding/artifacts.py @@ -274,6 +274,7 @@ class Artifacts(object): artifact_type_data = tasks.ArtifactTypeData() override_filename_in_uri = None override_filename_ext_in_uri = None + uri = None if np and isinstance(artifact_object, np.ndarray): artifact_type = 'numpy' artifact_type_data.content_type = 'application/numpy' @@ -316,10 +317,15 @@ class Artifacts(object): os.close(fd) artifact_type_data.preview = preview delete_after_upload = True - elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path): + elif isinstance(artifact_object, six.string_types) and urlparse(artifact_object).scheme in remote_driver_schemes: + # we should not upload this, just register + local_filename = None + uri = artifact_object + artifact_type = 'custom' + artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0] + elif isinstance(artifact_object, six.string_types + (Path,)): # check if single file - if isinstance(artifact_object, six.string_types): - artifact_object = Path(artifact_object) + artifact_object = Path(artifact_object) artifact_object.expanduser().absolute() try: @@ -388,21 +394,26 @@ class Artifacts(object): self._task_artifact_list.remove(artifact) break - # check that the file to upload exists - local_filename = Path(local_filename).absolute() - if not local_filename.exists() or not local_filename.is_file(): - LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format( - local_filename.as_posix())) - return False + if not local_filename: + file_size = None + file_hash = None + else: + # check that the file to upload exists + local_filename = Path(local_filename).absolute() + if not local_filename.exists() or not local_filename.is_file(): + LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format( + local_filename.as_posix())) + return False + + file_hash, _ = self.sha256sum(local_filename.as_posix()) + file_size = local_filename.stat().st_size + + uri = self._upload_local_file(local_filename, name, + delete_after_upload=delete_after_upload, + override_filename=override_filename_in_uri, + override_filename_ext=override_filename_ext_in_uri) - file_hash, _ = self.sha256sum(local_filename.as_posix()) timestamp = int(time()) - file_size = local_filename.stat().st_size - - uri = self._upload_local_file(local_filename, name, - delete_after_upload=delete_after_upload, - override_filename=override_filename_in_uri, - override_filename_ext=override_filename_ext_in_uri) artifact = tasks.Artifact(key=name, type=artifact_type, uri=uri, diff --git a/trains/storage/helper.py b/trains/storage/helper.py index 8fa0c844..f3bfe7a5 100644 --- a/trains/storage/helper.py +++ b/trains/storage/helper.py @@ -2,6 +2,7 @@ from __future__ import with_statement import errno import getpass +import itertools import json import os import shutil @@ -342,7 +343,7 @@ class StorageHelper(object): config=self._conf ) - elif self._scheme in ('http', 'https'): + elif self._scheme in _HttpDriver.schemes: self._driver = _HttpDriver(retries=retries) self._container = self._driver.get_container(container_name=self._base_url) else: # elif self._scheme == 'file': @@ -950,6 +951,8 @@ class _HttpDriver(_Driver): timeout = (5.0, 30.) min_kbps_speed = 50 + schemes = ('http', 'https') + class _Container(object): _default_backend_session = None _default_files_server_host = None @@ -2235,3 +2238,16 @@ class _FileStorageDriver(_Driver): def test_upload(self, test_path, config, **kwargs): return True + + +driver_schemes = set( + filter( + None, + itertools.chain( + (getattr(cls, "scheme", None) for cls in _Driver.__subclasses__()), + *(getattr(cls, "schemes", []) for cls in _Driver.__subclasses__()) + ) + ) +) + +remote_driver_schemes = driver_schemes - {_FileStorageDriver.scheme}