Improve artifacts support

This commit is contained in:
allegroai 2019-09-07 23:27:16 +03:00
parent 7cb7d891b3
commit dc632e160f

View File

@ -1,4 +1,5 @@
import weakref
from copy import deepcopy
import numpy as np
import hashlib
@ -29,7 +30,7 @@ class Artifacts(object):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
self.local_artifacts = {}
self.artifact_metadata = {}
def __setitem__(self, key, value):
# check that value is of type pandas
@ -39,15 +40,17 @@ class Artifacts(object):
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently supports pandas.DataFrame objects only')
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
def disable_upload(self, name):
if name in self.keys():
self.local_artifacts[name] = weakref.ref(self.get(name))
def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
self.pop(name, None)
def do_upload(self, name):
# return True is this artifact should be uploaded
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
def add_metadata(self, name, metadata):
self.artifact_metadata[name] = deepcopy(metadata)
def get_metadata(self, name):
return self.artifact_metadata.get(name)
@property
def artifacts(self):
@ -63,6 +66,7 @@ class Artifacts(object):
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
self._unregister_request = set()
self._thread = None
self._flush_event = Event()
self._exit_flag = False
@ -70,12 +74,16 @@ class Artifacts(object):
self._summary = ''
self._temp_folder = []
def add_artifact(self, name, artifact, upload=True):
def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
# or numpy array, which we will upload as npz
self._artifacts_dict[name] = artifact
if not upload:
self._artifacts_dict.disable_upload(name)
if metadata:
self._artifacts_dict.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
def flush(self):
# start the thread if it hasn't already:
@ -111,36 +119,47 @@ class Artifacts(object):
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
try:
self._upload_artifacts()
artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
self._upload_artifacts(name)
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
def _upload_artifacts(self):
def _upload_artifacts(self, name):
logger = self._task.get_logger()
for name, artifact in self._artifacts_dict.items():
if not self._artifacts_dict.do_upload(name):
# only register artifacts, and leave, TBD
continue
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
continue
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
artifact = self._artifacts_dict.get(name)
# remove from artifacts watch list
if name in self._unregister_request:
try:
self._unregister_request.remove(name)
except KeyError:
pass
self._artifacts_dict.unregister_artifact(name)
if artifact is None:
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
return
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
def _get_statistics(self):
summary = ''