Changed report_image matrix argument to image (with backwards support)

This commit is contained in:
allegroai 2019-10-10 21:09:44 +03:00
parent c0cfe3ccb2
commit 0b875a2dea
3 changed files with 34 additions and 16 deletions

View File

@ -194,7 +194,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
variant=self._normalize_name(series), iter=iter, src=src)
self._report(ev)
def report_image_and_upload(self, title, series, iter, path=None, matrix=None, upload_uri=None,
def report_image_and_upload(self, title, series, iter, path=None, image=None, upload_uri=None,
max_image_history=None, delete_after_upload=False):
"""
Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with
@ -204,11 +204,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant)
:type series: str
:param iter: Iteration number
:type value: int
:type iter: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param image: Image data. Required unless filename is provided.
:type image: A PIL.Image.Image object or a 3D numpy.ndarray object
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
@ -216,11 +216,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
"""
if not self._storage_uri and not upload_uri:
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, matrix) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, matrix]')
if len([x for x in (path, image) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, image]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path,
ev = ImageEvent(image_data=image, upload_uri=upload_uri, local_image_path=path,
delete_after_upload=delete_after_upload, **kwargs)
self._report(ev)
@ -512,7 +512,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
# Hack: if the url doesn't start with http/s then the plotly will not be able to show it,
# then we put the link under images not plots
if not url.startswith('http'):
return self.report_image_and_upload(title=title, series=series, iter=iter, path=path, matrix=matrix,
return self.report_image_and_upload(title=title, series=series, iter=iter, path=path, image=matrix,
upload_uri=upload_uri, max_image_history=max_image_history)
self._report(ev)

View File

@ -236,7 +236,7 @@ class EventTrainsWriter(object):
title=title,
series=series,
iteration=step,
matrix=img_data_np,
image=img_data_np,
max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images,
)

View File

@ -1,6 +1,8 @@
import logging
import warnings
import numpy as np
from PIL import Image
from pathlib2 import Path
from .backend_interface.logger import StdStreamPatch, LogFlusher
@ -13,6 +15,9 @@ from .backend_api.services import tasks
from .backend_interface.task import Task as _Task
from .config import running_remotely, get_cache_dir
# Make sure that DeprecationWarning within this package always gets printed
warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
class Logger(object):
"""
@ -324,7 +329,7 @@ class Logger(object):
comment=comment,
)
def report_image(self, title, series, iteration, local_path=None, matrix=None, max_image_history=None,
def report_image(self, title, series, iteration, local_path=None, image=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Report an image and upload its contents.
@ -336,13 +341,22 @@ class Logger(object):
:param str series: Series (AKA variant)
:param int iteration: Iteration number
:param str local_path: A path to an image file. Required unless matrix is provided.
Required unless matrix is provided.
:param np.ndarray or PIL.Image.Image image: Could be a PIL.Image.Image object or a 3D numpy.ndarray
object containing image data (RGB).
:param np.ndarray matrix: A 3D numpy.ndarray object containing image data (RGB).
Required unless filename is provided.
This is deprecated, use image variable instead.
:param int max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted
"""
if matrix is not None:
warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning)
if len([x for x in (matrix, image, local_path) if x is not None]) != 1:
raise ValueError('Expected only one of [image, matrix, local_path]')
if image is None:
image = matrix
if image is not None and not isinstance(image, (np.ndarray, Image.Image)):
raise ValueError("Supported 'image' types are: numpy.ndarray or PIL.Image")
# if task was not started, we have to start it
self._start_task_if_needed()
@ -355,11 +369,15 @@ class Logger(object):
storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri)
self._touch_title_series(title, series)
if isinstance(image, Image.Image):
image = np.array(image)
self._task.reporter.report_image_and_upload(
title=title,
series=series,
path=local_path,
matrix=matrix,
image=image,
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_image_history,
@ -443,9 +461,9 @@ class Logger(object):
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Backwards compatibility, please use report_image instead
Deprecated: Backwards compatibility, please use report_image instead
"""
self.report_image(title=title, series=series, iteration=iteration, local_path=path, matrix=matrix,
self.report_image(title=title, series=series, iteration=iteration, local_path=path, image=matrix,
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
@classmethod
@ -591,7 +609,7 @@ class Logger(object):
title=title,
series=series,
path=path,
matrix=None,
image=None,
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_file_history,