Allow reporting a pre-uploaded image url in Logger.report_image using the url parameter

This commit is contained in:
allegroai 2020-01-26 15:29:35 +02:00
parent 8772bc2755
commit 923e45bb17
4 changed files with 58 additions and 34 deletions

View File

@ -156,8 +156,8 @@ class PlotEvent(MetricsEventAdapter):
class ImageEventNoUpload(MetricsEventAdapter): class ImageEventNoUpload(MetricsEventAdapter):
def __init__(self, metric, variant, src, iter=0, **kwargs): def __init__(self, metric, variant, src, iter=0, **kwargs):
self._url = src
parts = urlparse(src) parts = urlparse(src)
self._url = urlunparse((parts.scheme, parts.netloc, '', '', '', ''))
self._key = urlunparse(('', '', parts.path, parts.params, parts.query, parts.fragment)) self._key = urlunparse(('', '', parts.path, parts.params, parts.query, parts.fragment))
super(ImageEventNoUpload, self).__init__(metric, variant, iter=iter, **kwargs) super(ImageEventNoUpload, self).__init__(metric, variant, iter=iter, **kwargs)

View File

@ -79,15 +79,15 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
return results[0] return results[0]
def at_least_one(_exception_cls=Exception, **kwargs): def at_least_one(_exception_cls=Exception, _check_none=False, **kwargs):
actual = [k for k, v in kwargs.items() if v] actual = [k for k, v in kwargs.items() if (v is not None if _check_none else v)]
if len(actual) < 1: if len(actual) < 1:
raise _exception_cls('At least one of (%s) is required' % ', '.join(kwargs.keys())) raise _exception_cls('At least one of (%s) is required' % ', '.join(kwargs.keys()))
def mutually_exclusive(_exception_cls=Exception, _require_at_least_one=True, **kwargs): def mutually_exclusive(_exception_cls=Exception, _require_at_least_one=True, _check_none=False, **kwargs):
""" Helper for checking mutually exclusive options """ """ Helper for checking mutually exclusive options """
actual = [k for k, v in kwargs.items() if v] actual = [k for k, v in kwargs.items() if (v is not None if _check_none else v)]
if _require_at_least_one: if _require_at_least_one:
at_least_one(_exception_cls=_exception_cls, **kwargs) at_least_one(_exception_cls=_exception_cls, **kwargs)
if len(actual) > 1: if len(actual) > 1:
@ -106,4 +106,3 @@ def validate_dict(obj, key_types, value_types, desc=''):
def exact_match_regex(name): def exact_match_regex(name):
""" Convert string to a regex representing an exact match """ """ Convert string to a regex representing an exact match """
return '^%s$' % re.escape(name) return '^%s$' % re.escape(name)

View File

@ -5,15 +5,17 @@ import numpy as np
from PIL import Image from PIL import Image
from pathlib2 import Path from pathlib2 import Path
from .backend_api.services import tasks
from .backend_interface.logger import StdStreamPatch, LogFlusher from .backend_interface.logger import StdStreamPatch, LogFlusher
from .debugging.log import LoggerRoot from .backend_interface.task import Task as _Task
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler from .backend_interface.task.log import TaskHandler
from .backend_interface.util import mutually_exclusive
from .config import running_remotely, get_cache_dir
from .debugging.log import LoggerRoot
from .errors import UsageError
from .storage import StorageHelper from .storage import StorageHelper
from .utilities.plotly_reporter import SeriesInfo from .utilities.plotly_reporter import SeriesInfo
from .backend_api.services import tasks
from .backend_interface.task import Task as _Task
from .config import running_remotely, get_cache_dir
# Make sure that DeprecationWarning within this package always gets printed # Make sure that DeprecationWarning within this package always gets printed
warnings.filterwarnings('always', category=DeprecationWarning, module=__name__) warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
@ -353,17 +355,22 @@ class Logger(object):
) )
def report_image(self, title, series, iteration, local_path=None, image=None, matrix=None, max_image_history=None, def report_image(self, title, series, iteration, local_path=None, image=None, matrix=None, max_image_history=None,
delete_after_upload=False): delete_after_upload=False, url=None):
""" """
Report an image and upload its contents. Report an image and upload its contents.
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename) Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration. describing the task ID, title, series and iteration.
.. note::
:paramref:`~.Logger.report_image.local_path`, :paramref:`~.Logger.report_image.url`, :paramref:`~.Logger.report_image.image` and :paramref:`~.Logger.report_image.matrix`
are mutually exclusive, and at least one must be provided.
:param str title: Title (AKA metric) :param str title: Title (AKA metric)
:param str series: Series (AKA variant) :param str series: Series (AKA variant)
:param int iteration: Iteration number :param int iteration: Iteration number
:param str local_path: A path to an image file. Required unless matrix is provided. :param str local_path: A path to an image file.
:param str url: A URL to the location of a pre-uploaded image.
:param np.ndarray or PIL.Image.Image image: Could be a PIL.Image.Image object or a 3D numpy.ndarray :param np.ndarray or PIL.Image.Image image: Could be a PIL.Image.Image object or a 3D numpy.ndarray
object containing image data (RGB). object containing image data (RGB).
:param np.ndarray matrix: A 3D numpy.ndarray object containing image data (RGB). :param np.ndarray matrix: A 3D numpy.ndarray object containing image data (RGB).
@ -372,10 +379,12 @@ class Logger(object):
use negative value for unlimited. default is set in global configuration (default=5) use negative value for unlimited. default is set in global configuration (default=5)
:param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted :param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted
""" """
mutually_exclusive(
UsageError, _check_none=True,
local_path=local_path or None, url=url or None, image=image, matrix=matrix
)
if matrix is not None: if matrix is not None:
warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning) warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning)
if len([x for x in (matrix, image, local_path) if x is not None]) != 1:
raise ValueError('Expected only one of [image, matrix, local_path]')
if image is None: if image is None:
image = matrix image = matrix
if image is not None and not isinstance(image, (np.ndarray, Image.Image)): if image is not None and not isinstance(image, (np.ndarray, Image.Image)):
@ -383,6 +392,18 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series)
if url:
self._task.reporter.report_image(
title=title,
series=series,
src=url,
iter=iteration,
)
else:
upload_uri = self.get_default_upload_destination() upload_uri = self.get_default_upload_destination()
if not upload_uri: if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images' upload_uri = Path(get_cache_dir()) / 'debug_images'
@ -391,7 +412,6 @@ class Logger(object):
upload_uri = str(upload_uri) upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri) storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri)
self._touch_title_series(title, series)
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = np.array(image) image = np.array(image)

View File

@ -336,6 +336,11 @@ class CheckPackageUpdates(object):
@staticmethod @staticmethod
def get_version_from_updates_server(cur_version): def get_version_from_updates_server(cur_version):
"""
Get the latest version for trains from updates server
:param cur_version: The current running version of trains
:type cur_version: Version
"""
try: try:
_ = requests.get('https://updates.trains.allegro.ai/updates', _ = requests.get('https://updates.trains.allegro.ai/updates',
data=json.dumps({"versions": {"trains": str(cur_version)}}), data=json.dumps({"versions": {"trains": str(cur_version)}}),