Add initial artifacts support

Fix sub-process support
Fix delete_after_upload option when uploading images
Add logugu support
Fix subsample plots if they are too big
Fix requests for over 15mb
This commit is contained in:
allegroai 2019-07-28 21:04:45 +03:00
parent 637269a340
commit 62bc54d7be
12 changed files with 462 additions and 74 deletions

View File

@ -6,7 +6,7 @@ api {
# Notice: 'api_server' is the api server (default port 8008), not the web server.
api_server: "http://localhost:8008"
# file server onport 8081
# file server on port 8081
files_server: "http://localhost:8081"
# Credentials are generated in the webapp, http://localhost:8080/admin

View File

@ -15,6 +15,13 @@ logging.warning('This is a warning message')
logging.error('This is an error message')
logging.critical('This is a critical message')
# this is loguru test example
try:
from loguru import logger
logger.debug("That's it, beautiful and simple logging! (using ANSI colors)")
except ImportError:
pass
# get TRAINS logger object for any metrics / reports
logger = task.get_logger()

View File

@ -23,6 +23,10 @@ class LoginError(Exception):
pass
class MaxRequestSizeError(Exception):
pass
class Session(TokenManager):
""" TRAINS API Session class. """
@ -34,7 +38,9 @@ class Session(TokenManager):
_async_status_code = 202
_session_requests = 0
_session_initial_timeout = (3.0, 10.)
_session_timeout = (5.0, 300.)
_session_timeout = (10.0, 300.)
_write_session_data_size = 15000
_write_session_timeout = (300.0, 300.)
api_version = '2.1'
default_host = "https://demoapi.trainsai.io"
@ -181,10 +187,15 @@ class Session(TokenManager):
else "{host}/{service}.{action}"
).format(**locals())
while True:
if data and len(data) > self._write_session_data_size:
timeout = self._write_session_timeout
elif self._session_requests < 1:
timeout = self._session_initial_timeout
else:
timeout = self._session_timeout
res = self.__http_session.request(
method, url, headers=headers, auth=auth, data=data, json=json,
timeout=self._session_initial_timeout if self._session_requests < 1 else self._session_timeout,
)
method, url, headers=headers, auth=auth, data=data, json=json, timeout=timeout)
if (
refresh_token_if_unauthorized
and res.status_code == requests.codes.unauthorized
@ -294,7 +305,7 @@ class Session(TokenManager):
results = []
while True:
size = self.__max_req_size
slice = req_data[cur : cur + size]
slice = req_data[cur: cur + size]
if not slice:
break
if len(slice) < size:
@ -304,7 +315,10 @@ class Session(TokenManager):
# search for the last newline in order to send a coherent request
size = slice.rfind("\n") + 1
# readjust the slice
slice = req_data[cur : cur + size]
slice = req_data[cur: cur + size]
if not slice:
raise MaxRequestSizeError('Error: {}.{} request exceeds limit {} > {} bytes'.format(
service, action, len(req_data), self.__max_req_size))
res = self.send_request(
method=method,
service=service,

View File

@ -2,7 +2,10 @@ import abc
import requests.exceptions
import six
from ..backend_api import Session
from ..backend_api import Session, CallResult
from ..backend_api.session.session import MaxRequestSizeError
from ..backend_api.session.response import ResponseMeta
from ..backend_api.session import BatchRequest
from ..backend_api.session.defs import ENV_ACCESS_KEY, ENV_SECRET_KEY
@ -42,6 +45,7 @@ class InterfaceBase(SessionInterface):
def _send(cls, session, req, ignore_errors=False, raise_on_errors=True, log=None, async_enable=False):
""" Convenience send() method providing a standardized error reporting """
while True:
error_msg = ''
try:
res = session.send(req, async_enable=async_enable)
if res.meta.result_code in (200, 202) or ignore_errors:
@ -58,6 +62,9 @@ class InterfaceBase(SessionInterface):
except requests.exceptions.BaseHTTPError as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
except MaxRequestSizeError as e:
res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e)))
error_msg = 'Failed sending: %s' % str(e)
except Exception as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))

