clearml/trains/binding/artifacts.py

398 lines
16 KiB
Python

import hashlib
import json
import logging
import mimetypes
import os
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp, mkstemp
from threading import Thread, Event, RLock
from time import time
import six
from pathlib2 import Path
from PIL import Image
from ..backend_interface.metrics.events import UploadEvent
from ..backend_api import Session
from ..debugging.log import LoggerRoot
from ..backend_api.services import tasks
try:
import pandas as pd
except ImportError:
pd = None
try:
import numpy as np
except ImportError:
np = None
class Artifacts(object):
_flush_frequency_sec = 300.
# notice these two should match
_save_format = '.csv.gz'
_compression = 'gzip'
# hashing constants
_hash_block_size = 65536
_pd_artifact_type = 'data-audit-table'
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, artifacts_manager, *args, **kwargs):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
self.artifact_metadata = {}
def __setitem__(self, key, value):
# check that value is of type pandas
if pd and isinstance(value, pd.DataFrame):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
self.pop(name, None)
def add_metadata(self, name, metadata):
self.artifact_metadata[name] = deepcopy(metadata)
def get_metadata(self, name):
return self.artifact_metadata.get(name)
@property
def artifacts(self):
return self._artifacts_dict
@property
def summary(self):
return self._summary
def __init__(self, task):
self._task = task
# notice the double link, this important since the Artifact
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
self._unregister_request = set()
self._thread = None
self._flush_event = Event()
self._exit_flag = False
self._thread_pool = ThreadPool()
self._summary = ''
self._temp_folder = []
self._task_artifact_list = []
self._task_edit_lock = RLock()
self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
if name in self._artifacts_dict:
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
self._artifacts_dict[name] = artifact
if metadata:
self._artifacts_dict.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
'please upgrade to the latest server version')
return False
if name in self._artifacts_dict:
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData()
use_filename_in_uri = True
if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix='.npz')
os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, Image.Image):
artifact_type = 'image'
artifact_type_data.content_type = 'image/png'
desc = str(artifact_object.__repr__())
artifact_type_data.preview = desc[1:desc.find(' at ')]
fd, local_filename = mkstemp(suffix='.png')
os.close(fd)
artifact_object.save(local_filename)
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, dict):
artifact_type = 'JSON'
artifact_type_data.content_type = 'application/json'
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
fd, local_filename = mkstemp(suffix='.json')
os.write(fd, bytes(preview.encode()))
os.close(fd)
artifact_type_data.preview = preview
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
if isinstance(artifact_object, Path):
artifact_object = artifact_object.as_posix()
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
else:
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
# remove from existing list, if exists
for artifact in self._task_artifact_list:
if artifact.key == name:
if artifact.type == self._pd_artifact_type:
raise ValueError("Artifact of name {} already registered, "
"use register_artifact instead".format(name))
self._task_artifact_list.remove(artifact)
break
# check that the file to upload exists
local_filename = Path(local_filename).absolute()
if not local_filename.exists() or not local_filename.is_file():
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
local_filename.as_posix()))
return False
file_hash, _ = self.sha256sum(local_filename.as_posix())
timestamp = int(time())
file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
artifact = tasks.Artifact(key=name, type=artifact_type,
uri=uri,
content_size=file_size,
hash=file_hash,
timestamp=timestamp,
type_data=artifact_type_data,
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
# update task artifacts
with self._task_edit_lock:
self._task_artifact_list.append(artifact)
self._task.set_artifacts(self._task_artifact_list)
return True
def flush(self):
# start the thread if it hasn't already:
self._start()
# flush the current state of all artifacts
self._flush_event.set()
def stop(self, wait=True):
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
self._flush_event.set()
if wait:
if self._thread:
self._thread.join()
# remove all temp folders
for f in self._temp_folder:
try:
Path(f).rmdir()
except Exception:
pass
def _start(self):
if not self._thread:
# start the daemon thread
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
try:
self._upload_data_audit_artifacts(name)
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
def _upload_data_audit_artifacts(self, name):
logger = self._task.get_logger()
pd_artifact = self._artifacts_dict.get(name)
pd_metadata = self._artifacts_dict.get_metadata(name)
# remove from artifacts watch list
if name in self._unregister_request:
try:
self._unregister_request.remove(name)
except KeyError:
pass
self._artifacts_dict.unregister_artifact(name)
if pd_artifact is None:
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
return
self._last_artifacts_upload[name] = current_sha2
# If old trains-server, upload as debug image
if not Session.check_min_api_version('2.3'):
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
return
# Find our artifact
artifact = None
for an_artifact in self._task_artifact_list:
if an_artifact.key == name:
artifact = an_artifact
break
file_size = local_csv.stat().st_size
# upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
# update task artifacts
with self._task_edit_lock:
if not artifact:
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
self._task_artifact_list.append(artifact)
artifact_type_data = tasks.ArtifactTypeData()
artifact_type_data.data_hash = current_sha2
artifact_type_data.content_type = "text/csv"
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
artifact.type_data = artifact_type_data
artifact.uri = uri
artifact.content_size = file_size
artifact.hash = file_sha2
artifact.timestamp = int(time())
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
"""
Upload local file and return uri of the uploaded file (uploading in the background)
"""
upload_uri = self._task.output_uri or self._task.get_logger().get_default_upload_destination()
if not isinstance(local_file, Path):
local_file = Path(local_file)
ev = UploadEvent(metric='artifacts', variant=name,
image_data=None, upload_uri=upload_uri,
local_image_path=local_file.as_posix(),
delete_after_upload=delete_after_upload,
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
override_storage_key_prefix=self._get_storage_uri_prefix())
_, uri = ev.get_target_full_upload_uri(upload_uri)
# send for upload
self._task.reporter._report(ev)
return uri
def _get_statistics(self, artifacts_dict=None):
summary = ''
artifacts_dict = artifacts_dict or self._artifacts_dict
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
for a_name, a_df in artifacts_dict.items():
if not pd or not isinstance(a_df, pd.DataFrame):
continue
a_unique_hash = set()
def hash_row(r):
a_unique_hash.add(hash(bytes(r)))
a_shape = a_df.shape
# parallelize
thread_pool.map(hash_row, a_df.values)
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:
thread_pool.close()
thread_pool.terminate()
return summary
def _get_temp_folder(self, force_new=False):
if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp)
return new_temp
return self._temp_folder[0]
def _get_storage_uri_prefix(self):
if not self._storage_prefix:
self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix
@staticmethod
def sha256sum(filename, skip_header=0):
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()
file_hash = hashlib.sha256()
b = bytearray(Artifacts._hash_block_size)
mv = memoryview(b)
try:
with open(filename, 'rb', buffering=0) as f:
# skip header
if skip_header:
file_hash.update(f.read(skip_header))
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
if skip_header:
file_hash.update(mv[:n])
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
return None, None
return h.hexdigest(), file_hash.hexdigest() if skip_header else None