clearml/trains/binding/artifacts.py

398 lines
16 KiB
Python
Raw Normal View History

import hashlib
2019-09-13 14:09:24 +00:00
import json
import logging
import mimetypes
import os
2019-09-09 18:50:18 +00:00
from copy import deepcopy
from multiprocessing.pool import ThreadPool
2019-09-13 14:09:24 +00:00
from tempfile import mkdtemp, mkstemp
from threading import Thread, Event, RLock
from time import time
2019-09-13 14:09:24 +00:00
import six
from pathlib2 import Path
2019-09-13 14:09:24 +00:00
from PIL import Image
2019-09-09 18:50:18 +00:00
2019-09-13 14:09:24 +00:00
from ..backend_interface.metrics.events import UploadEvent
from ..backend_api import Session
from ..debugging.log import LoggerRoot
2019-09-13 14:09:24 +00:00
from ..backend_api.services import tasks
try:
import pandas as pd
except ImportError:
pd = None
2019-09-13 14:09:24 +00:00
try:
import numpy as np
except ImportError:
np = None
class Artifacts(object):
_flush_frequency_sec = 300.
# notice these two should match
_save_format = '.csv.gz'
_compression = 'gzip'
# hashing constants
_hash_block_size = 65536
2019-09-13 14:09:24 +00:00
_pd_artifact_type = 'data-audit-table'
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, artifacts_manager, *args, **kwargs):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
2019-09-07 20:27:16 +00:00
self.artifact_metadata = {}
def __setitem__(self, key, value):
# check that value is of type pandas
2019-09-13 14:09:24 +00:00
if pd and isinstance(value, pd.DataFrame):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
2019-09-07 20:27:16 +00:00
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
2019-09-07 20:27:16 +00:00
def unregister_artifact(self, name):
self.artifact_metadata.pop(name, None)
self.pop(name, None)
2019-09-07 20:27:16 +00:00
def add_metadata(self, name, metadata):
self.artifact_metadata[name] = deepcopy(metadata)
def get_metadata(self, name):
return self.artifact_metadata.get(name)
@property
def artifacts(self):
return self._artifacts_dict
@property
def summary(self):
return self._summary
def __init__(self, task):
self._task = task
# notice the double link, this important since the Artifact
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
2019-09-07 20:27:16 +00:00
self._unregister_request = set()
self._thread = None
self._flush_event = Event()
self._exit_flag = False
self._thread_pool = ThreadPool()
self._summary = ''
self._temp_folder = []
2019-09-13 14:09:24 +00:00
self._task_artifact_list = []
self._task_edit_lock = RLock()
self._storage_prefix = None
2019-09-07 20:27:16 +00:00
def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
2019-09-09 18:50:18 +00:00
if name in self._artifacts_dict:
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
self._artifacts_dict[name] = artifact
2019-09-07 20:27:16 +00:00
if metadata:
self._artifacts_dict.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
self._unregister_request.add(name)
self.flush()
2019-09-13 14:09:24 +00:00
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
'please upgrade to the latest server version')
return False
if name in self._artifacts_dict:
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData()
use_filename_in_uri = True
if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix='.npz')
os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, Image.Image):
artifact_type = 'image'
artifact_type_data.content_type = 'image/png'
desc = str(artifact_object.__repr__())
artifact_type_data.preview = desc[1:desc.find(' at ')]
fd, local_filename = mkstemp(suffix='.png')
os.close(fd)
artifact_object.save(local_filename)
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, dict):
artifact_type = 'JSON'
artifact_type_data.content_type = 'application/json'
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
fd, local_filename = mkstemp(suffix='.json')
os.write(fd, bytes(preview.encode()))
os.close(fd)
artifact_type_data.preview = preview
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
if isinstance(artifact_object, Path):
artifact_object = artifact_object.as_posix()
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
else:
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
# remove from existing list, if exists
for artifact in self._task_artifact_list:
if artifact.key == name:
if artifact.type == self._pd_artifact_type:
raise ValueError("Artifact of name {} already registered, "
"use register_artifact instead".format(name))
self._task_artifact_list.remove(artifact)
break
# check that the file to upload exists
local_filename = Path(local_filename).absolute()
if not local_filename.exists() or not local_filename.is_file():
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
local_filename.as_posix()))
return False
file_hash, _ = self.sha256sum(local_filename.as_posix())
timestamp = int(time())
file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
artifact = tasks.Artifact(key=name, type=artifact_type,
uri=uri,
content_size=file_size,
hash=file_hash,
timestamp=timestamp,
type_data=artifact_type_data,
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
# update task artifacts
with self._task_edit_lock:
self._task_artifact_list.append(artifact)
self._task.set_artifacts(self._task_artifact_list)
return True
def flush(self):
# start the thread if it hasn't already:
self._start()
# flush the current state of all artifacts
self._flush_event.set()
def stop(self, wait=True):
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
self._flush_event.set()
if wait:
if self._thread:
self._thread.join()
# remove all temp folders
for f in self._temp_folder:
try:
Path(f).rmdir()
except Exception:
pass
def _start(self):
if not self._thread:
# start the daemon thread
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
2019-09-13 14:09:24 +00:00
artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
try:
self._upload_data_audit_artifacts(name)
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
2019-09-13 14:09:24 +00:00
def _upload_data_audit_artifacts(self, name):
logger = self._task.get_logger()
2019-09-13 14:09:24 +00:00
pd_artifact = self._artifacts_dict.get(name)
pd_metadata = self._artifacts_dict.get_metadata(name)
2019-09-07 20:27:16 +00:00
# remove from artifacts watch list
if name in self._unregister_request:
try:
self._unregister_request.remove(name)
except KeyError:
pass
self._artifacts_dict.unregister_artifact(name)
2019-09-13 14:09:24 +00:00
if pd_artifact is None:
2019-09-07 20:27:16 +00:00
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
2019-09-13 14:09:24 +00:00
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
2019-09-07 20:27:16 +00:00
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
return
self._last_artifacts_upload[name] = current_sha2
2019-09-13 14:09:24 +00:00
# If old trains-server, upload as debug image
if not Session.check_min_api_version('2.3'):
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
return
# Find our artifact
artifact = None
for an_artifact in self._task_artifact_list:
if an_artifact.key == name:
artifact = an_artifact
break
file_size = local_csv.stat().st_size
# upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
# update task artifacts
with self._task_edit_lock:
if not artifact:
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
self._task_artifact_list.append(artifact)
artifact_type_data = tasks.ArtifactTypeData()
artifact_type_data.data_hash = current_sha2
artifact_type_data.content_type = "text/csv"
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
artifact.type_data = artifact_type_data
artifact.uri = uri
artifact.content_size = file_size
artifact.hash = file_sha2
artifact.timestamp = int(time())
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
"""
Upload local file and return uri of the uploaded file (uploading in the background)
"""
upload_uri = self._task.output_uri or self._task.get_logger().get_default_upload_destination()
2019-09-13 14:09:24 +00:00
if not isinstance(local_file, Path):
local_file = Path(local_file)
ev = UploadEvent(metric='artifacts', variant=name,
image_data=None, upload_uri=upload_uri,
local_image_path=local_file.as_posix(),
delete_after_upload=delete_after_upload,
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
override_storage_key_prefix=self._get_storage_uri_prefix())
_, uri = ev.get_target_full_upload_uri(upload_uri)
# send for upload
self._task.reporter._report(ev)
return uri
def _get_statistics(self, artifacts_dict=None):
summary = ''
2019-09-13 14:09:24 +00:00
artifacts_dict = artifacts_dict or self._artifacts_dict
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
2019-09-13 14:09:24 +00:00
for a_name, a_df in artifacts_dict.items():
if not pd or not isinstance(a_df, pd.DataFrame):
continue
a_unique_hash = set()
def hash_row(r):
a_unique_hash.add(hash(bytes(r)))
a_shape = a_df.shape
# parallelize
thread_pool.map(hash_row, a_df.values)
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:
thread_pool.close()
thread_pool.terminate()
return summary
def _get_temp_folder(self, force_new=False):
if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp)
return new_temp
return self._temp_folder[0]
2019-09-13 14:09:24 +00:00
def _get_storage_uri_prefix(self):
if not self._storage_prefix:
self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix
@staticmethod
def sha256sum(filename, skip_header=0):
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()
2019-09-13 14:09:24 +00:00
file_hash = hashlib.sha256()
b = bytearray(Artifacts._hash_block_size)
mv = memoryview(b)
2019-09-13 14:09:24 +00:00
try:
with open(filename, 'rb', buffering=0) as f:
# skip header
if skip_header:
file_hash.update(f.read(skip_header))
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
if skip_header:
file_hash.update(mv[:n])
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
return None, None
return h.hexdigest(), file_hash.hexdigest() if skip_header else None