mirror of
https://github.com/clearml/clearml
synced 2025-02-07 21:33:25 +00:00
Calculate data-audit artifact uniqueness by user-criteria
This commit is contained in:
parent
a169b43885
commit
bc33ad0da3
@ -2,24 +2,26 @@ import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from zipfile import ZipFile, ZIP_DEFLATED
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from multiprocessing import RLock, Event
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from tempfile import mkdtemp, mkstemp
|
||||
from threading import Thread
|
||||
from multiprocessing import RLock, Event
|
||||
from time import time
|
||||
from zipfile import ZipFile, ZIP_DEFLATED
|
||||
|
||||
import humanfriendly
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
from PIL import Image
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from ..backend_interface.metrics.events import UploadEvent
|
||||
from ..backend_api import Session
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..backend_api.services import tasks
|
||||
from ..backend_interface.metrics.events import UploadEvent
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..storage.helper import remote_driver_schemes
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
@ -204,6 +206,8 @@ class Artifacts(object):
|
||||
self._artifacts_manager = artifacts_manager
|
||||
# list of artifacts we should not upload (by name & weak-reference)
|
||||
self.artifact_metadata = {}
|
||||
# list of hash columns to calculate uniqueness for the artifacts
|
||||
self.artifact_hash_columns = {}
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# check that value is of type pandas
|
||||
@ -225,9 +229,15 @@ class Artifacts(object):
|
||||
def get_metadata(self, name):
|
||||
return self.artifact_metadata.get(name)
|
||||
|
||||
def add_hash_columns(self, artifact_name, hash_columns):
|
||||
self.artifact_hash_columns[artifact_name] = hash_columns
|
||||
|
||||
def get_hash_columns(self, artifact_name):
|
||||
return self.artifact_hash_columns.get(artifact_name)
|
||||
|
||||
@property
|
||||
def registered_artifacts(self):
|
||||
return self._artifacts_dict
|
||||
return self._artifacts_container
|
||||
|
||||
@property
|
||||
def summary(self):
|
||||
@ -237,7 +247,7 @@ class Artifacts(object):
|
||||
self._task = task
|
||||
# notice the double link, this important since the Artifact
|
||||
# dictionary needs to signal the Artifacts base on changes
|
||||
self._artifacts_dict = self._ProxyDictWrite(self)
|
||||
self._artifacts_container = self._ProxyDictWrite(self)
|
||||
self._last_artifacts_upload = {}
|
||||
self._unregister_request = set()
|
||||
self._thread = None
|
||||
@ -249,13 +259,21 @@ class Artifacts(object):
|
||||
self._task_edit_lock = RLock()
|
||||
self._storage_prefix = None
|
||||
|
||||
def register_artifact(self, name, artifact, metadata=None):
|
||||
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
|
||||
"""
|
||||
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
|
||||
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
|
||||
:param dict metadata: dictionary of key value to store with the artifact (visible in the UI)
|
||||
:param list uniqueness_columns: list of columns for artifact uniqueness comparison criteria. The default value
|
||||
is True, which equals to all the columns (same as artifact.columns).
|
||||
"""
|
||||
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
||||
if name in self._artifacts_dict:
|
||||
if name in self._artifacts_container:
|
||||
LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name))
|
||||
self._artifacts_dict[name] = artifact
|
||||
self._artifacts_container.add_hash_columns(name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns))
|
||||
self._artifacts_container[name] = artifact
|
||||
if metadata:
|
||||
self._artifacts_dict.add_metadata(name, metadata)
|
||||
self._artifacts_container.add_metadata(name, metadata)
|
||||
|
||||
def unregister_artifact(self, name):
|
||||
# Remove artifact from the watch list
|
||||
@ -268,7 +286,7 @@ class Artifacts(object):
|
||||
'please upgrade to the latest server version')
|
||||
return False
|
||||
|
||||
if name in self._artifacts_dict:
|
||||
if name in self._artifacts_container:
|
||||
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
|
||||
|
||||
artifact_type_data = tasks.ArtifactTypeData()
|
||||
@ -453,7 +471,7 @@ class Artifacts(object):
|
||||
|
||||
def _start(self):
|
||||
""" Start daemon thread if any artifacts are registered and thread is not up yet """
|
||||
if not self._thread and self._artifacts_dict:
|
||||
if not self._thread and self._artifacts_container:
|
||||
# start the daemon thread
|
||||
self._flush_event.clear()
|
||||
self._thread = Thread(target=self._daemon)
|
||||
@ -464,7 +482,7 @@ class Artifacts(object):
|
||||
while not self._exit_flag:
|
||||
self._flush_event.wait(self._flush_frequency_sec)
|
||||
self._flush_event.clear()
|
||||
artifact_keys = list(self._artifacts_dict.keys())
|
||||
artifact_keys = list(self._artifacts_container.keys())
|
||||
for name in artifact_keys:
|
||||
try:
|
||||
self._upload_data_audit_artifacts(name)
|
||||
@ -476,8 +494,8 @@ class Artifacts(object):
|
||||
|
||||
def _upload_data_audit_artifacts(self, name):
|
||||
logger = self._task.get_logger()
|
||||
pd_artifact = self._artifacts_dict.get(name)
|
||||
pd_metadata = self._artifacts_dict.get_metadata(name)
|
||||
pd_artifact = self._artifacts_container.get(name)
|
||||
pd_metadata = self._artifacts_container.get_metadata(name)
|
||||
|
||||
# remove from artifacts watch list
|
||||
if name in self._unregister_request:
|
||||
@ -485,7 +503,7 @@ class Artifacts(object):
|
||||
self._unregister_request.remove(name)
|
||||
except KeyError:
|
||||
pass
|
||||
self._artifacts_dict.unregister_artifact(name)
|
||||
self._artifacts_container.unregister_artifact(name)
|
||||
|
||||
if pd_artifact is None:
|
||||
return
|
||||
@ -574,16 +592,39 @@ class Artifacts(object):
|
||||
|
||||
def _get_statistics(self, artifacts_dict=None):
|
||||
summary = ''
|
||||
artifacts_dict = artifacts_dict or self._artifacts_dict
|
||||
artifacts_dict = artifacts_dict or self._artifacts_container
|
||||
thread_pool = ThreadPool()
|
||||
|
||||
try:
|
||||
# build hash row sets
|
||||
artifacts_summary = []
|
||||
for a_name, a_df in artifacts_dict.items():
|
||||
hash_cols = self._artifacts_container.get_hash_columns(a_name)
|
||||
if not pd or not isinstance(a_df, pd.DataFrame):
|
||||
continue
|
||||
|
||||
if hash_cols is True:
|
||||
hash_col_drop = []
|
||||
else:
|
||||
hash_cols = set(hash_cols)
|
||||
missing_cols = hash_cols.difference(a_df.columns)
|
||||
if missing_cols == hash_cols:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
'Uniqueness columns {} not found in artifact {}. '
|
||||
'Skipping uniqueness check for artifact.'.format(list(missing_cols), a_name)
|
||||
)
|
||||
continue
|
||||
elif missing_cols:
|
||||
# missing_cols must be a subset of hash_cols
|
||||
hash_cols.difference_update(missing_cols)
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
'Uniqueness columns {} not found in artifact {}. Using {}.'.format(
|
||||
list(missing_cols), a_name, list(hash_cols)
|
||||
)
|
||||
)
|
||||
|
||||
hash_col_drop = [col for col in a_df.columns if col not in hash_cols]
|
||||
|
||||
a_unique_hash = set()
|
||||
|
||||
def hash_row(r):
|
||||
@ -591,7 +632,8 @@ class Artifacts(object):
|
||||
|
||||
a_shape = a_df.shape
|
||||
# parallelize
|
||||
thread_pool.map(hash_row, a_df.values)
|
||||
a_hash_cols = a_df.drop(columns=hash_col_drop)
|
||||
thread_pool.map(hash_row, a_hash_cols.values)
|
||||
# add result
|
||||
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
|
||||
|
||||
|
@ -7,8 +7,11 @@ import time
|
||||
from argparse import ArgumentParser
|
||||
from tempfile import mkstemp
|
||||
|
||||
from pathlib2 import Path
|
||||
from collections import OrderedDict, Callable
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except ImportError:
|
||||
from collections import Sequence
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
@ -703,7 +706,7 @@ class Task(_Task):
|
||||
if self.is_main_task():
|
||||
self.__register_at_exit(None)
|
||||
|
||||
def register_artifact(self, name, artifact, metadata=None):
|
||||
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
|
||||
"""
|
||||
Add artifact for the current Task, used mostly for Data Audition.
|
||||
Currently supported artifacts object types: pandas.DataFrame
|
||||
@ -711,8 +714,14 @@ class Task(_Task):
|
||||
:param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists.
|
||||
:param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame
|
||||
:param dict metadata: dictionary of key value to store with the artifact (visible in the UI)
|
||||
:param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria.
|
||||
The default value is True, which equals to all the columns (same as artifact.columns).
|
||||
"""
|
||||
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata)
|
||||
if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True:
|
||||
raise ValueError('uniqueness_columns should be a sequence or True')
|
||||
if isinstance(uniqueness_columns, str):
|
||||
uniqueness_columns = [uniqueness_columns]
|
||||
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
|
||||
|
||||
def unregister_artifact(self, name):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user