From 7bf208eb085ea3de7bd603ab02822dae36ae166c Mon Sep 17 00:00:00 2001
From: Omer Moran <37040617+OmerM25@users.noreply.github.com>
Date: Mon, 2 Nov 2020 18:39:06 +0200
Subject: [PATCH] Added synchronous support for upload_artifact() (#231)

Add synchronous support for Artifacts.upload_artifact()
---
 trains/binding/artifacts.py | 23 ++++++++++++++++-------
 trains/task.py              |  8 ++++++--
 2 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py
index 192866f1..a22788e2 100644
--- a/trains/binding/artifacts.py
+++ b/trains/binding/artifacts.py
@@ -305,7 +305,7 @@ class Artifacts(object):
         self.flush()
 
     def upload_artifact(self, name, artifact_object=None, metadata=None, preview=None,
-                        delete_after_upload=False, auto_pickle=True):
+                        delete_after_upload=False, auto_pickle=True, wait_on_upload=False):
         # type: (str, Optional[object], Optional[dict], Optional[str], bool, bool) -> bool
         if not Session.check_min_api_version('2.3'):
             LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
@@ -538,7 +538,8 @@ class Artifacts(object):
             uri = self._upload_local_file(local_filename, name,
                                           delete_after_upload=delete_after_upload,
                                           override_filename=override_filename_in_uri,
-                                          override_filename_ext=override_filename_ext_in_uri)
+                                          override_filename_ext=override_filename_ext_in_uri,
+                                          wait_on_upload=wait_on_upload)
 
         timestamp = int(time())
 
@@ -685,12 +686,15 @@ class Artifacts(object):
             self._task.set_artifacts(self._task_artifact_list)
 
     def _upload_local_file(
-            self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None
+            self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None,
+            wait_on_upload=False
     ):
-        # type: (str, str, bool, Optional[str], Optional[str]) -> str
+        # type: (str, str, bool, Optional[str], Optional[str], Optional[bool]) -> str
         """
         Upload local file and return uri of the uploaded file (uploading in the background)
         """
+        from trains.storage import StorageManager
+
         upload_uri = self._task.output_uri or self._task.get_logger().get_default_upload_destination()
         if not isinstance(local_file, Path):
             local_file = Path(local_file)
@@ -701,13 +705,18 @@ class Artifacts(object):
                          override_filename=override_filename,
                          override_filename_ext=override_filename_ext,
                          override_storage_key_prefix=self._get_storage_uri_prefix())
-        _, uri = ev.get_target_full_upload_uri(upload_uri)
+        _, uri = ev.get_target_full_upload_uri(upload_uri, quote_uri=False)
 
         # send for upload
         # noinspection PyProtectedMember
-        self._task._reporter._report(ev)
+        if wait_on_upload:
+            StorageManager.upload_file(local_file, uri)
+        else:
+            self._task._reporter._report(ev)
 
-        return uri
+        _, quoted_uri = ev.get_target_full_upload_uri(upload_uri)
+
+        return quoted_uri
 
     def _get_statistics(self, artifacts_dict=None):
         # type: (Optional[Dict[str, Artifact]]) -> str
diff --git a/trains/task.py b/trains/task.py
index 5c9e2107..2b67146d 100644
--- a/trains/task.py
+++ b/trains/task.py
@@ -1286,6 +1286,7 @@ class Task(_Task):
         delete_after_upload=False,  # type: bool
         auto_pickle=True,  # type: bool
         preview=None,  # type: Any
+        wait_on_upload=False,  # type: bool
     ):
         # type: (...) -> bool
         """
@@ -1320,6 +1321,9 @@ class Task(_Task):
 
         :param Any preview: The artifact preview
 
+        :param bool wait_on_upload: Whether or not the upload should be synchronous, forcing the upload to complete
+            before continuing.
+
         :return: The status of the upload.
 
         - ``True`` - Upload succeeded.
@@ -1328,8 +1332,8 @@ class Task(_Task):
         :raise: If the artifact object type is not supported, raise a ``ValueError``.
         """
         return self._artifacts_manager.upload_artifact(
-            name=name, artifact_object=artifact_object, metadata=metadata,
-            delete_after_upload=delete_after_upload, auto_pickle=auto_pickle, preview=preview)
+            name=name, artifact_object=artifact_object, metadata=metadata, delete_after_upload=delete_after_upload,
+            auto_pickle=auto_pickle, preview=preview, wait_on_upload=wait_on_upload)
 
     def get_models(self):
         # type: () -> Dict[str, Sequence[Model]]