View File

@ -46,6 +46,9 @@ class MetricsEventAdapter(object):
exception = attr.attrib(default=None)
delete_local_file = attr.attrib(default=True)
""" Local file path, if exists, delete the file after upload completed """
def set_exception(self, exp):
self.exception = exp
self.event.upload_exception = exp
@ -162,7 +165,7 @@ class ImageEventNoUpload(MetricsEventAdapter):
**self._get_base_dict())
class ImageEvent(MetricsEventAdapter):
class UploadEvent(MetricsEventAdapter):
""" Image event adapter """
_format = '.' + str(config.get('metrics.images.format', 'JPEG')).upper().lstrip('.')
_quality = int(config.get('metrics.images.quality', 87))
@ -173,7 +176,7 @@ class ImageEvent(MetricsEventAdapter):
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, **kwargs):
image_file_history_size=None, delete_after_upload=False, **kwargs):
if image_data is not None and not hasattr(image_data, 'shape'):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
@ -188,13 +191,14 @@ class ImageEvent(MetricsEventAdapter):
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._upload_uri = upload_uri
self._delete_after_upload = delete_after_upload
# get upload uri upfront
# get upload uri upfront, either predefined image format or local file extension
# e.g.: image.png -> .png or image.raw.gz -> .raw.gz
image_format = self._format.lower() if self._image_data is not None else \
pathlib2.Path(self._local_image_path).suffix
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs)
super(UploadEvent, self).__init__(metric, variant, iter=iter, **kwargs)
@classmethod
def _get_metric_count(cls, metric, variant, next=True):
@ -210,20 +214,19 @@ class ImageEvent(MetricsEventAdapter):
finally:
cls._metric_counters_lock.release()
# return No event (just the upload)
def get_api_event(self):
return events.MetricsImageEvent(
url=self._url,
key=self._key,
**self._get_base_dict())
return None
def update(self, url=None, key=None, **kwargs):
super(ImageEvent, self).update(**kwargs)
super(UploadEvent, self).update(**kwargs)
if url is not None:
self._url = url
if key is not None:
self._key = key
def get_file_entry(self):
local_file = None
# don't provide file in case this event is out of the history window
last_count = self._get_metric_count(self.metric, self.variant, next=False)
if abs(self._count - last_count) > self._image_file_history_size:
@ -253,9 +256,14 @@ class ImageEvent(MetricsEventAdapter):
output = six.BytesIO(img_bytes.tostring())
output.seek(0)
else:
with open(self._local_image_path, 'rb') as f:
output = six.BytesIO(f.read())
output.seek(0)
local_file = self._local_image_path
try:
output = open(local_file, 'rb')
except Exception as e:
# something happened to the file, we should skip it
from ...debugging.log import LoggerRoot
LoggerRoot.get_base_logger().warning(str(e))
return None
return self.FileEntry(
event=self,
@ -263,7 +271,8 @@ class ImageEvent(MetricsEventAdapter):
stream=output,
url_prop='url',
key_prop='key',
upload_uri=self._upload_uri
upload_uri=self._upload_uri,
delete_local_file=local_file if self._delete_after_upload else None,
)
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
@ -273,3 +282,18 @@ class ImageEvent(MetricsEventAdapter):
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
return key, url
class ImageEvent(UploadEvent):
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, delete_after_upload=False, **kwargs):
super(ImageEvent, self).__init__(metric, variant, image_data=image_data, local_image_path=local_image_path,
iter=iter, upload_uri=upload_uri,
image_file_history_size=image_file_history_size,
delete_after_upload=delete_after_upload, **kwargs)
def get_api_event(self):
return events.MetricsImageEvent(
url=self._url,
key=self._key,
**self._get_base_dict())

View File

