mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
202 lines
7.6 KiB
Python
202 lines
7.6 KiB
Python
import weakref
|
|
|
|
import numpy as np
|
|
import hashlib
|
|
from tempfile import mkdtemp
|
|
from threading import Thread, Event
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
from pathlib2 import Path
|
|
from ..debugging.log import LoggerRoot
|
|
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
pd = None
|
|
|
|
|
|
class Artifacts(object):
|
|
_flush_frequency_sec = 300.
|
|
# notice these two should match
|
|
_save_format = '.csv.gz'
|
|
_compression = 'gzip'
|
|
# hashing constants
|
|
_hash_block_size = 65536
|
|
|
|
class _ProxyDictWrite(dict):
|
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
|
def __init__(self, artifacts_manager, *args, **kwargs):
|
|
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
|
|
self._artifacts_manager = artifacts_manager
|
|
# list of artifacts we should not upload (by name & weak-reference)
|
|
self.local_artifacts = {}
|
|
|
|
def __setitem__(self, key, value):
|
|
# check that value is of type pandas
|
|
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
|
|
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
|
|
|
if self._artifacts_manager:
|
|
self._artifacts_manager.flush()
|
|
else:
|
|
raise ValueError('Artifacts currently supports pandas.DataFrame objects only')
|
|
|
|
def disable_upload(self, name):
|
|
if name in self.keys():
|
|
self.local_artifacts[name] = weakref.ref(self.get(name))
|
|
|
|
def do_upload(self, name):
|
|
# return True is this artifact should be uploaded
|
|
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
|
|
|
|
@property
|
|
def artifacts(self):
|
|
return self._artifacts_dict
|
|
|
|
@property
|
|
def summary(self):
|
|
return self._summary
|
|
|
|
def __init__(self, task):
|
|
self._task = task
|
|
# notice the double link, this important since the Artifact
|
|
# dictionary needs to signal the Artifacts base on changes
|
|
self._artifacts_dict = self._ProxyDictWrite(self)
|
|
self._last_artifacts_upload = {}
|
|
self._thread = None
|
|
self._flush_event = Event()
|
|
self._exit_flag = False
|
|
self._thread_pool = ThreadPool()
|
|
self._summary = ''
|
|
self._temp_folder = []
|
|
|
|
def add_artifact(self, name, artifact, upload=True):
|
|
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
|
# or numpy array, which we will upload as npz
|
|
self._artifacts_dict[name] = artifact
|
|
if not upload:
|
|
self._artifacts_dict.disable_upload(name)
|
|
|
|
def flush(self):
|
|
# start the thread if it hasn't already:
|
|
self._start()
|
|
# flush the current state of all artifacts
|
|
self._flush_event.set()
|
|
|
|
def stop(self, wait=True):
|
|
# stop the daemon thread and quit
|
|
# wait until thread exists
|
|
self._exit_flag = True
|
|
self._flush_event.set()
|
|
if wait:
|
|
if self._thread:
|
|
self._thread.join()
|
|
# remove all temp folders
|
|
for f in self._temp_folder:
|
|
try:
|
|
Path(f).rmdir()
|
|
except Exception:
|
|
pass
|
|
|
|
def _start(self):
|
|
if not self._thread:
|
|
# start the daemon thread
|
|
self._flush_event.clear()
|
|
self._thread = Thread(target=self._daemon)
|
|
self._thread.daemon = True
|
|
self._thread.start()
|
|
|
|
def _daemon(self):
|
|
while not self._exit_flag:
|
|
self._flush_event.wait(self._flush_frequency_sec)
|
|
self._flush_event.clear()
|
|
try:
|
|
self._upload_artifacts()
|
|
except Exception as e:
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
|
|
|
# create summary
|
|
self._summary = self._get_statistics()
|
|
|
|
def _upload_artifacts(self):
|
|
logger = self._task.get_logger()
|
|
for name, artifact in self._artifacts_dict.items():
|
|
if not self._artifacts_dict.do_upload(name):
|
|
# only register artifacts, and leave, TBD
|
|
continue
|
|
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
|
if local_csv.exists():
|
|
# we are still uploading... get another temp folder
|
|
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
|
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
|
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
|
if name in self._last_artifacts_upload:
|
|
previous_sha2 = self._last_artifacts_upload[name]
|
|
if previous_sha2 == current_sha2:
|
|
# nothing to do, we can skip the upload
|
|
local_csv.unlink()
|
|
continue
|
|
self._last_artifacts_upload[name] = current_sha2
|
|
# now upload and delete at the end.
|
|
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
|
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
|
max_image_history=2)
|
|
|
|
def _get_statistics(self):
|
|
summary = ''
|
|
thread_pool = ThreadPool()
|
|
|
|
try:
|
|
# build hash row sets
|
|
artifacts_summary = []
|
|
for a_name, a_df in self._artifacts_dict.items():
|
|
if not pd or not isinstance(a_df, pd.DataFrame):
|
|
continue
|
|
|
|
a_unique_hash = set()
|
|
|
|
def hash_row(r):
|
|
a_unique_hash.add(hash(bytes(r)))
|
|
|
|
a_shape = a_df.shape
|
|
# parallelize
|
|
thread_pool.map(hash_row, a_df.values)
|
|
# add result
|
|
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
|
|
|
# build intersection summary
|
|
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
|
|
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
|
|
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
|
|
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
|
|
intersection = len(unique_hash & unique_hash2)
|
|
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
|
|
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
|
|
except Exception as e:
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
|
finally:
|
|
thread_pool.close()
|
|
thread_pool.terminate()
|
|
return summary
|
|
|
|
def _get_temp_folder(self, force_new=False):
|
|
if force_new or not self._temp_folder:
|
|
new_temp = mkdtemp(prefix='artifacts_')
|
|
self._temp_folder.append(new_temp)
|
|
return new_temp
|
|
return self._temp_folder[0]
|
|
|
|
@staticmethod
|
|
def sha256sum(filename, skip_header=0):
|
|
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
|
# because sometimes that is the only change
|
|
h = hashlib.sha256()
|
|
b = bytearray(Artifacts._hash_block_size)
|
|
mv = memoryview(b)
|
|
with open(filename, 'rb', buffering=0) as f:
|
|
# skip header
|
|
f.read(skip_header)
|
|
for n in iter(lambda: f.readinto(mv), 0):
|
|
h.update(mv[:n])
|
|
return h.hexdigest()
|