diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py index 398056c7..bd7cc682 100644 --- a/trains/binding/artifacts.py +++ b/trains/binding/artifacts.py @@ -1,4 +1,5 @@ import weakref +from copy import deepcopy import numpy as np import hashlib @@ -29,7 +30,7 @@ class Artifacts(object): super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs) self._artifacts_manager = artifacts_manager # list of artifacts we should not upload (by name & weak-reference) - self.local_artifacts = {} + self.artifact_metadata = {} def __setitem__(self, key, value): # check that value is of type pandas @@ -39,15 +40,17 @@ class Artifacts(object): if self._artifacts_manager: self._artifacts_manager.flush() else: - raise ValueError('Artifacts currently supports pandas.DataFrame objects only') + raise ValueError('Artifacts currently support pandas.DataFrame objects only') - def disable_upload(self, name): - if name in self.keys(): - self.local_artifacts[name] = weakref.ref(self.get(name)) + def unregister_artifact(self, name): + self.artifact_metadata.pop(name, None) + self.pop(name, None) - def do_upload(self, name): - # return True is this artifact should be uploaded - return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name) + def add_metadata(self, name, metadata): + self.artifact_metadata[name] = deepcopy(metadata) + + def get_metadata(self, name): + return self.artifact_metadata.get(name) @property def artifacts(self): @@ -63,6 +66,7 @@ class Artifacts(object): # dictionary needs to signal the Artifacts base on changes self._artifacts_dict = self._ProxyDictWrite(self) self._last_artifacts_upload = {} + self._unregister_request = set() self._thread = None self._flush_event = Event() self._exit_flag = False @@ -70,12 +74,16 @@ class Artifacts(object): self._summary = '' self._temp_folder = [] - def add_artifact(self, name, artifact, upload=True): + def register_artifact(self, name, artifact, metadata=None): # currently we support pandas.DataFrame (which we will upload as csv.gz) - # or numpy array, which we will upload as npz self._artifacts_dict[name] = artifact - if not upload: - self._artifacts_dict.disable_upload(name) + if metadata: + self._artifacts_dict.add_metadata(name, metadata) + + def unregister_artifact(self, name): + # Remove artifact from the watch list + self._unregister_request.add(name) + self.flush() def flush(self): # start the thread if it hasn't already: @@ -111,36 +119,47 @@ class Artifacts(object): self._flush_event.wait(self._flush_frequency_sec) self._flush_event.clear() try: - self._upload_artifacts() + artifact_keys = list(self._artifacts_dict.keys()) + for name in artifact_keys: + self._upload_artifacts(name) except Exception as e: LoggerRoot.get_base_logger().warning(str(e)) # create summary self._summary = self._get_statistics() - def _upload_artifacts(self): + def _upload_artifacts(self, name): logger = self._task.get_logger() - for name, artifact in self._artifacts_dict.items(): - if not self._artifacts_dict.do_upload(name): - # only register artifacts, and leave, TBD - continue - local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute() - if local_csv.exists(): - # we are still uploading... get another temp folder - local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute() - artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression) - current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32) - if name in self._last_artifacts_upload: - previous_sha2 = self._last_artifacts_upload[name] - if previous_sha2 == current_sha2: - # nothing to do, we can skip the upload - local_csv.unlink() - continue - self._last_artifacts_upload[name] = current_sha2 - # now upload and delete at the end. - logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(), - delete_after_upload=True, iteration=self._task.get_last_iteration(), - max_image_history=2) + artifact = self._artifacts_dict.get(name) + + # remove from artifacts watch list + if name in self._unregister_request: + try: + self._unregister_request.remove(name) + except KeyError: + pass + self._artifacts_dict.unregister_artifact(name) + + if artifact is None: + return + + local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute() + if local_csv.exists(): + # we are still uploading... get another temp folder + local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute() + artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression) + current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32) + if name in self._last_artifacts_upload: + previous_sha2 = self._last_artifacts_upload[name] + if previous_sha2 == current_sha2: + # nothing to do, we can skip the upload + local_csv.unlink() + return + self._last_artifacts_upload[name] = current_sha2 + # now upload and delete at the end. + logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(), + delete_after_upload=True, iteration=self._task.get_last_iteration(), + max_image_history=2) def _get_statistics(self): summary = ''