Improve artifacts support

This commit is contained in:
allegroai 2019-09-07 23:27:16 +03:00
parent 7cb7d891b3
commit dc632e160f

View File

@ -1,4 +1,5 @@
import weakref import weakref
from copy import deepcopy
import numpy as np import numpy as np
import hashlib import hashlib
@ -29,7 +30,7 @@ class Artifacts(object):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs) super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference) # list of artifacts we should not upload (by name & weak-reference)
self.local_artifacts = {} self.artifact_metadata = {}
def __setitem__(self, key, value): def __setitem__(self, key, value):
# check that value is of type pandas # check that value is of type pandas
@ -39,15 +40,17 @@ class Artifacts(object):
if self._artifacts_manager: if self._artifacts_manager:
self._artifacts_manager.flush() self._artifacts_manager.flush()
else: else:
raise ValueError('Artifacts currently supports pandas.DataFrame objects only') raise ValueError('Artifacts currently support pandas.DataFrame objects only')
def disable_upload(self, name): def unregister_artifact(self, name):
if name in self.keys(): self.artifact_metadata.pop(name, None)
self.local_artifacts[name] = weakref.ref(self.get(name)) self.pop(name, None)
def do_upload(self, name): def add_metadata(self, name, metadata):
# return True is this artifact should be uploaded self.artifact_metadata[name] = deepcopy(metadata)
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
def get_metadata(self, name):
return self.artifact_metadata.get(name)
@property @property
def artifacts(self): def artifacts(self):
@ -63,6 +66,7 @@ class Artifacts(object):
# dictionary needs to signal the Artifacts base on changes # dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self) self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {} self._last_artifacts_upload = {}
self._unregister_request = set()
self._thread = None self._thread = None
self._flush_event = Event() self._flush_event = Event()
self._exit_flag = False self._exit_flag = False
@ -70,12 +74,16 @@ class Artifacts(object):
self._summary = '' self._summary = ''
self._temp_folder = [] self._temp_folder = []
def add_artifact(self, name, artifact, upload=True): def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz) # currently we support pandas.DataFrame (which we will upload as csv.gz)
# or numpy array, which we will upload as npz
self._artifacts_dict[name] = artifact self._artifacts_dict[name] = artifact
if not upload: if metadata:
self._artifacts_dict.disable_upload(name) self._artifacts_dict.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
def flush(self): def flush(self):
# start the thread if it hasn't already: # start the thread if it hasn't already:
@ -111,19 +119,30 @@ class Artifacts(object):
self._flush_event.wait(self._flush_frequency_sec) self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear() self._flush_event.clear()
try: try:
self._upload_artifacts() artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
self._upload_artifacts(name)
except Exception as e: except Exception as e:
LoggerRoot.get_base_logger().warning(str(e)) LoggerRoot.get_base_logger().warning(str(e))
# create summary # create summary
self._summary = self._get_statistics() self._summary = self._get_statistics()
def _upload_artifacts(self): def _upload_artifacts(self, name):
logger = self._task.get_logger() logger = self._task.get_logger()
for name, artifact in self._artifacts_dict.items(): artifact = self._artifacts_dict.get(name)
if not self._artifacts_dict.do_upload(name):
# only register artifacts, and leave, TBD # remove from artifacts watch list
continue if name in self._unregister_request:
try:
self._unregister_request.remove(name)
except KeyError:
pass
self._artifacts_dict.unregister_artifact(name)
if artifact is None:
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute() local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists(): if local_csv.exists():
# we are still uploading... get another temp folder # we are still uploading... get another temp folder
@ -135,7 +154,7 @@ class Artifacts(object):
if previous_sha2 == current_sha2: if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload # nothing to do, we can skip the upload
local_csv.unlink() local_csv.unlink()
continue return
self._last_artifacts_upload[name] = current_sha2 self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end. # now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(), logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),