@ -4,6 +4,8 @@ from threading import Lock
from time import time
from humanfriendly import format_timespan
from pathlib2 import Path
from ...backend_api.services import events as api_events
from ..base import InterfaceBase
from ...config import config
@ -150,12 +152,18 @@ class Metrics(InterfaceBase):
url = storage.upload_from_stream(e.stream, e.url)
e.event.update(url=url)
except Exception as exp:
log.debug("Failed uploading to {} ({})".format(
log.warning("Failed uploading to {} ({})".format(
upload_uri if upload_uri else "(Could not calculate upload uri)",
exp,
))
e.set_exception(exp)
e.stream.close()
if e.delete_local_file:
try:
Path(e.delete_local_file).unlink()
except Exception:
pass
res = file_upload_pool.map_async(upload, entries)
res.wait()
@ -180,8 +188,10 @@ class Metrics(InterfaceBase):
))
if good_events:
batched_requests = [api_events.AddRequest(event=ev.get_api_event()) for ev in good_events]
req = api_events.AddBatchRequest(requests=batched_requests)
return self.send(req, raise_on_errors=False)
_events = [ev.get_api_event() for ev in good_events]
batched_requests = [api_events.AddRequest(event=ev) for ev in _events if ev]
if batched_requests:
req = api_events.AddBatchRequest(requests=batched_requests)
return self.send(req, raise_on_errors=False)
return None

View File

