Improve artifacts support

This commit is contained in:
allegroai 2019-09-07 23:27:16 +03:00
parent 7cb7d891b3
commit dc632e160f

View File

@ -1,4 +1,5 @@
import weakref import weakref
from copy import deepcopy
import numpy as np import numpy as np
import hashlib import hashlib
@ -29,7 +30,7 @@ class Artifacts(object):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs) super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference) # list of artifacts we should not upload (by name & weak-reference)
self.local_artifacts = {} self.artifact_metadata = {}
def __setitem__(self, key, value): def __setitem__(self, key, value):
# check that value is of type pandas # check that value is of type pandas
@ -39,15 +40,17 @@ class Artifacts(object):
if self._artifacts_manager: if self._artifacts_manager:
self._artifacts_manager.flush() self._artifacts_manager.flush()
else: else:
raise ValueError('Artifacts currently supports pandas.DataFrame objects only') raise ValueError('Artifacts currently support pandas.DataFrame objects only')
def disable_upload(self, name): def unregister_artifact(self, name):
if name in self.keys(): self.artifact_metadata.pop(name, None)
self.local_artifacts[name] = weakref.ref(self.get(name)) self.pop(name, None)
def do_upload(self, name): def add_metadata(self, name, metadata):
# return True is this artifact should be uploaded self.artifact_metadata[name] = deepcopy(metadata)
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
def get_metadata(self, name):
return self.artifact_metadata.get(name)
@property @property
def artifacts(self): def artifacts(self):
@ -63,6 +66,7 @@ class Artifacts(object):
# dictionary needs to signal the Artifacts base on changes # dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self) self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {} self._last_artifacts_upload = {}
self._unregister_request = set()
self._thread = None self._thread = None
self._flush_event = Event() self._flush_event = Event()
self._exit_flag = False self._exit_flag = False
@ -70,12 +74,16 @@ class Artifacts(object):
self._summary = '' self._summary = ''
self._temp_folder = [] self._temp_folder = []
def add_artifact(self, name, artifact, upload=True): def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz) # currently we support pandas.DataFrame (which we will upload as csv.gz)
# or numpy array, which we will upload as npz
self._artifacts_dict[name] = artifact self._artifacts_dict[name] = artifact
if not upload: if metadata:
self._artifacts_dict.disable_upload(name) self._artifacts_dict.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
def flush(self): def flush(self):
# start the thread if it hasn't already: # start the thread if it hasn't already:
@ -111,36 +119,47 @@ class Artifacts(object):
self._flush_event.wait(self._flush_frequency_sec) self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear() self._flush_event.clear()
try: try:
self._upload_artifacts() artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
self._upload_artifacts(name)
except Exception as e: except Exception as e:
LoggerRoot.get_base_logger().warning(str(e)) LoggerRoot.get_base_logger().warning(str(e))
# create summary # create summary
self._summary = self._get_statistics() self._summary = self._get_statistics()
def _upload_artifacts(self): def _upload_artifacts(self, name):
logger = self._task.get_logger() logger = self._task.get_logger()
for name, artifact in self._artifacts_dict.items(): artifact = self._artifacts_dict.get(name)
if not self._artifacts_dict.do_upload(name):
# only register artifacts, and leave, TBD # remove from artifacts watch list
continue if name in self._unregister_request:
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute() try:
if local_csv.exists(): self._unregister_request.remove(name)
# we are still uploading... get another temp folder except KeyError:
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute() pass
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression) self._artifacts_dict.unregister_artifact(name)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload: if artifact is None:
previous_sha2 = self._last_artifacts_upload[name] return
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
local_csv.unlink() if local_csv.exists():
continue # we are still uploading... get another temp folder
self._last_artifacts_upload[name] = current_sha2 local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
# now upload and delete at the end. artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(), current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
delete_after_upload=True, iteration=self._task.get_last_iteration(), if name in self._last_artifacts_upload:
max_image_history=2) previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
return
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
def _get_statistics(self): def _get_statistics(self):
summary = '' summary = ''