From bc33ad0da3f1ead3b0450329f6aa4bfb6b904041 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jan 2020 17:19:44 +0200 Subject: [PATCH] Calculate data-audit artifact uniqueness by user-criteria --- trains/binding/artifacts.py | 80 ++++++++++++++++++++++++++++--------- trains/task.py | 17 ++++++-- 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/trains/binding/artifacts.py b/trains/binding/artifacts.py index 151a48aa..19d8744b 100644 --- a/trains/binding/artifacts.py +++ b/trains/binding/artifacts.py @@ -2,24 +2,26 @@ import hashlib import json import mimetypes import os -from zipfile import ZipFile, ZIP_DEFLATED from copy import deepcopy from datetime import datetime +from multiprocessing import RLock, Event from multiprocessing.pool import ThreadPool from tempfile import mkdtemp, mkstemp from threading import Thread -from multiprocessing import RLock, Event from time import time +from zipfile import ZipFile, ZIP_DEFLATED import humanfriendly import six -from pathlib2 import Path from PIL import Image +from pathlib2 import Path +from six.moves.urllib.parse import urlparse -from ..backend_interface.metrics.events import UploadEvent from ..backend_api import Session -from ..debugging.log import LoggerRoot from ..backend_api.services import tasks +from ..backend_interface.metrics.events import UploadEvent +from ..debugging.log import LoggerRoot +from ..storage.helper import remote_driver_schemes try: import pandas as pd @@ -204,6 +206,8 @@ class Artifacts(object): self._artifacts_manager = artifacts_manager # list of artifacts we should not upload (by name & weak-reference) self.artifact_metadata = {} + # list of hash columns to calculate uniqueness for the artifacts + self.artifact_hash_columns = {} def __setitem__(self, key, value): # check that value is of type pandas @@ -225,9 +229,15 @@ class Artifacts(object): def get_metadata(self, name): return self.artifact_metadata.get(name) + def add_hash_columns(self, artifact_name, hash_columns): + self.artifact_hash_columns[artifact_name] = hash_columns + + def get_hash_columns(self, artifact_name): + return self.artifact_hash_columns.get(artifact_name) + @property def registered_artifacts(self): - return self._artifacts_dict + return self._artifacts_container @property def summary(self): @@ -237,7 +247,7 @@ class Artifacts(object): self._task = task # notice the double link, this important since the Artifact # dictionary needs to signal the Artifacts base on changes - self._artifacts_dict = self._ProxyDictWrite(self) + self._artifacts_container = self._ProxyDictWrite(self) self._last_artifacts_upload = {} self._unregister_request = set() self._thread = None @@ -249,13 +259,21 @@ class Artifacts(object): self._task_edit_lock = RLock() self._storage_prefix = None - def register_artifact(self, name, artifact, metadata=None): + def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): + """ + :param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists. + :param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame + :param dict metadata: dictionary of key value to store with the artifact (visible in the UI) + :param list uniqueness_columns: list of columns for artifact uniqueness comparison criteria. The default value + is True, which equals to all the columns (same as artifact.columns). + """ # currently we support pandas.DataFrame (which we will upload as csv.gz) - if name in self._artifacts_dict: + if name in self._artifacts_container: LoggerRoot.get_base_logger().info('Register artifact, overwriting existing artifact \"{}\"'.format(name)) - self._artifacts_dict[name] = artifact + self._artifacts_container.add_hash_columns(name, list(artifact.columns if uniqueness_columns is True else uniqueness_columns)) + self._artifacts_container[name] = artifact if metadata: - self._artifacts_dict.add_metadata(name, metadata) + self._artifacts_container.add_metadata(name, metadata) def unregister_artifact(self, name): # Remove artifact from the watch list @@ -268,7 +286,7 @@ class Artifacts(object): 'please upgrade to the latest server version') return False - if name in self._artifacts_dict: + if name in self._artifacts_container: raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name)) artifact_type_data = tasks.ArtifactTypeData() @@ -453,7 +471,7 @@ class Artifacts(object): def _start(self): """ Start daemon thread if any artifacts are registered and thread is not up yet """ - if not self._thread and self._artifacts_dict: + if not self._thread and self._artifacts_container: # start the daemon thread self._flush_event.clear() self._thread = Thread(target=self._daemon) @@ -464,7 +482,7 @@ class Artifacts(object): while not self._exit_flag: self._flush_event.wait(self._flush_frequency_sec) self._flush_event.clear() - artifact_keys = list(self._artifacts_dict.keys()) + artifact_keys = list(self._artifacts_container.keys()) for name in artifact_keys: try: self._upload_data_audit_artifacts(name) @@ -476,8 +494,8 @@ class Artifacts(object): def _upload_data_audit_artifacts(self, name): logger = self._task.get_logger() - pd_artifact = self._artifacts_dict.get(name) - pd_metadata = self._artifacts_dict.get_metadata(name) + pd_artifact = self._artifacts_container.get(name) + pd_metadata = self._artifacts_container.get_metadata(name) # remove from artifacts watch list if name in self._unregister_request: @@ -485,7 +503,7 @@ class Artifacts(object): self._unregister_request.remove(name) except KeyError: pass - self._artifacts_dict.unregister_artifact(name) + self._artifacts_container.unregister_artifact(name) if pd_artifact is None: return @@ -574,16 +592,39 @@ class Artifacts(object): def _get_statistics(self, artifacts_dict=None): summary = '' - artifacts_dict = artifacts_dict or self._artifacts_dict + artifacts_dict = artifacts_dict or self._artifacts_container thread_pool = ThreadPool() try: # build hash row sets artifacts_summary = [] for a_name, a_df in artifacts_dict.items(): + hash_cols = self._artifacts_container.get_hash_columns(a_name) if not pd or not isinstance(a_df, pd.DataFrame): continue + if hash_cols is True: + hash_col_drop = [] + else: + hash_cols = set(hash_cols) + missing_cols = hash_cols.difference(a_df.columns) + if missing_cols == hash_cols: + LoggerRoot.get_base_logger().warning( + 'Uniqueness columns {} not found in artifact {}. ' + 'Skipping uniqueness check for artifact.'.format(list(missing_cols), a_name) + ) + continue + elif missing_cols: + # missing_cols must be a subset of hash_cols + hash_cols.difference_update(missing_cols) + LoggerRoot.get_base_logger().warning( + 'Uniqueness columns {} not found in artifact {}. Using {}.'.format( + list(missing_cols), a_name, list(hash_cols) + ) + ) + + hash_col_drop = [col for col in a_df.columns if col not in hash_cols] + a_unique_hash = set() def hash_row(r): @@ -591,7 +632,8 @@ class Artifacts(object): a_shape = a_df.shape # parallelize - thread_pool.map(hash_row, a_df.values) + a_hash_cols = a_df.drop(columns=hash_col_drop) + thread_pool.map(hash_row, a_hash_cols.values) # add result artifacts_summary.append((a_name, a_shape, a_unique_hash,)) diff --git a/trains/task.py b/trains/task.py index 7dec1d8d..5716ce49 100644 --- a/trains/task.py +++ b/trains/task.py @@ -7,8 +7,11 @@ import time from argparse import ArgumentParser from tempfile import mkstemp -from pathlib2 import Path -from collections import OrderedDict, Callable +try: + from collections.abc import Sequence +except ImportError: + from collections import Sequence + from typing import Optional import psutil @@ -703,7 +706,7 @@ class Task(_Task): if self.is_main_task(): self.__register_at_exit(None) - def register_artifact(self, name, artifact, metadata=None): + def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): """ Add artifact for the current Task, used mostly for Data Audition. Currently supported artifacts object types: pandas.DataFrame @@ -711,8 +714,14 @@ class Task(_Task): :param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists. :param pandas.DataFrame artifact: artifact object, supported artifacts object types: pandas.DataFrame :param dict metadata: dictionary of key value to store with the artifact (visible in the UI) + :param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria. + The default value is True, which equals to all the columns (same as artifact.columns). """ - self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata) + if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True: + raise ValueError('uniqueness_columns should be a sequence or True') + if isinstance(uniqueness_columns, str): + uniqueness_columns = [uniqueness_columns] + self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) def unregister_artifact(self, name): """