Add raise_on_error (default=False) argument to Artifat.get_local_copy()

This commit is contained in:
allegroai 2020-05-22 11:13:13 +03:00
parent 03bf764dc7
commit 2393ac5f7f

View File

@ -17,6 +17,7 @@ import six
from PIL import Image from PIL import Image
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import urlparse from six.moves.urllib.parse import urlparse
from typing import Dict, Union, Optional, Any
from ..backend_api import Session from ..backend_api import Session
from ..backend_api.services import tasks from ..backend_api.services import tasks
@ -41,6 +42,7 @@ class Artifact(object):
@property @property
def url(self): def url(self):
# type: () -> str
""" """
:return: url of uploaded artifact :return: url of uploaded artifact
""" """
@ -48,6 +50,7 @@ class Artifact(object):
@property @property
def name(self): def name(self):
# type: () -> str
""" """
:return: name of artifact :return: name of artifact
""" """
@ -55,6 +58,7 @@ class Artifact(object):
@property @property
def size(self): def size(self):
# type: () -> int
""" """
:return: size in bytes of artifact :return: size in bytes of artifact
""" """
@ -62,6 +66,7 @@ class Artifact(object):
@property @property
def type(self): def type(self):
# type: () -> str
""" """
:return: type (str) of of artifact :return: type (str) of of artifact
""" """
@ -69,6 +74,7 @@ class Artifact(object):
@property @property
def mode(self): def mode(self):
# type: () -> Union["input", "output"]
""" """
:return: mode (str) of of artifact. either "input" or "output" :return: mode (str) of of artifact. either "input" or "output"
""" """
@ -76,6 +82,7 @@ class Artifact(object):
@property @property
def hash(self): def hash(self):
# type: () -> str
""" """
:return: SHA2 hash (str) of of artifact content. :return: SHA2 hash (str) of of artifact content.
""" """
@ -83,6 +90,7 @@ class Artifact(object):
@property @property
def timestamp(self): def timestamp(self):
# type: () -> datetime
""" """
:return: Timestamp (datetime) of uploaded artifact. :return: Timestamp (datetime) of uploaded artifact.
""" """
@ -90,6 +98,7 @@ class Artifact(object):
@property @property
def metadata(self): def metadata(self):
# type: () -> Optional[Dict[str, str]]
""" """
:return: Key/Value dictionary attached to artifact. :return: Key/Value dictionary attached to artifact.
""" """
@ -97,6 +106,7 @@ class Artifact(object):
@property @property
def preview(self): def preview(self):
# type: () -> str
""" """
:return: string (str) representation of the artifact. :return: string (str) representation of the artifact.
""" """
@ -120,6 +130,7 @@ class Artifact(object):
self._object = None self._object = None
def get(self): def get(self):
# type: () -> Any
""" """
Return an object constructed from the artifact file Return an object constructed from the artifact file
@ -131,7 +142,7 @@ class Artifact(object):
if self._object: if self._object:
return self._object return self._object
local_file = self.get_local_copy() local_file = self.get_local_copy(raise_on_error=True)
if self.type == 'numpy' and np: if self.type == 'numpy' and np:
self._object = np.load(local_file)[self.name] self._object = np.load(local_file)[self.name]
@ -150,18 +161,26 @@ class Artifact(object):
return self._object return self._object
def get_local_copy(self, extract_archive=True): def get_local_copy(self, extract_archive=True, raise_on_error=False):
# type: (bool, bool) -> str
""" """
:param bool extract_archive: If True and artifact is of type 'archive' (compressed folder) :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
The returned path will be a temporary folder containing the archive content The returned path will be a temporary folder containing the archive content
:param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:return: a local path to a downloaded copy of the artifact :return: a local path to a downloaded copy of the artifact
""" """
from trains.storage import StorageManager from trains.storage import StorageManager
return StorageManager.get_local_copy( local_copy = StorageManager.get_local_copy(
remote_url=self.url, remote_url=self.url,
extract_archive=extract_archive and self.type == 'archive', extract_archive=extract_archive and self.type == 'archive',
name=self.name name=self.name
) )
if raise_on_error and local_copy is None:
raise ValueError(
"Could not retrieve a local copy of artifact {}, failed downloading {}".format(self.name, self.url))
return local_copy
def __repr__(self): def __repr__(self):
return str({'name': self.name, 'size': self.size, 'type': self.type, 'mode': self.mode, 'url': self.url, return str({'name': self.name, 'size': self.size, 'type': self.type, 'mode': self.mode, 'url': self.url,
@ -216,10 +235,12 @@ class Artifacts(object):
@property @property
def registered_artifacts(self): def registered_artifacts(self):
# type: () -> Dict[str, Artifact]
return self._artifacts_container return self._artifacts_container
@property @property
def summary(self): def summary(self):
# type: () -> str
return self._summary return self._summary
def __init__(self, task): def __init__(self, task):
@ -239,6 +260,7 @@ class Artifacts(object):
self._storage_prefix = None self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, object, Optional[dict], bool) -> ()
""" """
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists. :param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame :param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
@ -249,17 +271,21 @@ class Artifacts(object):
# currently we support pandas.DataFrame (which we will upload as csv.gz) # currently we support pandas.DataFrame (which we will upload as csv.gz)
if name in self._artifacts_container: if name in self._artifacts_container:
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name)) LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
self._artifacts_container.add_hash_columns(name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns)) self._artifacts_container.add_hash_columns(
name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns)
)
self._artifacts_container[name] = artifact self._artifacts_container[name] = artifact
if metadata: if metadata:
self._artifacts_container.add_metadata(name, metadata) self._artifacts_container.add_metadata(name, metadata)
def unregister_artifact(self, name): def unregister_artifact(self, name):
# type: (str) -> ()
# Remove artifact from the watch list # Remove artifact from the watch list
self._unregister_request.add(name) self._unregister_request.add(name)
self.flush() self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False): def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
# type: (str, Optional[object], Optional[dict], bool) -> bool
if not Session.check_min_api_version('2.3'): if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, ' LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
'please upgrade to the latest server version') 'please upgrade to the latest server version')
@ -314,7 +340,10 @@ class Artifacts(object):
os.close(fd) os.close(fd)
artifact_type_data.preview = preview artifact_type_data.preview = preview
delete_after_upload = True delete_after_upload = True
elif isinstance(artifact_object, six.string_types) and urlparse(artifact_object).scheme in remote_driver_schemes: elif (
isinstance(artifact_object, six.string_types)
and urlparse(artifact_object).scheme in remote_driver_schemes
):
# we should not upload this, just register # we should not upload this, just register
local_filename = None local_filename = None
uri = artifact_object uri = artifact_object
@ -344,7 +373,9 @@ class Artifacts(object):
files = list(Path(folder).rglob(wildcard)) files = list(Path(folder).rglob(wildcard))
override_filename_ext_in_uri = '.zip' override_filename_ext_in_uri = '.zip'
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
fd, zip_file = mkstemp(prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri) fd, zip_file = mkstemp(
prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri
)
try: try:
artifact_type_data.content_type = 'application/zip' artifact_type_data.content_type = 'application/zip'
artifact_type_data.preview = 'Archive content {}:\n'.format(artifact_object.as_posix()) artifact_type_data.preview = 'Archive content {}:\n'.format(artifact_object.as_posix())
@ -428,12 +459,14 @@ class Artifacts(object):
return True return True
def flush(self): def flush(self):
# type: () -> ()
# start the thread if it hasn't already: # start the thread if it hasn't already:
self._start() self._start()
# flush the current state of all artifacts # flush the current state of all artifacts
self._flush_event.set() self._flush_event.set()
def stop(self, wait=True): def stop(self, wait=True):
# type: (str) -> ()
# stop the daemon thread and quit # stop the daemon thread and quit
# wait until thread exists # wait until thread exists
self._exit_flag = True self._exit_flag = True
@ -449,6 +482,7 @@ class Artifacts(object):
pass pass
def _start(self): def _start(self):
# type: () -> ()
""" Start daemon thread if any artifacts are registered and thread is not up yet """ """ Start daemon thread if any artifacts are registered and thread is not up yet """
if not self._thread and self._artifacts_container: if not self._thread and self._artifacts_container:
# start the daemon thread # start the daemon thread
@ -458,6 +492,7 @@ class Artifacts(object):
self._thread.start() self._thread.start()
def _daemon(self): def _daemon(self):
# type: () -> ()
while not self._exit_flag: while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec) self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear() self._flush_event.clear()
@ -472,6 +507,7 @@ class Artifacts(object):
self._summary = self._get_statistics() self._summary = self._get_statistics()
def _upload_data_audit_artifacts(self, name): def _upload_data_audit_artifacts(self, name):
# type: (str) -> ()
logger = self._task.get_logger() logger = self._task.get_logger()
pd_artifact = self._artifacts_container.get(name) pd_artifact = self._artifacts_container.get(name)
pd_metadata = self._artifacts_container.get_metadata(name) pd_metadata = self._artifacts_container.get_metadata(name)
@ -546,9 +582,10 @@ class Artifacts(object):
self._task.set_artifacts(self._task_artifact_list) self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False, def _upload_local_file(
override_filename=None, self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None
override_filename_ext=None): ):
# type: (str, str, bool, Optional[str], Optional[str]) -> str
""" """
Upload local file and return uri of the uploaded file (uploading in the background) Upload local file and return uri of the uploaded file (uploading in the background)
""" """
@ -570,6 +607,7 @@ class Artifacts(object):
return uri return uri
def _get_statistics(self, artifacts_dict=None): def _get_statistics(self, artifacts_dict=None):
# type: (Optional[Dict[str, Artifact]]) -> str
summary = '' summary = ''
artifacts_dict = artifacts_dict or self._artifacts_container artifacts_dict = artifacts_dict or self._artifacts_container
thread_pool = ThreadPool() thread_pool = ThreadPool()
@ -632,6 +670,7 @@ class Artifacts(object):
return summary return summary
def _get_temp_folder(self, force_new=False): def _get_temp_folder(self, force_new=False):
# type: (bool) -> str
if force_new or not self._temp_folder: if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_') new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp) self._temp_folder.append(new_temp)
@ -639,12 +678,14 @@ class Artifacts(object):
return self._temp_folder[0] return self._temp_folder[0]
def _get_storage_uri_prefix(self): def _get_storage_uri_prefix(self):
# type: () -> str
if not self._storage_prefix: if not self._storage_prefix:
self._storage_prefix = self._task._get_output_destination_suffix() self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix return self._storage_prefix
@staticmethod @staticmethod
def sha256sum(filename, skip_header=0): def sha256sum(filename, skip_header=0):
# type: (str, int) -> (Optional[str], Optional[str])
# create sha2 of the file, notice we skip the header of the file (32 bytes) # create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change # because sometimes that is the only change
h = hashlib.sha256() h = hashlib.sha256()