2019-07-28 18:04:45 +00:00
|
|
|
import hashlib
|
2019-09-13 14:09:24 +00:00
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import mimetypes
|
|
|
|
import os
|
2019-09-09 18:50:18 +00:00
|
|
|
from copy import deepcopy
|
|
|
|
from multiprocessing.pool import ThreadPool
|
2019-09-13 14:09:24 +00:00
|
|
|
from tempfile import mkdtemp, mkstemp
|
|
|
|
from threading import Thread, Event, RLock
|
|
|
|
from time import time
|
2019-07-28 18:04:45 +00:00
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
import six
|
2019-07-28 18:04:45 +00:00
|
|
|
from pathlib2 import Path
|
2019-09-13 14:09:24 +00:00
|
|
|
from PIL import Image
|
2019-09-09 18:50:18 +00:00
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
from ..backend_interface.metrics.events import UploadEvent
|
|
|
|
from ..backend_api import Session
|
2019-07-28 18:04:45 +00:00
|
|
|
from ..debugging.log import LoggerRoot
|
2019-09-13 14:09:24 +00:00
|
|
|
from ..backend_api.services import tasks
|
2019-07-28 18:04:45 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
import pandas as pd
|
|
|
|
except ImportError:
|
|
|
|
pd = None
|
2019-09-13 14:09:24 +00:00
|
|
|
try:
|
|
|
|
import numpy as np
|
|
|
|
except ImportError:
|
|
|
|
np = None
|
2019-07-28 18:04:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Artifacts(object):
|
|
|
|
_flush_frequency_sec = 300.
|
|
|
|
# notice these two should match
|
|
|
|
_save_format = '.csv.gz'
|
|
|
|
_compression = 'gzip'
|
|
|
|
# hashing constants
|
|
|
|
_hash_block_size = 65536
|
2019-09-13 14:09:24 +00:00
|
|
|
_pd_artifact_type = 'data-audit-table'
|
2019-07-28 18:04:45 +00:00
|
|
|
|
|
|
|
class _ProxyDictWrite(dict):
|
|
|
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
|
|
|
def __init__(self, artifacts_manager, *args, **kwargs):
|
|
|
|
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
|
|
|
|
self._artifacts_manager = artifacts_manager
|
|
|
|
# list of artifacts we should not upload (by name & weak-reference)
|
2019-09-07 20:27:16 +00:00
|
|
|
self.artifact_metadata = {}
|
2019-07-28 18:04:45 +00:00
|
|
|
|
|
|
|
def __setitem__(self, key, value):
|
|
|
|
# check that value is of type pandas
|
2019-09-13 14:09:24 +00:00
|
|
|
if pd and isinstance(value, pd.DataFrame):
|
2019-07-28 18:04:45 +00:00
|
|
|
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
|
|
|
|
|
|
|
if self._artifacts_manager:
|
|
|
|
self._artifacts_manager.flush()
|
|
|
|
else:
|
2019-09-07 20:27:16 +00:00
|
|
|
raise ValueError('Artifacts currently support pandas.DataFrame objects only')
|
2019-07-28 18:04:45 +00:00
|
|
|
|
2019-09-07 20:27:16 +00:00
|
|
|
def unregister_artifact(self, name):
|
|
|
|
self.artifact_metadata.pop(name, None)
|
|
|
|
self.pop(name, None)
|
2019-07-28 18:04:45 +00:00
|
|
|
|
2019-09-07 20:27:16 +00:00
|
|
|
def add_metadata(self, name, metadata):
|
|
|
|
self.artifact_metadata[name] = deepcopy(metadata)
|
|
|
|
|
|
|
|
def get_metadata(self, name):
|
|
|
|
return self.artifact_metadata.get(name)
|
2019-07-28 18:04:45 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def artifacts(self):
|
|
|
|
return self._artifacts_dict
|
|
|
|
|
|
|
|
@property
|
|
|
|
def summary(self):
|
|
|
|
return self._summary
|
|
|
|
|
|
|
|
def __init__(self, task):
|
|
|
|
self._task = task
|
|
|
|
# notice the double link, this important since the Artifact
|
|
|
|
# dictionary needs to signal the Artifacts base on changes
|
|
|
|
self._artifacts_dict = self._ProxyDictWrite(self)
|
|
|
|
self._last_artifacts_upload = {}
|
2019-09-07 20:27:16 +00:00
|
|
|
self._unregister_request = set()
|
2019-07-28 18:04:45 +00:00
|
|
|
self._thread = None
|
|
|
|
self._flush_event = Event()
|
|
|
|
self._exit_flag = False
|
|
|
|
self._thread_pool = ThreadPool()
|
|
|
|
self._summary = ''
|
|
|
|
self._temp_folder = []
|
2019-09-13 14:09:24 +00:00
|
|
|
self._task_artifact_list = []
|
|
|
|
self._task_edit_lock = RLock()
|
|
|
|
self._storage_prefix = None
|
2019-07-28 18:04:45 +00:00
|
|
|
|
2019-09-07 20:27:16 +00:00
|
|
|
def register_artifact(self, name, artifact, metadata=None):
|
2019-07-28 18:04:45 +00:00
|
|
|
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
2019-09-09 18:50:18 +00:00
|
|
|
if name in self._artifacts_dict:
|
|
|
|
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
|
2019-07-28 18:04:45 +00:00
|
|
|
self._artifacts_dict[name] = artifact
|
2019-09-07 20:27:16 +00:00
|
|
|
if metadata:
|
|
|
|
self._artifacts_dict.add_metadata(name, metadata)
|
|
|
|
|
|
|
|
def unregister_artifact(self, name):
|
|
|
|
# Remove artifact from the watch list
|
|
|
|
self._unregister_request.add(name)
|
|
|
|
self.flush()
|
2019-07-28 18:04:45 +00:00
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
|
|
|
|
if not Session.check_min_api_version('2.3'):
|
|
|
|
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
|
|
|
|
'please upgrade to the latest server version')
|
|
|
|
return False
|
|
|
|
|
|
|
|
if name in self._artifacts_dict:
|
|
|
|
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
|
|
|
|
|
|
|
|
artifact_type_data = tasks.ArtifactTypeData()
|
|
|
|
use_filename_in_uri = True
|
|
|
|
if np and isinstance(artifact_object, np.ndarray):
|
|
|
|
artifact_type = 'numpy'
|
|
|
|
artifact_type_data.content_type = 'application/numpy'
|
|
|
|
artifact_type_data.preview = str(artifact_object.__repr__())
|
|
|
|
fd, local_filename = mkstemp(suffix='.npz')
|
|
|
|
os.close(fd)
|
|
|
|
np.savez_compressed(local_filename, **{name: artifact_object})
|
|
|
|
delete_after_upload = True
|
|
|
|
use_filename_in_uri = False
|
|
|
|
elif isinstance(artifact_object, Image.Image):
|
|
|
|
artifact_type = 'image'
|
|
|
|
artifact_type_data.content_type = 'image/png'
|
|
|
|
desc = str(artifact_object.__repr__())
|
|
|
|
artifact_type_data.preview = desc[1:desc.find(' at ')]
|
|
|
|
fd, local_filename = mkstemp(suffix='.png')
|
|
|
|
os.close(fd)
|
|
|
|
artifact_object.save(local_filename)
|
|
|
|
delete_after_upload = True
|
|
|
|
use_filename_in_uri = False
|
|
|
|
elif isinstance(artifact_object, dict):
|
|
|
|
artifact_type = 'JSON'
|
|
|
|
artifact_type_data.content_type = 'application/json'
|
|
|
|
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
|
|
|
|
fd, local_filename = mkstemp(suffix='.json')
|
|
|
|
os.write(fd, bytes(preview.encode()))
|
|
|
|
os.close(fd)
|
|
|
|
artifact_type_data.preview = preview
|
|
|
|
delete_after_upload = True
|
|
|
|
use_filename_in_uri = False
|
|
|
|
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
|
|
|
|
if isinstance(artifact_object, Path):
|
|
|
|
artifact_object = artifact_object.as_posix()
|
|
|
|
artifact_type = 'custom'
|
|
|
|
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
|
|
|
|
local_filename = artifact_object
|
|
|
|
else:
|
|
|
|
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
|
|
|
|
|
|
|
|
# remove from existing list, if exists
|
|
|
|
for artifact in self._task_artifact_list:
|
|
|
|
if artifact.key == name:
|
|
|
|
if artifact.type == self._pd_artifact_type:
|
|
|
|
raise ValueError("Artifact of name {} already registered, "
|
|
|
|
"use register_artifact instead".format(name))
|
|
|
|
|
|
|
|
self._task_artifact_list.remove(artifact)
|
|
|
|
break
|
|
|
|
|
|
|
|
# check that the file to upload exists
|
|
|
|
local_filename = Path(local_filename).absolute()
|
|
|
|
if not local_filename.exists() or not local_filename.is_file():
|
|
|
|
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
|
|
|
|
local_filename.as_posix()))
|
|
|
|
return False
|
|
|
|
|
|
|
|
file_hash, _ = self.sha256sum(local_filename.as_posix())
|
|
|
|
timestamp = int(time())
|
|
|
|
file_size = local_filename.stat().st_size
|
|
|
|
|
|
|
|
uri = self._upload_local_file(local_filename, name,
|
|
|
|
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
|
|
|
|
|
|
|
|
artifact = tasks.Artifact(key=name, type=artifact_type,
|
|
|
|
uri=uri,
|
|
|
|
content_size=file_size,
|
|
|
|
hash=file_hash,
|
|
|
|
timestamp=timestamp,
|
|
|
|
type_data=artifact_type_data,
|
|
|
|
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
|
|
|
|
|
|
|
|
# update task artifacts
|
|
|
|
with self._task_edit_lock:
|
|
|
|
self._task_artifact_list.append(artifact)
|
|
|
|
self._task.set_artifacts(self._task_artifact_list)
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
2019-07-28 18:04:45 +00:00
|
|
|
def flush(self):
|
|
|
|
# start the thread if it hasn't already:
|
|
|
|
self._start()
|
|
|
|
# flush the current state of all artifacts
|
|
|
|
self._flush_event.set()
|
|
|
|
|
|
|
|
def stop(self, wait=True):
|
|
|
|
# stop the daemon thread and quit
|
|
|
|
# wait until thread exists
|
|
|
|
self._exit_flag = True
|
|
|
|
self._flush_event.set()
|
|
|
|
if wait:
|
|
|
|
if self._thread:
|
|
|
|
self._thread.join()
|
|
|
|
# remove all temp folders
|
|
|
|
for f in self._temp_folder:
|
|
|
|
try:
|
|
|
|
Path(f).rmdir()
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _start(self):
|
|
|
|
if not self._thread:
|
|
|
|
# start the daemon thread
|
|
|
|
self._flush_event.clear()
|
|
|
|
self._thread = Thread(target=self._daemon)
|
|
|
|
self._thread.daemon = True
|
|
|
|
self._thread.start()
|
|
|
|
|
|
|
|
def _daemon(self):
|
|
|
|
while not self._exit_flag:
|
|
|
|
self._flush_event.wait(self._flush_frequency_sec)
|
|
|
|
self._flush_event.clear()
|
2019-09-13 14:09:24 +00:00
|
|
|
artifact_keys = list(self._artifacts_dict.keys())
|
|
|
|
for name in artifact_keys:
|
|
|
|
try:
|
|
|
|
self._upload_data_audit_artifacts(name)
|
|
|
|
except Exception as e:
|
|
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
2019-07-28 18:04:45 +00:00
|
|
|
|
|
|
|
# create summary
|
|
|
|
self._summary = self._get_statistics()
|
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
def _upload_data_audit_artifacts(self, name):
|
2019-07-28 18:04:45 +00:00
|
|
|
logger = self._task.get_logger()
|
2019-09-13 14:09:24 +00:00
|
|
|
pd_artifact = self._artifacts_dict.get(name)
|
|
|
|
pd_metadata = self._artifacts_dict.get_metadata(name)
|
2019-09-07 20:27:16 +00:00
|
|
|
|
|
|
|
# remove from artifacts watch list
|
|
|
|
if name in self._unregister_request:
|
|
|
|
try:
|
|
|
|
self._unregister_request.remove(name)
|
|
|
|
except KeyError:
|
|
|
|
pass
|
|
|
|
self._artifacts_dict.unregister_artifact(name)
|
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
if pd_artifact is None:
|
2019-09-07 20:27:16 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
|
|
|
if local_csv.exists():
|
|
|
|
# we are still uploading... get another temp folder
|
|
|
|
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
2019-09-13 14:09:24 +00:00
|
|
|
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
|
|
|
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
2019-09-07 20:27:16 +00:00
|
|
|
if name in self._last_artifacts_upload:
|
|
|
|
previous_sha2 = self._last_artifacts_upload[name]
|
|
|
|
if previous_sha2 == current_sha2:
|
|
|
|
# nothing to do, we can skip the upload
|
|
|
|
local_csv.unlink()
|
|
|
|
return
|
|
|
|
self._last_artifacts_upload[name] = current_sha2
|
2019-07-28 18:04:45 +00:00
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
# If old trains-server, upload as debug image
|
|
|
|
if not Session.check_min_api_version('2.3'):
|
|
|
|
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
|
|
|
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
|
|
|
max_image_history=2)
|
|
|
|
return
|
|
|
|
|
|
|
|
# Find our artifact
|
|
|
|
artifact = None
|
|
|
|
for an_artifact in self._task_artifact_list:
|
|
|
|
if an_artifact.key == name:
|
|
|
|
artifact = an_artifact
|
|
|
|
break
|
|
|
|
|
|
|
|
file_size = local_csv.stat().st_size
|
|
|
|
|
|
|
|
# upload file
|
|
|
|
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
|
|
|
|
|
|
|
|
# update task artifacts
|
|
|
|
with self._task_edit_lock:
|
|
|
|
if not artifact:
|
|
|
|
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
|
|
|
|
self._task_artifact_list.append(artifact)
|
|
|
|
artifact_type_data = tasks.ArtifactTypeData()
|
|
|
|
|
|
|
|
artifact_type_data.data_hash = current_sha2
|
|
|
|
artifact_type_data.content_type = "text/csv"
|
|
|
|
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
|
|
|
|
|
|
|
|
artifact.type_data = artifact_type_data
|
|
|
|
artifact.uri = uri
|
|
|
|
artifact.content_size = file_size
|
|
|
|
artifact.hash = file_sha2
|
|
|
|
artifact.timestamp = int(time())
|
|
|
|
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
|
|
|
|
|
|
|
|
self._task.set_artifacts(self._task_artifact_list)
|
|
|
|
|
|
|
|
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
|
|
|
|
"""
|
|
|
|
Upload local file and return uri of the uploaded file (uploading in the background)
|
|
|
|
"""
|
2019-09-14 10:55:17 +00:00
|
|
|
upload_uri = self._task.output_uri or self._task.get_logger().get_default_upload_destination()
|
2019-09-13 14:09:24 +00:00
|
|
|
if not isinstance(local_file, Path):
|
|
|
|
local_file = Path(local_file)
|
|
|
|
ev = UploadEvent(metric='artifacts', variant=name,
|
|
|
|
image_data=None, upload_uri=upload_uri,
|
|
|
|
local_image_path=local_file.as_posix(),
|
|
|
|
delete_after_upload=delete_after_upload,
|
|
|
|
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
|
|
|
|
override_storage_key_prefix=self._get_storage_uri_prefix())
|
|
|
|
_, uri = ev.get_target_full_upload_uri(upload_uri)
|
|
|
|
|
|
|
|
# send for upload
|
|
|
|
self._task.reporter._report(ev)
|
|
|
|
|
|
|
|
return uri
|
|
|
|
|
|
|
|
def _get_statistics(self, artifacts_dict=None):
|
2019-07-28 18:04:45 +00:00
|
|
|
summary = ''
|
2019-09-13 14:09:24 +00:00
|
|
|
artifacts_dict = artifacts_dict or self._artifacts_dict
|
2019-07-28 18:04:45 +00:00
|
|
|
thread_pool = ThreadPool()
|
|
|
|
|
|
|
|
try:
|
|
|
|
# build hash row sets
|
|
|
|
artifacts_summary = []
|
2019-09-13 14:09:24 +00:00
|
|
|
for a_name, a_df in artifacts_dict.items():
|
2019-07-28 18:04:45 +00:00
|
|
|
if not pd or not isinstance(a_df, pd.DataFrame):
|
|
|
|
continue
|
|
|
|
|
|
|
|
a_unique_hash = set()
|
|
|
|
|
|
|
|
def hash_row(r):
|
|
|
|
a_unique_hash.add(hash(bytes(r)))
|
|
|
|
|
|
|
|
a_shape = a_df.shape
|
|
|
|
# parallelize
|
|
|
|
thread_pool.map(hash_row, a_df.values)
|
|
|
|
# add result
|
|
|
|
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
|
|
|
|
|
|
|
# build intersection summary
|
|
|
|
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
|
|
|
|
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
|
|
|
|
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
|
|
|
|
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
|
|
|
|
intersection = len(unique_hash & unique_hash2)
|
|
|
|
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
|
|
|
|
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
|
|
|
|
except Exception as e:
|
|
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
|
|
|
finally:
|
|
|
|
thread_pool.close()
|
|
|
|
thread_pool.terminate()
|
|
|
|
return summary
|
|
|
|
|
|
|
|
def _get_temp_folder(self, force_new=False):
|
|
|
|
if force_new or not self._temp_folder:
|
|
|
|
new_temp = mkdtemp(prefix='artifacts_')
|
|
|
|
self._temp_folder.append(new_temp)
|
|
|
|
return new_temp
|
|
|
|
return self._temp_folder[0]
|
|
|
|
|
2019-09-13 14:09:24 +00:00
|
|
|
def _get_storage_uri_prefix(self):
|
|
|
|
if not self._storage_prefix:
|
|
|
|
self._storage_prefix = self._task._get_output_destination_suffix()
|
|
|
|
return self._storage_prefix
|
|
|
|
|
2019-07-28 18:04:45 +00:00
|
|
|
@staticmethod
|
|
|
|
def sha256sum(filename, skip_header=0):
|
|
|
|
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
|
|
|
# because sometimes that is the only change
|
|
|
|
h = hashlib.sha256()
|
2019-09-13 14:09:24 +00:00
|
|
|
file_hash = hashlib.sha256()
|
2019-07-28 18:04:45 +00:00
|
|
|
b = bytearray(Artifacts._hash_block_size)
|
|
|
|
mv = memoryview(b)
|
2019-09-13 14:09:24 +00:00
|
|
|
try:
|
|
|
|
with open(filename, 'rb', buffering=0) as f:
|
|
|
|
# skip header
|
|
|
|
if skip_header:
|
|
|
|
file_hash.update(f.read(skip_header))
|
|
|
|
for n in iter(lambda: f.readinto(mv), 0):
|
|
|
|
h.update(mv[:n])
|
|
|
|
if skip_header:
|
|
|
|
file_hash.update(mv[:n])
|
|
|
|
except Exception as e:
|
|
|
|
LoggerRoot.get_base_logger().warning(str(e))
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
return h.hexdigest(), file_hash.hexdigest() if skip_header else None
|