Calculate data-audit artifact uniqueness by user-criteria

This commit is contained in:
allegroai 2020-01-06 17:19:44 +02:00
parent a169b43885
commit bc33ad0da3
2 changed files with 74 additions and 23 deletions

View File

@ -2,24 +2,26 @@ import hashlib
import json
import mimetypes
import os
from zipfile import ZipFile, ZIP_DEFLATED
from copy import deepcopy
from datetime import datetime
from multiprocessing import RLock, Event
from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp, mkstemp
from threading import Thread
from multiprocessing import RLock, Event
from time import time
from zipfile import ZipFile, ZIP_DEFLATED
import humanfriendly
import six
from pathlib2 import Path
from PIL import Image
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from ..backend_interface.metrics.events import UploadEvent
from ..backend_api import Session
from ..debugging.log import LoggerRoot
from ..backend_api.services import tasks
from ..backend_interface.metrics.events import UploadEvent
from ..debugging.log import LoggerRoot
from ..storage.helper import remote_driver_schemes
try:
import pandas as pd
@ -204,6 +206,8 @@ class Artifacts(object):
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
self.artifact_metadata = {}
# list of hash columns to calculate uniqueness for the artifacts
self.artifact_hash_columns = {}
def __setitem__(self, key, value):
# check that value is of type pandas
@ -225,9 +229,15 @@ class Artifacts(object):
def get_metadata(self, name):
return self.artifact_metadata.get(name)
def add_hash_columns(self, artifact_name, hash_columns):
self.artifact_hash_columns[artifact_name] = hash_columns
def get_hash_columns(self, artifact_name):
return self.artifact_hash_columns.get(artifact_name)
@property
def registered_artifacts(self):
return self._artifacts_dict
return self._artifacts_container
@property
def summary(self):
@ -237,7 +247,7 @@ class Artifacts(object):
self._task = task
# notice the double link, this important since the Artifact
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._artifacts_container = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
self._unregister_request = set()
self._thread = None
@ -249,13 +259,21 @@ class Artifacts(object):
self._task_edit_lock = RLock()
self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None):
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
"""
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
:param dict metadata: dictionary of key value to store with the artifact (visible in the UI)
:param list uniqueness_columns: list of columns for artifact uniqueness comparison criteria. The default value
is True, which equals to all the columns (same as artifact.columns).
"""
# currently we support pandas.DataFrame (which we will upload as csv.gz)
if name in self._artifacts_dict:
if name in self._artifacts_container:
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
self._artifacts_dict[name] = artifact
self._artifacts_container.add_hash_columns(name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns))
self._artifacts_container[name] = artifact
if metadata:
self._artifacts_dict.add_metadata(name, metadata)
self._artifacts_container.add_metadata(name, metadata)
def unregister_artifact(self, name):
# Remove artifact from the watch list
@ -268,7 +286,7 @@ class Artifacts(object):
'please upgrade to the latest server version')
return False
if name in self._artifacts_dict:
if name in self._artifacts_container:
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData()
@ -453,7 +471,7 @@ class Artifacts(object):
def _start(self):
""" Start daemon thread if any artifacts are registered and thread is not up yet """
if not self._thread and self._artifacts_dict:
if not self._thread and self._artifacts_container:
# start the daemon thread
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
@ -464,7 +482,7 @@ class Artifacts(object):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
artifact_keys = list(self._artifacts_dict.keys())
artifact_keys = list(self._artifacts_container.keys())
for name in artifact_keys:
try:
self._upload_data_audit_artifacts(name)
@ -476,8 +494,8 @@ class Artifacts(object):
def _upload_data_audit_artifacts(self, name):
logger = self._task.get_logger()
pd_artifact = self._artifacts_dict.get(name)
pd_metadata = self._artifacts_dict.get_metadata(name)
pd_artifact = self._artifacts_container.get(name)
pd_metadata = self._artifacts_container.get_metadata(name)
# remove from artifacts watch list
if name in self._unregister_request:
@ -485,7 +503,7 @@ class Artifacts(object):
self._unregister_request.remove(name)
except KeyError:
pass
self._artifacts_dict.unregister_artifact(name)
self._artifacts_container.unregister_artifact(name)
if pd_artifact is None:
return
@ -574,16 +592,39 @@ class Artifacts(object):
def _get_statistics(self, artifacts_dict=None):
summary = ''
artifacts_dict = artifacts_dict or self._artifacts_dict
artifacts_dict = artifacts_dict or self._artifacts_container
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
for a_name, a_df in artifacts_dict.items():
hash_cols = self._artifacts_container.get_hash_columns(a_name)
if not pd or not isinstance(a_df, pd.DataFrame):
continue
if hash_cols is True:
hash_col_drop = []
else:
hash_cols = set(hash_cols)
missing_cols = hash_cols.difference(a_df.columns)
if missing_cols == hash_cols:
LoggerRoot.get_base_logger().warning(
'Uniqueness columns {} not found in artifact {}. '
'Skipping uniqueness check for artifact.'.format(list(missing_cols), a_name)
)
continue
elif missing_cols:
# missing_cols must be a subset of hash_cols
hash_cols.difference_update(missing_cols)
LoggerRoot.get_base_logger().warning(
'Uniqueness columns {} not found in artifact {}. Using {}.'.format(
list(missing_cols), a_name, list(hash_cols)
)
)
hash_col_drop = [col for col in a_df.columns if col not in hash_cols]
a_unique_hash = set()
def hash_row(r):
@ -591,7 +632,8 @@ class Artifacts(object):
a_shape = a_df.shape
# parallelize
thread_pool.map(hash_row, a_df.values)
a_hash_cols = a_df.drop(columns=hash_col_drop)
thread_pool.map(hash_row, a_hash_cols.values)
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))

View File

@ -7,8 +7,11 @@ import time
from argparse import ArgumentParser
from tempfile import mkstemp
from pathlib2 import Path
from collections import OrderedDict, Callable
try:
from collections.abc import Sequence
except ImportError:
from collections import Sequence
from typing import Optional
import psutil
@ -703,7 +706,7 @@ class Task(_Task):
if self.is_main_task():
self.__register_at_exit(None)
def register_artifact(self, name, artifact, metadata=None):
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
"""
Add artifact for the current Task, used mostly for Data Audition.
Currently supported artifacts object types: pandas.DataFrame
@ -711,8 +714,14 @@ class Task(_Task):
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
:param dict metadata: dictionary of key value to store with the artifact (visible in the UI)
:param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria.
The default value is True, which equals to all the columns (same as artifact.columns).
"""
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata)
if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True:
raise ValueError('uniqueness_columns should be a sequence or True')
if isinstance(uniqueness_columns, str):
uniqueness_columns = [uniqueness_columns]
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
def unregister_artifact(self, name):
"""