Added synchronous support for upload_artifact() ()

Add synchronous support for Artifacts.upload_artifact()
This commit is contained in:
Omer Moran 2020-11-02 18:39:06 +02:00 committed by GitHub
parent 4d1582d077
commit 7bf208eb08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 9 deletions
trains

View File

@ -305,7 +305,7 @@ class Artifacts(object):
self.flush() self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None, def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None,
delete_after_upload=False, auto_pickle=True): delete_after_upload=False, auto_pickle=True, wait_on_upload=False):
# type: (str, Optional[object], Optional[dict], Optional[str], bool, bool) -> bool # type: (str, Optional[object], Optional[dict], Optional[str], bool, bool) -> bool
if not Session.check_min_api_version('2.3'): if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, ' LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
@ -538,7 +538,8 @@ class Artifacts(object):
uri = self._upload_local_file(local_filename, name, uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload, delete_after_upload=delete_after_upload,
override_filename=override_filename_in_uri, override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri) override_filename_ext=override_filename_ext_in_uri,
wait_on_upload=wait_on_upload)
timestamp = int(time()) timestamp = int(time())
@ -685,12 +686,15 @@ class Artifacts(object):
self._task.set_artifacts(self._task_artifact_list) self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file( def _upload_local_file(
self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None,
wait_on_upload=False
): ):
# type: (str, str, bool, Optional[str], Optional[str]) -> str # type: (str, str, bool, Optional[str], Optional[str], Optional[bool]) -> str
""" """
Upload local file and return uri of the uploaded file (uploading in the background) Upload local file and return uri of the uploaded file (uploading in the background)
""" """
from trains.storage import StorageManager
upload_uri = self._task.output_uri or self._task.get_logger().get_default_upload_destination() upload_uri = self._task.output_uri or self._task.get_logger().get_default_upload_destination()
if not isinstance(local_file, Path): if not isinstance(local_file, Path):
local_file = Path(local_file) local_file = Path(local_file)
@ -701,13 +705,18 @@ class Artifacts(object):
override_filename=override_filename, override_filename=override_filename,
override_filename_ext=override_filename_ext, override_filename_ext=override_filename_ext,
override_storage_key_prefix=self._get_storage_uri_prefix()) override_storage_key_prefix=self._get_storage_uri_prefix())
_, uri = ev.get_target_full_upload_uri(upload_uri) _, uri = ev.get_target_full_upload_uri(upload_uri, quote_uri=False)
# send for upload # send for upload
# noinspection PyProtectedMember # noinspection PyProtectedMember
if wait_on_upload:
StorageManager.upload_file(local_file, uri)
else:
self._task._reporter._report(ev) self._task._reporter._report(ev)
return uri _, quoted_uri = ev.get_target_full_upload_uri(upload_uri)
return quoted_uri
def _get_statistics(self, artifacts_dict=None): def _get_statistics(self, artifacts_dict=None):
# type: (Optional[Dict[str, Artifact]]) -> str # type: (Optional[Dict[str, Artifact]]) -> str

View File

@ -1286,6 +1286,7 @@ class Task(_Task):
delete_after_upload=False, # type: bool delete_after_upload=False, # type: bool
auto_pickle=True, # type: bool auto_pickle=True, # type: bool
preview=None, # type: Any preview=None, # type: Any
wait_on_upload=False, # type: bool
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -1320,6 +1321,9 @@ class Task(_Task):
:param Any preview: The artifact preview :param Any preview: The artifact preview
:param bool wait_on_upload: Whether or not the upload should be synchronous, forcing the upload to complete
before continuing.
:return: The status of the upload. :return: The status of the upload.
- ``True`` - Upload succeeded. - ``True`` - Upload succeeded.
@ -1328,8 +1332,8 @@ class Task(_Task):
:raise: If the artifact object type is not supported, raise a ``ValueError``. :raise: If the artifact object type is not supported, raise a ``ValueError``.
""" """
return self._artifacts_manager.upload_artifact( return self._artifacts_manager.upload_artifact(
name=name, artifact_object=artifact_object, metadata=metadata, name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
delete_after_upload=delete_after_upload, auto_pickle=auto_pickle, preview=preview) auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload)
def get_models(self): def get_models(self):
# type: () -> Dict[str, Sequence[Model]] # type: () -> Dict[str, Sequence[Model]]