@ -11,7 +11,7 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
create_image_plot
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
@ -183,7 +183,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._report(ev)
def report_image_and_upload(self, title, series, iter, path=None, matrix=None, upload_uri=None,
max_image_history=None):
max_image_history=None, delete_after_upload=False):
"""
Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with
a key (filename) describing the task ID, title, series and iteration.
@ -199,6 +199,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
:type delete_after_upload: boolean
"""
if not self._storage_uri and not upload_uri:
raise ValueError('Upload configuration is required (use setup_upload())')
@ -206,7 +208,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path,
delete_after_upload=delete_after_upload, **kwargs)
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
@ -463,7 +466,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
)
def report_image_plot_and_upload(self, title, series, iter, path=None, matrix=None,
upload_uri=None, max_image_history=None):
upload_uri=None, max_image_history=None, delete_after_upload=False):
"""
Report an image as plot and upload its contents.
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
@ -481,6 +484,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
:type delete_after_upload: boolean
"""
if not upload_uri and not self._storage_uri:
raise ValueError('Upload configuration is required (use setup_upload())')
@ -488,8 +493,16 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path, **kwargs)
ev = UploadEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path,
delete_after_upload=delete_after_upload, **kwargs)
_, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)
# Hack: if the url doesn't start with http/s then the plotly will not be able to show it,
# then we put the link under images not plots
if not url.startswith('http'):
return self.report_image_and_upload(title=title, series=series, iter=iter, path=path, matrix=matrix,
upload_uri=upload_uri, max_image_history=max_image_history)
self._report(ev)
plotly_dict = create_image_plot(
image_src=url,

View File

@ -180,35 +180,41 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _update_repository(self):
def check_package_update():
# check latest version
from ...utilities.check_updates import CheckPackageUpdates
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
self.get_logger().console(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
)
else:
self.get_logger().console(
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
latest_version[0]),
)
try:
# check latest version
from ...utilities.check_updates import CheckPackageUpdates
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
self.get_logger().console(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
)
else:
self.get_logger().console(
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
latest_version[0]),
)
except Exception:
pass
check_package_update_thread = Thread(target=check_package_update)
check_package_update_thread.daemon = True
check_package_update_thread.start()
result = ScriptInfo.get(log=self.log)
for msg in result.warning_messages:
self.get_logger().console(msg)
try:
check_package_update_thread = Thread(target=check_package_update)
check_package_update_thread.daemon = True
check_package_update_thread.start()
result = ScriptInfo.get(log=self.log)
for msg in result.warning_messages:
self.get_logger().console(msg)
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (lest someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '')
check_package_update_thread.join()
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (lest someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
self._update_requirements(result.script.get('requirements') if result.script.get('requirements') else '')
check_package_update_thread.join()
except Exception as e:
get_logger('task').warning(str(e))
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
@ -663,6 +669,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._project_name[1]
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
if not res or not res.response or not res.response.project:
return None
self._project_name = (self.project, res.response.project.name)
return self._project_name[1]

202
trains/binding/artifacts.py Normal file
View File

@ -0,0 +1,202 @@
import os
import weakref
import numpy as np
import hashlib
from tempfile import mkstemp, mkdtemp
from threading import Thread, Event
from multiprocessing.pool import ThreadPool
from pathlib2 import Path
from ..debugging.log import LoggerRoot
try:
import pandas as pd
except ImportError:
pd = None
class Artifacts(object):
_flush_frequency_sec = 300.
# notice these two should match
_save_format = '.csv.gz'
_compression = 'gzip'
# hashing constants
_hash_block_size = 65536
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, artifacts_manager, *args, **kwargs):
super(Artifacts._ProxyDictWrite, self).__init__(*args, **kwargs)
self._artifacts_manager = artifacts_manager
# list of artifacts we should not upload (by name & weak-reference)
self.local_artifacts = {}
def __setitem__(self, key, value):
# check that value is of type pandas
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
self._artifacts_manager.flush()
else:
raise ValueError('Artifacts currently supports pandas.DataFrame objects only')
def disable_upload(self, name):
if name in self.keys():
self.local_artifacts[name] = weakref.ref(self.get(name))
def do_upload(self, name):
# return True is this artifact should be uploaded
return name not in self.local_artifacts or self.local_artifacts[name] != self.get(name)
@property
def artifacts(self):
return self._artifacts_dict
@property
def summary(self):
return self._summary
def __init__(self, task):
self._task = task
# notice the double link, this important since the Artifact
# dictionary needs to signal the Artifacts base on changes
self._artifacts_dict = self._ProxyDictWrite(self)
self._last_artifacts_upload = {}
self._thread = None
self._flush_event = Event()
self._exit_flag = False
self._thread_pool = ThreadPool()
self._summary = ''
self._temp_folder = []
def add_artifact(self, name, artifact, upload=True):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
# or numpy array, which we will upload as npz
self._artifacts_dict[name] = artifact
if not upload:
self._artifacts_dict.disable_upload(name)
def flush(self):
# start the thread if it hasn't already:
self._start()
# flush the current state of all artifacts
self._flush_event.set()
def stop(self, wait=True):
# stop the daemon thread and quit
# wait until thread exists
self._exit_flag = True
self._flush_event.set()
if wait:
if self._thread:
self._thread.join()
# remove all temp folders
for f in self._temp_folder:
try:
Path(f).rmdir()
except Exception:
pass
def _start(self):
if not self._thread:
# start the daemon thread
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
try:
self._upload_artifacts()
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
def _upload_artifacts(self):
logger = self._task.get_logger()
for name, artifact in self._artifacts_dict.items():
if not self._artifacts_dict.do_upload(name):
# only register artifacts, and leave, TBD
continue
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
continue
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
def _get_statistics(self):
summary = ''
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
for a_name, a_df in self._artifacts_dict.items():
if not pd or not isinstance(a_df, pd.DataFrame):
continue
a_unique_hash = set()
def hash_row(r):
a_unique_hash.add(hash(bytes(r)))
a_shape = a_df.shape
# parallelize
thread_pool.map(hash_row, a_df.values)
# add result
artifacts_summary.append((a_name, a_shape, a_unique_hash,))
# build intersection summary
for i, (name, shape, unique_hash) in enumerate(artifacts_summary):
summary += '[{name}]: shape={shape}, {unique} unique rows, {percentage:.1f}% uniqueness\n'.format(
name=name, shape=shape, unique=len(unique_hash), percentage=100*len(unique_hash)/float(shape[0]))
for name2, shape2, unique_hash2 in artifacts_summary[i+1:]:
intersection = len(unique_hash & unique_hash2)
summary += '\tIntersection with [{name2}] {intersection} rows: {percentage:.1f}%\n'.format(
name2=name2, intersection=intersection, percentage=100*intersection/float(len(unique_hash2)))
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
finally:
thread_pool.close()
thread_pool.terminate()
return summary
def _get_temp_folder(self, force_new=False):
if force_new or not self._temp_folder:
new_temp = mkdtemp(prefix='artifacts_')
self._temp_folder.append(new_temp)
return new_temp
return self._temp_folder[0]
@staticmethod
def sha256sum(filename, skip_header=0):
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()
b = bytearray(Artifacts._hash_block_size)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
# skip header
f.read(skip_header)
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()

