import weakref import numpy as np import hashlib from tempfile import mkdtemp from threading import Thread, Event from multiprocessing.pool import ThreadPool from pathlib2 import Path from ..debugging.log import LoggerRoot try: import pandas as pd except ImportError: pd = None class Artifacts(object): _flush_frequency_sec = 300. # notice these two should match _save_format = '.csv.gz' _compression = 'gzip' # hashing constants _hash_block_size = 65536 class _ProxyDictWrite(dict): """ Dictionary wrapper that updates an arguments instance on any item set in the dictionary """ def __init__(self, artifacts_manager, *args, **kwargs): super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs) self._artifacts_manager = artifacts_manager # list of artifacts we should not upload (by name & weak-reference) self.local_artifacts = {} def __setitem__(self, key, value): # check that value is of type pandas if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)): super(Artifacts._ProxyDictWrite, self).__setitem__(key, value) if self._artifacts_manager: self._artifacts_manager.flush() else: raise ValueError('Artifacts currently supports pandas.DataFrame objects only') def disable_upload(self, name): if name in self.keys(): self.local_artifacts[name] = weakref.ref(self.get(name)) def do_upload(self, name): # return True is this artifact should be uploaded return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name) @property def artifacts(self): return self._artifacts_dict @property def summary(self): return self._summary def __init__(self, task): self._task = task # notice the double link, this important since the Artifact # dictionary needs to signal the Artifacts base on changes self._artifacts_dict = self._ProxyDictWrite(self) self._last_artifacts_upload = {} self._thread = None self._flush_event = Event() self._exit_flag = False self._thread_pool = ThreadPool() self._summary = '' self._temp_folder = [] def add_artifact(self, name, artifact, upload=True): # currently we support pandas.DataFrame (which we will upload as csv.gz) # or numpy array, which we will upload as npz self._artifacts_dict[name] = artifact if not upload: self._artifacts_dict.disable_upload(name) def flush(self): # start the thread if it hasn't already: self._start() # flush the current state of all artifacts self._flush_event.set() def stop(self, wait=True): # stop the daemon thread and quit # wait until thread exists self._exit_flag = True self._flush_event.set() if wait: if self._thread: self._thread.join() # remove all temp folders for f in self._temp_folder: try: Path(f).rmdir() except Exception: pass def _start(self): if not self._thread: # start the daemon thread self._flush_event.clear() self._thread = Thread(target=self._daemon) self._thread.daemon = True self._thread.start() def _daemon(self): while not self._exit_flag: self._flush_event.wait(self._flush_frequency_sec) self._flush_event.clear() try: self._upload_artifacts() except Exception as e: LoggerRoot.get_base_logger().warning(str(e)) # create summary self._summary = self._get_statistics() def _upload_artifacts(self): logger = self._task.get_logger() for name, artifact in self._artifacts_dict.items(): if not self._artifacts_dict.do_upload(name): # only register artifacts, and leave, TBD continue local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute() if local_csv.exists(): # we are still uploading... get another temp folder local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute() artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression) current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32) if name in self._last_artifacts_upload: previous_sha2 = self._last_artifacts_upload[name] if previous_sha2 == current_sha2: # nothing to do, we can skip the upload local_csv.unlink() continue self._last_artifacts_upload[name] = current_sha2 # now upload and delete at the end. logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(), delete_after_upload=True, iteration=self._task.get_last_iteration(), max_image_history=2) def _get_statistics(self): summary = '' thread_pool = ThreadPool() try: # build hash row sets artifacts_summary = [] for a_name, a_df in self._artifacts_dict.items(): if not pd or not isinstance(a_df, pd.DataFrame): continue a_unique_hash = set() def hash_row(r): a_unique_hash.add(hash(bytes(r))) a_shape = a_df.shape # parallelize thread_pool.map(hash_row, a_df.values) # add result artifacts_summary.append((a_name, a_shape, a_unique_hash,)) # build intersection summary for i, (name, shape, unique_hash) in enumerate(artifacts_summary): summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format( name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0])) for name2, shape2, unique_hash2 in artifacts_summary[i+1:]: intersection = len(unique_hash & unique_hash2) summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format( name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2))) except Exception as e: LoggerRoot.get_base_logger().warning(str(e)) finally: thread_pool.close() thread_pool.terminate() return summary def _get_temp_folder(self, force_new=False): if force_new or not self._temp_folder: new_temp = mkdtemp(prefix='artifacts_') self._temp_folder.append(new_temp) return new_temp return self._temp_folder[0] @staticmethod def sha256sum(filename, skip_header=0): # create sha2 of the file, notice we skip the header of the file (32 bytes) # because sometimes that is the only change h = hashlib.sha256() b = bytearray(Artifacts._hash_block_size) mv = memoryview(b) with open(filename, 'rb', buffering=0) as f: # skip header f.read(skip_header) for n in iter(lambda: f.readinto(mv), 0): h.update(mv[:n]) return h.hexdigest()