Replace opencv-python with more standard Pillow

This commit is contained in:
allegroai 2019-08-09 02:18:01 +03:00
parent 7beddf97da
commit 761082b474
8 changed files with 35 additions and 42 deletions

View File

@ -48,11 +48,13 @@ logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=sca
scatter3d = np.random.randint(10, size=(10, 3)) scatter3d = np.random.randint(10, size=(10, 3))
logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d) logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d)
# report image # report images
m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image_and_upload("test case", "image uint", iteration=1, matrix=m)
m = np.eye(256, 256, dtype=np.float) m = np.eye(256, 256, dtype=np.float)
logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m) logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m)
m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image_and_upload("test case", "image uint8", iteration=1, matrix=m)
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
logger.report_image_and_upload("test case", "image color red", iteration=1, matrix=m)
# flush reports (otherwise it will be flushed in the background, every couple of seconds) # flush reports (otherwise it will be flushed in the background, every couple of seconds)
logger.flush() logger.flush()

View File

@ -2,7 +2,7 @@
# #
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import cv2 from PIL import Image
from trains import Task from trains import Task
task = Task.init(project_name='examples', task_name='tensorboard toy example') task = Task.init(project_name='examples', task_name='tensorboard toy example')
@ -46,7 +46,8 @@ all_combined = tf.concat(all_distributions, 0)
tf.summary.histogram("all_combined", all_combined) tf.summary.histogram("all_combined", all_combined)
# convert to 4d [batch, col, row, RGB-channels] # convert to 4d [batch, col, row, RGB-channels]
image = cv2.imread('./samples/picasso.jpg')[:, :, ::-1] image_open = Image.open('./samples/picasso.jpg')
image = np.asarray(image_open)
image_gray = image[:, :, 0][np.newaxis, :, :, np.newaxis] image_gray = image[:, :, 0][np.newaxis, :, :, np.newaxis]
image_rgba = np.concatenate((image, 255*np.atleast_3d(np.ones(shape=image.shape[:2], dtype=np.uint8))), axis=2) image_rgba = np.concatenate((image, 255*np.atleast_3d(np.ones(shape=image.shape[:2], dtype=np.uint8))), axis=2)
image_rgba = image_rgba[np.newaxis, :, :, :] image_rgba = image_rgba[np.newaxis, :, :, :]

View File

@ -13,8 +13,8 @@ humanfriendly>=2.1
jsonmodels>=2.2 jsonmodels>=2.2
jsonschema>=2.6.0 jsonschema>=2.6.0
numpy>=1.10 numpy>=1.10
opencv-python>=3.2.0.8
pathlib2>=2.3.0 pathlib2>=2.3.0
Pillow>=4.1.1
pigar>=0.9.2 pigar>=0.9.2
plotly>=3.9.0 plotly>=3.9.0
psutil>=3.4.2 psutil>=3.4.2

View File

@ -3,13 +3,13 @@ import time
from threading import Lock from threading import Lock
import attr import attr
import cv2
import numpy as np import numpy as np
import pathlib2 import pathlib2
import six import six
from ...backend_api.services import events from PIL import Image
from six.moves.urllib.parse import urlparse, urlunparse from six.moves.urllib.parse import urlparse, urlunparse
from ...backend_api.services import events
from ...config import config from ...config import config
@ -248,12 +248,10 @@ class UploadEvent(MetricsEventAdapter):
image_data = np.reshape(image_data, (height, width)) image_data = np.reshape(image_data, (height, width))
# serialize image # serialize image
_, img_bytes = cv2.imencode( image = Image.fromarray(image_data)
self._format, image_data, output = six.BytesIO()
params=(cv2.IMWRITE_JPEG_QUALITY, self._quality), image_format = Image.registered_extensions().get(self._format.lower(), 'JPEG')
) image.save(output, format=image_format, quality=self._quality)
output = six.BytesIO(img_bytes.tostring())
output.seek(0) output.seek(0)
else: else:
local_file = self._local_image_path local_file = self._local_image_path

View File

