Add Task.upload_artifact support for external URLs

This commit is contained in:
allegroai 2020-01-06 17:16:51 +02:00
parent 7820e0d14a
commit a169b43885
2 changed files with 44 additions and 17 deletions

View File

@ -274,6 +274,7 @@ class Artifacts(object):
artifact_type_data = tasks.ArtifactTypeData()
override_filename_in_uri = None
override_filename_ext_in_uri = None
uri = None
if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy'
@ -316,10 +317,15 @@ class Artifacts(object):
os.close(fd)
artifact_type_data.preview = preview
delete_after_upload = True
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
elif isinstance(artifact_object, six.string_types) and urlparse(artifact_object).scheme in remote_driver_schemes:
# we should not upload this, just register
local_filename = None
uri = artifact_object
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
elif isinstance(artifact_object, six.string_types + (Path,)):
# check if single file
if isinstance(artifact_object, six.string_types):
artifact_object = Path(artifact_object)
artifact_object = Path(artifact_object)
artifact_object.expanduser().absolute()
try:
@ -388,21 +394,26 @@ class Artifacts(object):
self._task_artifact_list.remove(artifact)
break
# check that the file to upload exists
local_filename = Path(local_filename).absolute()
if not local_filename.exists() or not local_filename.is_file():
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
local_filename.as_posix()))
return False
if not local_filename:
file_size = None
file_hash = None
else:
# check that the file to upload exists
local_filename = Path(local_filename).absolute()
if not local_filename.exists() or not local_filename.is_file():
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
local_filename.as_posix()))
return False
file_hash, _ = self.sha256sum(local_filename.as_posix())
file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
file_hash, _ = self.sha256sum(local_filename.as_posix())
timestamp = int(time())
file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
artifact = tasks.Artifact(key=name, type=artifact_type,
uri=uri,

View File

@ -2,6 +2,7 @@ from __future__ import with_statement
import errno
import getpass
import itertools
import json
import os
import shutil
@ -342,7 +343,7 @@ class StorageHelper(object):
config=self._conf
)
elif self._scheme in ('http', 'https'):
elif self._scheme in _HttpDriver.schemes:
self._driver = _HttpDriver(retries=retries)
self._container = self._driver.get_container(container_name=self._base_url)
else: # elif self._scheme == 'file':
@ -950,6 +951,8 @@ class _HttpDriver(_Driver):
timeout = (5.0, 30.)
min_kbps_speed = 50
schemes = ('http', 'https')
class _Container(object):
_default_backend_session = None
_default_files_server_host = None
@ -2235,3 +2238,16 @@ class _FileStorageDriver(_Driver):
def test_upload(self, test_path, config, **kwargs):
return True
driver_schemes = set(
filter(
None,
itertools.chain(
(getattr(cls, "scheme", None) for cls in _Driver.__subclasses__()),
*(getattr(cls, "schemes", []) for cls in _Driver.__subclasses__())
)
)
)
remote_driver_schemes = driver_schemes - {_FileStorageDriver.scheme}