View File

@ -94,6 +94,39 @@ class Logger(object):
pass
sys.stdout = Logger._stdout_proxy
sys.stderr = Logger._stderr_proxy
# patch the base streams of sys (this way colorama will keep its ANSI colors)
# noinspection PyBroadException
try:
sys.__stderr__ = sys.stderr
except Exception:
pass
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
# now check if we have loguru and make it re-register the handlers
# because it sores internally the stream.write function, which we cant patch
# noinspection PyBroadException
try:
from loguru import logger
register_stderr = None
register_stdout = None
for k, v in logger._handlers.items():
if v._name == '<stderr>':
register_stderr = k
elif v._name == '<stdout>':
register_stderr = k
if register_stderr is not None:
logger.remove(register_stderr)
logger.add(sys.stderr)
if register_stdout is not None:
logger.remove(register_stdout)
logger.add(sys.stdout)
except Exception:
pass
elif DevWorker.report_stdout and not running_remotely():
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
if Logger._stdout_proxy:
@ -495,7 +528,8 @@ class Logger(object):
)
@_safe_names
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Report an image and upload its contents.
@ -515,6 +549,8 @@ class Logger(object):
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
:type delete_after_upload: boolean
"""
# if task was not started, we have to start it
@ -536,9 +572,11 @@ class Logger(object):
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_image_history,
delete_after_upload=delete_after_upload,
)
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Report an image, upload its contents, and present in plots section using plotly
@ -558,6 +596,8 @@ class Logger(object):
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
:type delete_after_upload: boolean
"""
# if task was not started, we have to start it
@ -579,6 +619,7 @@ class Logger(object):
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_image_history,
delete_after_upload=delete_after_upload,
)
def set_default_upload_destination(self, uri):

View File

