mirror of
https://github.com/clearml/clearml
synced 2025-05-03 20:41:00 +00:00
Add raise_on_error (default=False) argument to Artifat.get_local_copy()
This commit is contained in:
parent
03bf764dc7
commit
2393ac5f7f
@ -17,6 +17,7 @@ import six
|
||||
from PIL import Image
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
from typing import Dict, Union, Optional, Any
|
||||
|
||||
from ..backend_api import Session
|
||||
from ..backend_api.services import tasks
|
||||
@ -41,6 +42,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
:return: url of uploaded artifact
|
||||
"""
|
||||
@ -48,6 +50,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
:return: name of artifact
|
||||
"""
|
||||
@ -55,6 +58,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
# type: () -> int
|
||||
"""
|
||||
:return: size in bytes of artifact
|
||||
"""
|
||||
@ -62,6 +66,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
:return: type (str) of of artifact
|
||||
"""
|
||||
@ -69,6 +74,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
# type: () -> Union["input", "output"]
|
||||
"""
|
||||
:return: mode (str) of of artifact. either "input" or "output"
|
||||
"""
|
||||
@ -76,6 +82,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
:return: SHA2 hash (str) of of artifact content.
|
||||
"""
|
||||
@ -83,6 +90,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
# type: () -> datetime
|
||||
"""
|
||||
:return: Timestamp (datetime) of uploaded artifact.
|
||||
"""
|
||||
@ -90,6 +98,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
# type: () -> Optional[Dict[str, str]]
|
||||
"""
|
||||
:return: Key/Value dictionary attached to artifact.
|
||||
"""
|
||||
@ -97,6 +106,7 @@ class Artifact(object):
|
||||
|
||||
@property
|
||||
def preview(self):
|
||||
# type: () -> str
|
||||
"""
|
||||
:return: string (str) representation of the artifact.
|
||||
"""
|
||||
@ -120,6 +130,7 @@ class Artifact(object):
|
||||
self._object = None
|
||||
|
||||
def get(self):
|
||||
# type: () -> Any
|
||||
"""
|
||||
Return an object constructed from the artifact file
|
||||
|
||||
@ -131,7 +142,7 @@ class Artifact(object):
|
||||
if self._object:
|
||||
return self._object
|
||||
|
||||
local_file = self.get_local_copy()
|
||||
local_file = self.get_local_copy(raise_on_error=True)
|
||||
|
||||
if self.type == 'numpy' and np:
|
||||
self._object = np.load(local_file)[self.name]
|
||||
@ -150,18 +161,26 @@ class Artifact(object):
|
||||
|
||||
return self._object
|
||||
|
||||
def get_local_copy(self, extract_archive=True):
|
||||
def get_local_copy(self, extract_archive=True, raise_on_error=False):
|
||||
# type: (bool, bool) -> str
|
||||
"""
|
||||
:param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
|
||||
The returned path will be a temporary folder containing the archive content
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
:return: a local path to a downloaded copy of the artifact
|
||||
"""
|
||||
from trains.storage import StorageManager
|
||||
return StorageManager.get_local_copy(
|
||||
local_copy = StorageManager.get_local_copy(
|
||||
remote_url=self.url,
|
||||
extract_archive=extract_archive and self.type == 'archive',
|
||||
name=self.name
|
||||
)
|
||||
if raise_on_error and local_copy is None:
|
||||
raise ValueError(
|
||||
"Could not retrieve a local copy of artifact {}, failed downloading {}".format(self.name, self.url))
|
||||
|
||||
return local_copy
|
||||
|
||||
def __repr__(self):
|
||||
return str({'name': self.name, 'size': self.size, 'type': self.type, 'mode': self.mode, 'url': self.url,
|
||||
@ -216,10 +235,12 @@ class Artifacts(object):
|
||||
|
||||
@property
|
||||
def registered_artifacts(self):
|
||||
# type: () -> Dict[str, Artifact]
|
||||
return self._artifacts_container
|
||||
|
||||
@property
|
||||
def summary(self):
|
||||
# type: () -> str
|
||||
return self._summary
|
||||
|
||||
def __init__(self, task):
|
||||
@ -239,6 +260,7 @@ class Artifacts(object):
|
||||
self._storage_prefix = None
|
||||
|
||||
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
|
||||
# type: (str, object, Optional[dict], bool) -> ()
|
||||
"""
|
||||
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
|
||||
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
|
||||
@ -249,17 +271,21 @@ class Artifacts(object):
|
||||
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
||||
if name in self._artifacts_container:
|
||||
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
|
||||
self._artifacts_container.add_hash_columns(name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns))
|
||||
self._artifacts_container.add_hash_columns(
|
||||
name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns)
|
||||
)
|
||||
self._artifacts_container[name] = artifact
|
||||
if metadata:
|
||||
self._artifacts_container.add_metadata(name, metadata)
|
||||
|
||||
def unregister_artifact(self, name):
|
||||
# type: (str) -> ()
|
||||
# Remove artifact from the watch list
|
||||
self._unregister_request.add(name)
|
||||
self.flush()
|
||||
|
||||
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
|
||||
# type: (str, Optional[object], Optional[dict], bool) -> bool
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
|
||||
'please upgrade to the latest server version')
|
||||
@ -314,7 +340,10 @@ class Artifacts(object):
|
||||
os.close(fd)
|
||||
artifact_type_data.preview = preview
|
||||
delete_after_upload = True
|
||||
elif isinstance(artifact_object, six.string_types) and urlparse(artifact_object).scheme in remote_driver_schemes:
|
||||
elif (
|
||||
isinstance(artifact_object, six.string_types)
|
||||
and urlparse(artifact_object).scheme in remote_driver_schemes
|
||||
):
|
||||
# we should not upload this, just register
|
||||
local_filename = None
|
||||
uri = artifact_object
|
||||
@ -344,7 +373,9 @@ class Artifacts(object):
|
||||
files = list(Path(folder).rglob(wildcard))
|
||||
override_filename_ext_in_uri = '.zip'
|
||||
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
|
||||
fd, zip_file = mkstemp(prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri)
|
||||
fd, zip_file = mkstemp(
|
||||
prefix=quote(folder.parts[-1], safe="")+'.', suffix=override_filename_ext_in_uri
|
||||
)
|
||||
try:
|
||||
artifact_type_data.content_type = 'application/zip'
|
||||
artifact_type_data.preview = 'Archive content {}:\n'.format(artifact_object.as_posix())
|
||||
@ -428,12 +459,14 @@ class Artifacts(object):
|
||||
return True
|
||||
|
||||
def flush(self):
|
||||
# type: () -> ()
|
||||
# start the thread if it hasn't already:
|
||||
self._start()
|
||||
# flush the current state of all artifacts
|
||||
self._flush_event.set()
|
||||
|
||||
def stop(self, wait=True):
|
||||
# type: (str) -> ()
|
||||
# stop the daemon thread and quit
|
||||
# wait until thread exists
|
||||
self._exit_flag = True
|
||||
@ -449,6 +482,7 @@ class Artifacts(object):
|
||||
pass
|
||||
|
||||
def _start(self):
|
||||
# type: () -> ()
|
||||
""" Start daemon thread if any artifacts are registered and thread is not up yet """
|
||||
if not self._thread and self._artifacts_container:
|
||||
# start the daemon thread
|
||||
@ -458,6 +492,7 @@ class Artifacts(object):
|
||||
self._thread.start()
|
||||
|
||||
def _daemon(self):
|
||||
# type: () -> ()
|
||||
while not self._exit_flag:
|
||||
self._flush_event.wait(self._flush_frequency_sec)
|
||||
self._flush_event.clear()
|
||||
@ -472,6 +507,7 @@ class Artifacts(object):
|
||||
self._summary = self._get_statistics()
|
||||
|
||||
def _upload_data_audit_artifacts(self, name):
|
||||
# type: (str) -> ()
|
||||
logger = self._task.get_logger()
|
||||
pd_artifact = self._artifacts_container.get(name)
|
||||
pd_metadata = self._artifacts_container.get_metadata(name)
|
||||
@ -546,9 +582,10 @@ class Artifacts(object):
|
||||
|
||||
self._task.set_artifacts(self._task_artifact_list)
|
||||
|
||||
def _upload_local_file(self, local_file, name, delete_after_upload=False,
|
||||
override_filename=None,
|
||||
override_filename_ext=None):
|
||||
def _upload_local_file(
|
||||
self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None
|
||||
):
|
||||
# type: (str, str, bool, Optional[str], Optional[str]) -> str
|
||||
"""
|
||||
Upload local file and return uri of the uploaded file (uploading in the background)
|
||||
"""
|
||||
@ -570,6 +607,7 @@ class Artifacts(object):
|
||||
return uri
|
||||
|
||||
def _get_statistics(self, artifacts_dict=None):
|
||||
# type: (Optional[Dict[str, Artifact]]) -> str
|
||||
summary = ''
|
||||
artifacts_dict = artifacts_dict or self._artifacts_container
|
||||
thread_pool = ThreadPool()
|
||||
@ -632,6 +670,7 @@ class Artifacts(object):
|
||||
return summary
|
||||
|
||||
def _get_temp_folder(self, force_new=False):
|
||||
# type: (bool) -> str
|
||||
if force_new or not self._temp_folder:
|
||||
new_temp = mkdtemp(prefix='artifacts_')
|
||||
self._temp_folder.append(new_temp)
|
||||
@ -639,12 +678,14 @@ class Artifacts(object):
|
||||
return self._temp_folder[0]
|
||||
|
||||
def _get_storage_uri_prefix(self):
|
||||
# type: () -> str
|
||||
if not self._storage_prefix:
|
||||
self._storage_prefix = self._task._get_output_destination_suffix()
|
||||
return self._storage_prefix
|
||||
|
||||
@staticmethod
|
||||
def sha256sum(filename, skip_header=0):
|
||||
# type: (str, int) -> (Optional[str], Optional[str])
|
||||
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
||||
# because sometimes that is the only change
|
||||
h = hashlib.sha256()
|
||||
|
Loading…
Reference in New Issue
Block a user