mirror of
https://github.com/clearml/clearml
synced 2025-02-07 21:33:25 +00:00
![allegroai](/assets/img/avatar_default.png)
Fix sub-process support Fix delete_after_upload option when uploading images Add logugu support Fix subsample plots if they are too big Fix requests for over 15mb
203 lines
7.7 KiB
Python
203 lines
7.7 KiB
Python
import os
|
|
import weakref
|
|
|
|
import numpy as np
|
|
import hashlib
|
|
from tempfile import mkstemp, mkdtemp
|
|
from threading import Thread, Event
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
from pathlib2 import Path
|
|
from ..debugging.log import LoggerRoot
|
|
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
pd = None
|
|
|
|
|
|
class Artifacts(object):
|
|
_flush_frequency_sec = 300.
|
|
# notice these two should match
|
|
_save_format = '.csv.gz'
|
|
_compression = 'gzip'
|
|
# hashing constants
|
|
_hash_block_size = 65536
|
|
|
|
class _ProxyDictWrite(dict):
|
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
|
def __init__(self, artifacts_manager, *args, **kwargs):
|
|
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
|
|
self._artifacts_manager = artifacts_manager
|
|
# list of artifacts we should not upload (by name & weak-reference)
|
|
self.local_artifacts = {}
|
|
|
|
def __setitem__(self, key, value):
|
|
# check that value is of type pandas
|
|
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
|
|
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
|
|
|
if self._artifacts_manager:
|
|
self._artifacts_manager.flush()
|
|
else:
|
|
raise ValueError('Artifacts currently supports pandas.DataFrame objects only')
|
|
|
|
def disable_upload(self, name):
|
|
if name in self.keys():
|
|
self.local_artifacts[name] = weakref.ref(self.get(name))
|
|
|
|
def do_upload(self, name):
|
|
# return True is this artifact should be uploaded
|
|
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
|
|
|
|
@property
|
|
def artifacts(self):
|
|
return self._artifacts_dict
|
|
|
|
@property
|
|
def summary(self):
|
|
return self._summary
|
|
|
|
def __init__(self, task):
|
|
self._task = task
|
|
# notice the double link, this important since the Artifact
|
|
# dictionary needs to signal the Artifacts base on changes
|
|
self._artifacts_dict = self._ProxyDictWrite(self)
|
|
self._last_artifacts_upload = {}
|
|
self._thread = None
|
|
self._flush_event = Event()
|
|
self._exit_flag = False
|
|
self._thread_pool = ThreadPool()
|
|
self._summary = ''
|
|
self._temp_folder = []
|
|
|
|
def add_artifact(self, name, artifact, upload=True):
|
|
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
|
# or numpy array, which we will upload as npz
|
|
self._artifacts_dict[name] = artifact
|
|
if not upload:
|
|
self._artifacts_dict.disable_upload(name)
|
|
|
|
def flush(self):
|
|
# start the thread if it hasn't already:
|
|
self._start()
|
|
# flush the current state of all artifacts
|
|
self._flush_event.set()
|
|
|
|
def stop(self, wait=True):
|
|
# stop the daemon thread and quit
|
|
# wait until thread exists
|
|
self._exit_flag = True
|
|
self._flush_event.set()
|
|
if wait:
|
|
if self._thread:
|
|
self._thread.join()
|
|
# remove all temp folders
|
|
for f in self._temp_folder:
|
|
try:
|
|
Path(f).rmdir()
|
|
except Exception:
|
|
pass
|
|
|
|
def _start(self):
|
|
if not self._thread:
|
|
# start the daemon thread
|
|
self._flush_event.clear()
|
|
self._thread = Thread(target=self._daemon)
|
|
self._thread.daemon = True
|
|
self._thread.start()
|
|
|
|
def _daemon(self):
|
|
while not self._exit_flag:
|
|
self._flush_event.wait(self._flush_frequency_sec)
|
|
self._flush_event.clear()
|
|
try:
|
|
self._upload_artifacts()
|
|
except Exception as e:
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
|
|
|
# create summary
|
|
self._summary = self._get_statistics()
|
|
|
|
def _upload_artifacts(self):
|
|
logger = self._task.get_logger()
|
|
for name, artifact in self._artifacts_dict.items():
|
|
if not self._artifacts_dict.do_upload(name):
|
|
# only register artifacts, and leave, TBD
|
|
continue
|
|
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
|
if local_csv.exists():
|
|
# we are still uploading... get another temp folder
|
|
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
|
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
|
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
|
if name in self._last_artifacts_upload:
|
|
previous_sha2 = self._last_artifacts_upload[name]
|
|
if previous_sha2 == current_sha2:
|
|
# nothing to do, we can skip the upload
|
|
local_csv.unlink()
|
|
continue
|
|
self._last_artifacts_upload[name] = current_sha2
|
|
# now upload and delete at the end.
|
|
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
|
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
|
max_image_history=2)
|
|
|
|
def _get_statistics(self):
|
|
summary = ''
|
|
thread_pool = ThreadPool()
|
|
|
|
try:
|
|
# build hash row sets
|
|
artifacts_summary = []
|
|
for a_name, a_df in self._artifacts_dict.items():
|
|
if not pd or not isinstance(a_df, pd.DataFrame):
|
|
continue
|
|
|
|
a_unique_hash = set()
|
|
|
|
def hash_row(r):
|
|
a_unique_hash.add(hash(bytes(r)))
|
|
|
|
a_shape = a_df.shape
|
|
# parallelize
|
|
thread_pool.map(hash_row, a_df.values)
|
|
# add result
|
|
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
|
|
|
# build intersection summary
|
|
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
|
|
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
|
|
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
|
|
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
|
|
intersection = len(unique_hash & unique_hash2)
|
|
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
|
|
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
|
|
except Exception as e:
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
|
finally:
|
|
thread_pool.close()
|
|
thread_pool.terminate()
|
|
return summary
|
|
|
|
def _get_temp_folder(self, force_new=False):
|
|
if force_new or not self._temp_folder:
|
|
new_temp = mkdtemp(prefix='artifacts_')
|
|
self._temp_folder.append(new_temp)
|
|
return new_temp
|
|
return self._temp_folder[0]
|
|
|
|
@staticmethod
|
|
def sha256sum(filename, skip_header=0):
|
|
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
|
# because sometimes that is the only change
|
|
h = hashlib.sha256()
|
|
b = bytearray(Artifacts._hash_block_size)
|
|
mv = memoryview(b)
|
|
with open(filename, 'rb', buffering=0) as f:
|
|
# skip header
|
|
f.read(skip_header)
|
|
for n in iter(lambda: f.readinto(mv), 0):
|
|
h.update(mv[:n])
|
|
return h.hexdigest()
|