@ -26,6 +26,7 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .binding.artifacts import Artifacts
from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
@ -92,7 +93,7 @@ class Task(_Task):
if private is not Task.__create_protection:
raise UsageError(
'Task object cannot be instantiated externally, use Task.current_task() or Task.get_task(...)')
self._lock = threading.RLock()
self._repo_detect_lock = threading.RLock()
super(Task, self).__init__(**kwargs)
self._arguments = _Arguments(self)
@ -103,6 +104,7 @@ class Task(_Task):
self._connected_parameter_type = None
self._detect_repo_async_thread = None
self._resource_monitor = None
self._artifacts_manager = Artifacts(self)
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
@ -467,6 +469,14 @@ class Task(_Task):
def output_uri(self, value):
self.storage_uri = value
@property
def artifacts(self):
"""
dictionary of Task artifacts (name, artifact)
:return: dict
"""
return self._artifacts_manager.artifacts
def set_comment(self, comment):
"""
Set a comment text to the task.
@ -579,15 +589,6 @@ class Task(_Task):
:param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
:return: True
"""
# wait for detection repo sync
if self._detect_repo_async_thread:
with self._lock:
if self._detect_repo_async_thread:
try:
self._detect_repo_async_thread.join()
self._detect_repo_async_thread = None
except Exception:
pass
# make sure model upload is done
if BackendModel.get_num_results() > 0 and wait_for_uploads:
@ -625,6 +626,17 @@ class Task(_Task):
if self.is_main_task():
self.__register_at_exit(None)
def add_artifact(self, name, artifact):
"""
Add artifact for the current Task, used mostly for Data Audition.
Currently supported artifacts object types: pandas.DataFrame
:param name: name of the artifacts. can override previous artifacts if name already exists
:type name: str
:param artifact: artifact object, supported artifacts object types: pandas.DataFrame
:type artifact: pandas.DataFrame
"""
self._artifacts_manager.add_artifact(name=name, artifact=artifact)
def is_current_task(self):
"""
Check if this task is the main task (returned by Task.init())
@ -933,6 +945,24 @@ class Task(_Task):
if not flush_period or flush_period > self._dev_worker.report_period:
logger.set_flush_period(self._dev_worker.report_period)
def _wait_for_repo_detection(self, timeout=None):
# wait for detection repo sync
if self._detect_repo_async_thread:
with self._repo_detect_lock:
if self._detect_repo_async_thread:
try:
if self._detect_repo_async_thread.is_alive():
self._detect_repo_async_thread.join(timeout=timeout)
self._detect_repo_async_thread = None
except Exception:
pass
def _summary_artifacts(self):
# signal artifacts upload, and stop daemon
self._artifacts_manager.stop(wait=True)
# print artifacts summary
self.get_logger().console(self._artifacts_manager.summary)
def _at_exit(self):
"""
Will happen automatically once we exit code, i.e. atexit
@ -967,6 +997,13 @@ class Task(_Task):
else:
task_status = ('stopped', )
# wait for repository detection (if we didn't crash)
if not is_sub_process and wait_for_uploads:
# we should print summary here
self._summary_artifacts()
# make sure that if we crashed the thread we are not waiting forever
self._wait_for_repo_detection(timeout=10.)
# wait for uploads
print_done_waiting = False
if wait_for_uploads and (BackendModel.get_num_results() > 0 or self.reporter.get_num_results() > 0):
@ -1035,7 +1072,6 @@ class Task(_Task):
self.hook()
else:
# un register int hook
print('removing int hook', self._orig_exc_handler)
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None

View File

@ -71,6 +71,32 @@ def create_line_plot(title, series, xtitle, ytitle, mode='lines', reverse_xaxis=
if reverse_xaxis:
plotly_obj["layout"]["xaxis"]["autorange"] = "reversed"
# check maximum size of data
_MAX_SIZE = 800000
series_sizes = [s.data.size for s in series]
total_size = sum(series_sizes)
if total_size > _MAX_SIZE:
# we need to downscale
base_size = _MAX_SIZE / len(series_sizes)
baseused_size = sum([min(s, base_size) for s in series_sizes])
leftover = _MAX_SIZE - baseused_size
for s in series:
# if we need to down-sample, use low-pass average filter and sampling
if s.data.size >= base_size:
budget = int(leftover * s.data.size/(total_size-baseused_size))
step = int(np.ceil(s.data.size / float(budget)))
x = s.data[:, 0][::-step][::-1]
y = s.data[:, 1]
y_low_pass = np.convolve(y, np.ones(shape=(step,), dtype=y.dtype)/float(step), mode='same')
y = y_low_pass[::-step][::-1]
s.data = np.array([x, y], dtype=s.data.dtype).T
# decide on number of points between mean and max
s_max = np.max(np.abs(s.data), axis=0)
digits = np.maximum(np.array([1, 1]), np.array([6, 6]) - np.floor(np.abs(np.log10(s_max))))
s.data[:, 0] = np.round(s.data[:, 0] * (10 ** digits[0])) / (10 ** digits[0])
s.data[:, 1] = np.round(s.data[:, 1] * (10 ** digits[1])) / (10 ** digits[1])
plotly_obj["data"].extend({
"name": s.name,
"x": s.data[:, 0].tolist(),
@ -251,8 +277,8 @@ def create_image_plot(image_src, title, width=640, height=480, series=None, comm
"layout": {
"xaxis": {"visible": False, "range": [0, width]},
"yaxis": {"visible": False, "range": [0, height]},
"width": width,
"height": height,
# "width": width,
# "height": height,
"margin": {'l': 0, 'r': 0, 't': 0, 'b': 0},
"images": [{
"sizex": width,