clearml/trains/binding/artifacts.py

222 lines
8.2 KiB
Python
Raw Normal View History

import hashlib
2019-09-09 18:50:18 +00:00
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp
from threading import Thread, Event
2019-09-09 18:50:18 +00:00
import numpy as np
from pathlib2 import Path
2019-09-09 18:50:18 +00:00
from ..debugging.log import LoggerRoot
try:
import pandas as pd
except ImportError:
pd = None
class Artifacts(object):
_flush_frequency_sec = 300.
# notice these two should match
_save_format = '.csv.gz'
_compression = 'gzip'
# hashing constants
_hash_block_size = 65536
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, artifacts_manager, *args, **kwargs):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
2019-09-07 20:27:16 +00:00
self.artifact_metadata = {}
def __setitem__(self, key, value):
# check that value is of type pandas
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
2019-09-07 20:27:16 +00:00
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
2019-09-07 20:27:16 +00:00
def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
self.pop(name, None)
2019-09-07 20:27:16 +00:00
def add_metadata(self, name, metadata):
self.artifact_metadata[name] = deepcopy(metadata)
def get_metadata(self, name):
return self.artifact_metadata.get(name)
@property
def artifacts(self):
return self._artifacts_dict
@property
def summary(self):
return self._summary
def __init__(self, task):
self._task = task
# notice the double link, this important since the Artifact
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
2019-09-07 20:27:16 +00:00
self._unregister_request = set()
self._thread = None
self._flush_event = Event()
self._exit_flag = False
self._thread_pool = ThreadPool()
self._summary = ''
self._temp_folder = []
2019-09-07 20:27:16 +00:00
def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
2019-09-09 18:50:18 +00:00
if name in self._artifacts_dict:
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
self._artifacts_dict[name] = artifact
2019-09-07 20:27:16 +00:00
if metadata:
self._artifacts_dict.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
def flush(self):
# start the thread if it hasn't already:
self._start()
# flush the current state of all artifacts
self._flush_event.set()
def stop(self, wait=True):
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
self._flush_event.set()
if wait:
if self._thread:
self._thread.join()
# remove all temp folders
for f in self._temp_folder:
try:
Path(f).rmdir()
except Exception:
pass
def _start(self):
if not self._thread:
# start the daemon thread
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
try:
2019-09-07 20:27:16 +00:00
artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
self._upload_artifacts(name)
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
2019-09-07 20:27:16 +00:00
def _upload_artifacts(self, name):
logger = self._task.get_logger()
2019-09-07 20:27:16 +00:00
artifact = self._artifacts_dict.get(name)
# remove from artifacts watch list
if name in self._unregister_request:
try:
self._unregister_request.remove(name)
except KeyError:
pass
self._artifacts_dict.unregister_artifact(name)
if artifact is None:
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
return
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
def _get_statistics(self):
summary = ''
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
for a_name, a_df in self._artifacts_dict.items():
if not pd or not isinstance(a_df, pd.DataFrame):
continue
a_unique_hash = set()
def hash_row(r):
a_unique_hash.add(hash(bytes(r)))
a_shape = a_df.shape
# parallelize
thread_pool.map(hash_row, a_df.values)
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:
thread_pool.close()
thread_pool.terminate()
return summary
def _get_temp_folder(self, force_new=False):
if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp)
return new_temp
return self._temp_folder[0]
@staticmethod
def sha256sum(filename, skip_header=0):
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()
b = bytearray(Artifacts._hash_block_size)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
# skip header
f.read(skip_header)
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()