@ -203,7 +203,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type value: int :type value: int
:param path: A path to an image file. Required unless matrix is provided. :param path: A path to an image file. Required unless matrix is provided.
:type path: str :type path: str
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided. :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str :type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination :param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5) use negative value for unlimited. default is set in global configuration (default=5)
@ -488,7 +488,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type value: int :type value: int
:param path: A path to an image file. Required unless matrix is provided. :param path: A path to an image file. Required unless matrix is provided.
:type path: str :type path: str
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided. :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str :type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination :param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5) use negative value for unlimited. default is set in global configuration (default=5)

View File

@ -4,11 +4,12 @@ import sys
import threading import threading
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from io import BytesIO
from logging import ERROR, WARNING, getLogger from logging import ERROR, WARNING, getLogger
from typing import Any from typing import Any
import cv2
import numpy as np import numpy as np
from PIL import Image
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
from ..import_bind import PostImportHookPatching from ..import_bind import PostImportHookPatching
@ -163,8 +164,11 @@ class EventTrainsWriter(object):
def _decode_image(self, img_str, width, height, color_channels): def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8) imdata = base64.b64decode(img_str)
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR) output = BytesIO(imdata)
im = Image.open(output)
image = np.asarray(im)
output.close()
val = image.reshape(height, width, -1).astype(np.uint8) val = image.reshape(height, width, -1).astype(np.uint8)
if val.ndim == 3 and val.shape[2] == 3: if val.ndim == 3 and val.shape[2] == 3:
if self._visualization_mode == 'BGR': if self._visualization_mode == 'BGR':

View File

@ -6,6 +6,7 @@ from tempfile import mkstemp
import six import six
from six import BytesIO from six import BytesIO
import threading
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..config import running_remotely from ..config import running_remotely
@ -21,6 +22,7 @@ class PatchedMatplotlib:
_global_image_counter = -1 _global_image_counter = -1
_current_task = None _current_task = None
_support_image_plot = False _support_image_plot = False
_recursion_guard = {}
class _PatchWarnings(object): class _PatchWarnings(object):
def __init__(self): def __init__(self):
@ -115,19 +117,21 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_figure_show(self, *args, **kw): def patched_figure_show(self, *args, **kw):
if hasattr(self, '_trains_show'): tid = threading._get_ident() if six.PY2 else threading.get_ident()
# flag will be cleared when calling clf() (object will be replaced) if PatchedMatplotlib._recursion_guard.get(tid):
# we are inside a gaurd do nothing
return PatchedMatplotlib._patched_original_figure(self, *args, **kw) return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
try:
self._trains_show = True PatchedMatplotlib._recursion_guard[tid] = True
except Exception:
pass
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self) PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw) ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
PatchedMatplotlib._recursion_guard[tid] = False
return ret return ret
@staticmethod @staticmethod
def patched_show(*args, **kw): def patched_show(*args, **kw):
tid = threading._get_ident() if six.PY2 else threading.get_ident()
PatchedMatplotlib._recursion_guard[tid] = True
# noinspection PyBroadException # noinspection PyBroadException
try: try:
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True) figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
@ -145,6 +149,7 @@ class PatchedMatplotlib:
plt.clf() plt.clf()
except Exception: except Exception:
pass pass
PatchedMatplotlib._recursion_guard[tid] = False
return ret return ret
@staticmethod @staticmethod
@ -172,12 +177,6 @@ class PatchedMatplotlib:
else: else:
mpl_fig = specific_fig mpl_fig = specific_fig
# mark as processed, so nested calls to figure.show will do nothing
try:
mpl_fig._trains_show = True
except Exception:
pass
# convert to plotly # convert to plotly
image = None image = None
plotly_fig = None plotly_fig = None

View File

@ -5,11 +5,6 @@ try:
import numpy as np import numpy as np
except Exception: except Exception:
np = None np = None
try:
import cv2
except Exception:
cv2 = None
def make_deterministic(seed=1337, cudnn_deterministic=False): def make_deterministic(seed=1337, cudnn_deterministic=False):
""" """
@ -40,12 +35,6 @@ def make_deterministic(seed=1337, cudnn_deterministic=False):
if np is not None: if np is not None:
np.random.seed(seed) np.random.seed(seed)
if cv2 is not None:
try:
cv2.setRNGSeed(seed)
except Exception:
pass
if torch is not None: if torch is not None:
try: try:
torch.manual_seed(seed) torch.manual_seed(seed)