clearml/trains/binding/artifacts.py

202 lines
7.6 KiB
Python

import weakref
import numpy as np
import hashlib
from tempfile import mkdtemp
from threading import Thread, Event
from multiprocessing.pool import ThreadPool
from pathlib2 import Path
from ..debugging.log import LoggerRoot
try:
import pandas as pd
except ImportError:
pd = None
class Artifacts(object):
_flush_frequency_sec = 300.
# notice these two should match
_save_format = '.csv.gz'
_compression = 'gzip'
# hashing constants
_hash_block_size = 65536
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, artifacts_manager, *args, **kwargs):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
self.local_artifacts = {}
def __setitem__(self, key, value):
# check that value is of type pandas
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently supports pandas.DataFrame objects only')
def disable_upload(self, name):
if name in self.keys():
self.local_artifacts[name] = weakref.ref(self.get(name))
def do_upload(self, name):
# return True is this artifact should be uploaded
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
@property
def artifacts(self):
return self._artifacts_dict
@property
def summary(self):
return self._summary
def __init__(self, task):
self._task = task
# notice the double link, this important since the Artifact
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
self._thread = None
self._flush_event = Event()
self._exit_flag = False
self._thread_pool = ThreadPool()
self._summary = ''
self._temp_folder = []
def add_artifact(self, name, artifact, upload=True):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
# or numpy array, which we will upload as npz
self._artifacts_dict[name] = artifact
if not upload:
self._artifacts_dict.disable_upload(name)
def flush(self):
# start the thread if it hasn't already:
self._start()
# flush the current state of all artifacts
self._flush_event.set()
def stop(self, wait=True):
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
self._flush_event.set()
if wait:
if self._thread:
self._thread.join()
# remove all temp folders
for f in self._temp_folder:
try:
Path(f).rmdir()
except Exception:
pass
def _start(self):
if not self._thread:
# start the daemon thread
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
try:
self._upload_artifacts()
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
def _upload_artifacts(self):
logger = self._task.get_logger()
for name, artifact in self._artifacts_dict.items():
if not self._artifacts_dict.do_upload(name):
# only register artifacts, and leave, TBD
continue
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
continue
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
def _get_statistics(self):
summary = ''
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
for a_name, a_df in self._artifacts_dict.items():
if not pd or not isinstance(a_df, pd.DataFrame):
continue
a_unique_hash = set()
def hash_row(r):
a_unique_hash.add(hash(bytes(r)))
a_shape = a_df.shape
# parallelize
thread_pool.map(hash_row, a_df.values)
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:
thread_pool.close()
thread_pool.terminate()
return summary
def _get_temp_folder(self, force_new=False):
if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp)
return new_temp
return self._temp_folder[0]
@staticmethod
def sha256sum(filename, skip_header=0):
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()
b = bytearray(Artifacts._hash_block_size)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
# skip header
f.read(skip_header)
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()