Add raise_on_error (default=False) argument to Artifat.get_local_copy()

This commit is contained in:
allegroai 2020-05-22 11:13:13 +03:00
parent 03bf764dc7
commit 2393ac5f7f

View File

@ -17,6 +17,7 @@ import six
from PIL import Image
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from typing import Dict, Union, Optional, Any
from ..backend_api import Session
from ..backend_api.services import tasks
@ -41,6 +42,7 @@ class Artifact(object):
@property
def url(self):
# type: () -> str
"""
:return: url of uploaded artifact
"""
@ -48,6 +50,7 @@ class Artifact(object):
@property
def name(self):
# type: () -> str
"""
:return: name of artifact
"""
@ -55,6 +58,7 @@ class Artifact(object):
@property
def size(self):
# type: () -> int
"""
:return: size in bytes of artifact
"""
@ -62,6 +66,7 @@ class Artifact(object):
@property
def type(self):
# type: () -> str
"""
:return: type (str) of of artifact
"""
@ -69,6 +74,7 @@ class Artifact(object):
@property
def mode(self):
# type: () -> Union["input", "output"]
"""
:return: mode (str) of of artifact. either "input" or "output"
"""
@ -76,6 +82,7 @@ class Artifact(object):
@property
def hash(self):
# type: () -> str
"""
:return: SHA2 hash (str) of of artifact content.
"""
@ -83,6 +90,7 @@ class Artifact(object):
@property
def timestamp(self):
# type: () -> datetime
"""
:return: Timestamp (datetime) of uploaded artifact.
"""
@ -90,6 +98,7 @@ class Artifact(object):
@property
def metadata(self):
# type: () -> Optional[Dict[str, str]]
"""
:return: Key/Value dictionary attached to artifact.
"""
@ -97,6 +106,7 @@ class Artifact(object):
@property
def preview(self):
# type: () -> str
"""
:return: string (str) representation of the artifact.
"""
@ -120,6 +130,7 @@ class Artifact(object):
self._object = None
def get(self):
# type: () -> Any
"""
Return an object constructed from the artifact file
@ -131,7 +142,7 @@ class Artifact(object):
if self._object:
return self._object
local_file = self.get_local_copy()
local_file = self.get_local_copy(raise_on_error=True)
if self.type == 'numpy' and np:
self._object = np.load(local_file)[self.name]
@ -150,18 +161,26 @@ class Artifact(object):
return self._object
def get_local_copy(self, extract_archive=True):
def get_local_copy(self, extract_archive=True, raise_on_error=False):
# type: (bool, bool) -> str
"""
:param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
The returned path will be a temporary folder containing the archive content
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return: a local path to a downloaded copy of the artifact
"""
from trains.storage import StorageManager
return StorageManager.get_local_copy(
local_copy = StorageManager.get_local_copy(
remote_url=self.url,
extract_archive=extract_archive and self.type == 'archive',
name=self.name
)
if raise_on_error and local_copy is None:
raise ValueError(
"Could not retrieve a local copy of artifact {}, failed downloading {}".format(self.name, self.url))
return local_copy
def __repr__(self):
return str({'name': self.name, 'size': self.size, 'type': self.type, 'mode': self.mode, 'url': self.url,
@ -216,10 +235,12 @@ class Artifacts(object):
@property
def registered_artifacts(self):
# type: () -> Dict[str, Artifact]
return self._artifacts_container
@property
def summary(self):
# type: () -> str
return self._summary
def __init__(self, task):
@ -239,6 +260,7 @@ class Artifacts(object):
self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, object, Optional[dict], bool) -> ()
"""
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
@ -249,17 +271,21 @@ class Artifacts(object):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
if name in self._artifacts_container:
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
self._artifacts_container.add_hash_columns(name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns))
self._artifacts_container.add_hash_columns(
name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns)
)
self._artifacts_container[name] = artifact
if metadata:
self._artifacts_container.add_metadata(name, metadata)
def unregister_artifact(self, name):
# type: (str) -> ()
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
# type: (str, Optional[object], Optional[dict], bool) -> bool
if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
'please upgrade to the latest server version')
@ -314,7 +340,10 @@ class Artifacts(object):
os.close(fd)
artifact_type_data.preview = preview
delete_after_upload = True
elif isinstance(artifact_object, six.string_types) and urlparse(artifact_object).scheme in remote_driver_schemes:
elif (
isinstance(artifact_object, six.string_types)
and urlparse(artifact_object).scheme in remote_driver_schemes
):
# we should not upload this, just register
local_filename = None
uri = artifact_object
@ -344,7 +373,9 @@ class Artifacts(object):
files = list(Path(folder).rglob(wildcard))
override_filename_ext_in_uri = '.zip'
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
fd, zip_file = mkstemp(prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri)
fd, zip_file = mkstemp(
prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri
)
try:
artifact_type_data.content_type = 'application/zip'
artifact_type_data.preview = 'Archive content {}:\n'.format(artifact_object.as_posix())
@ -428,12 +459,14 @@ class Artifacts(object):
return True
def flush(self):
# type: () -> ()
# start the thread if it hasn't already:
self._start()
# flush the current state of all artifacts
self._flush_event.set()
def stop(self, wait=True):
# type: (str) -> ()
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
@ -449,6 +482,7 @@ class Artifacts(object):
pass
def _start(self):
# type: () -> ()
""" Start daemon thread if any artifacts are registered and thread is not up yet """
if not self._thread and self._artifacts_container:
# start the daemon thread
@ -458,6 +492,7 @@ class Artifacts(object):
self._thread.start()
def _daemon(self):
# type: () -> ()
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
@ -472,6 +507,7 @@ class Artifacts(object):
self._summary = self._get_statistics()
def _upload_data_audit_artifacts(self, name):
# type: (str) -> ()
logger = self._task.get_logger()
pd_artifact = self._artifacts_container.get(name)
pd_metadata = self._artifacts_container.get_metadata(name)
@ -546,9 +582,10 @@ class Artifacts(object):
self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False,
override_filename=None,
override_filename_ext=None):
def _upload_local_file(
self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None
):
# type: (str, str, bool, Optional[str], Optional[str]) -> str
"""
Upload local file and return uri of the uploaded file (uploading in the background)
"""
@ -570,6 +607,7 @@ class Artifacts(object):
return uri
def _get_statistics(self, artifacts_dict=None):
# type: (Optional[Dict[str, Artifact]]) -> str
summary = ''
artifacts_dict = artifacts_dict or self._artifacts_container
thread_pool = ThreadPool()
@ -632,6 +670,7 @@ class Artifacts(object):
return summary
def _get_temp_folder(self, force_new=False):
# type: (bool) -> str
if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp)
@ -639,12 +678,14 @@ class Artifacts(object):
return self._temp_folder[0]
def _get_storage_uri_prefix(self):
# type: () -> str
if not self._storage_prefix:
self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix
@staticmethod
def sha256sum(filename, skip_header=0):
# type: (str, int) -> (Optional[str], Optional[str])
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()