mirror of
https://github.com/clearml/clearml
synced 2025-03-03 02:32:11 +00:00
Improve artifacts support
This commit is contained in:
parent
7cb7d891b3
commit
dc632e160f
@ -1,4 +1,5 @@
|
||||
import weakref
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import hashlib
|
||||
@ -29,7 +30,7 @@ class Artifacts(object):
|
||||
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
|
||||
self._artifacts_manager = artifacts_manager
|
||||
# list of artifacts we should not upload (by name & weak-reference)
|
||||
self.local_artifacts = {}
|
||||
self.artifact_metadata = {}
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# check that value is of type pandas
|
||||
@ -39,15 +40,17 @@ class Artifacts(object):
|
||||
if self._artifacts_manager:
|
||||
self._artifacts_manager.flush()
|
||||
else:
|
||||
raise ValueError('Artifacts currently supports pandas.DataFrame objects only')
|
||||
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
|
||||
|
||||
def disable_upload(self, name):
|
||||
if name in self.keys():
|
||||
self.local_artifacts[name] = weakref.ref(self.get(name))
|
||||
def unregister_artifact(self, name):
|
||||
self.artifact_metadata.pop(name, None)
|
||||
self.pop(name, None)
|
||||
|
||||
def do_upload(self, name):
|
||||
# return True is this artifact should be uploaded
|
||||
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
|
||||
def add_metadata(self, name, metadata):
|
||||
self.artifact_metadata[name] = deepcopy(metadata)
|
||||
|
||||
def get_metadata(self, name):
|
||||
return self.artifact_metadata.get(name)
|
||||
|
||||
@property
|
||||
def artifacts(self):
|
||||
@ -63,6 +66,7 @@ class Artifacts(object):
|
||||
# dictionary needs to signal the Artifacts base on changes
|
||||
self._artifacts_dict = self._ProxyDictWrite(self)
|
||||
self._last_artifacts_upload = {}
|
||||
self._unregister_request = set()
|
||||
self._thread = None
|
||||
self._flush_event = Event()
|
||||
self._exit_flag = False
|
||||
@ -70,12 +74,16 @@ class Artifacts(object):
|
||||
self._summary = ''
|
||||
self._temp_folder = []
|
||||
|
||||
def add_artifact(self, name, artifact, upload=True):
|
||||
def register_artifact(self, name, artifact, metadata=None):
|
||||
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
||||
# or numpy array, which we will upload as npz
|
||||
self._artifacts_dict[name] = artifact
|
||||
if not upload:
|
||||
self._artifacts_dict.disable_upload(name)
|
||||
if metadata:
|
||||
self._artifacts_dict.add_metadata(name, metadata)
|
||||
|
||||
def unregister_artifact(self, name):
|
||||
# Remove artifact from the watch list
|
||||
self._unregister_request.add(name)
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
# start the thread if it hasn't already:
|
||||
@ -111,36 +119,47 @@ class Artifacts(object):
|
||||
self._flush_event.wait(self._flush_frequency_sec)
|
||||
self._flush_event.clear()
|
||||
try:
|
||||
self._upload_artifacts()
|
||||
artifact_keys = list(self._artifacts_dict.keys())
|
||||
for name in artifact_keys:
|
||||
self._upload_artifacts(name)
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(str(e))
|
||||
|
||||
# create summary
|
||||
self._summary = self._get_statistics()
|
||||
|
||||
def _upload_artifacts(self):
|
||||
def _upload_artifacts(self, name):
|
||||
logger = self._task.get_logger()
|
||||
for name, artifact in self._artifacts_dict.items():
|
||||
if not self._artifacts_dict.do_upload(name):
|
||||
# only register artifacts, and leave, TBD
|
||||
continue
|
||||
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
||||
if local_csv.exists():
|
||||
# we are still uploading... get another temp folder
|
||||
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
||||
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
||||
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
||||
if name in self._last_artifacts_upload:
|
||||
previous_sha2 = self._last_artifacts_upload[name]
|
||||
if previous_sha2 == current_sha2:
|
||||
# nothing to do, we can skip the upload
|
||||
local_csv.unlink()
|
||||
continue
|
||||
self._last_artifacts_upload[name] = current_sha2
|
||||
# now upload and delete at the end.
|
||||
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
||||
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
||||
max_image_history=2)
|
||||
artifact = self._artifacts_dict.get(name)
|
||||
|
||||
# remove from artifacts watch list
|
||||
if name in self._unregister_request:
|
||||
try:
|
||||
self._unregister_request.remove(name)
|
||||
except KeyError:
|
||||
pass
|
||||
self._artifacts_dict.unregister_artifact(name)
|
||||
|
||||
if artifact is None:
|
||||
return
|
||||
|
||||
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
||||
if local_csv.exists():
|
||||
# we are still uploading... get another temp folder
|
||||
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
||||
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
||||
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
||||
if name in self._last_artifacts_upload:
|
||||
previous_sha2 = self._last_artifacts_upload[name]
|
||||
if previous_sha2 == current_sha2:
|
||||
# nothing to do, we can skip the upload
|
||||
local_csv.unlink()
|
||||
return
|
||||
self._last_artifacts_upload[name] = current_sha2
|
||||
# now upload and delete at the end.
|
||||
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
||||
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
||||
max_image_history=2)
|
||||
|
||||
def _get_statistics(self):
|
||||
summary = ''
|
||||
|
Loading…
Reference in New Issue
Block a user