Replace opencv-python with more standard Pillow

This commit is contained in:
allegroai 2019-08-09 02:18:01 +03:00
parent 7beddf97da
commit 761082b474
8 changed files with 35 additions and 42 deletions

View File

@ -48,11 +48,13 @@ logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=sca
scatter3d = np.random.randint(10, size=(10, 3))
logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d)
# report image
m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image_and_upload("test case", "image uint", iteration=1, matrix=m)
# report images
m = np.eye(256, 256, dtype=np.float)
logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m)
m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image_and_upload("test case", "image uint8", iteration=1, matrix=m)
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
logger.report_image_and_upload("test case", "image color red", iteration=1, matrix=m)
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
logger.flush()

View File

@ -2,7 +2,7 @@
#
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
from trains import Task
task = Task.init(project_name='examples', task_name='tensorboard toy example')
@ -46,7 +46,8 @@ all_combined = tf.concat(all_distributions, 0)
tf.summary.histogram("all_combined", all_combined)
# convert to 4d [batch, col, row, RGB-channels]
image = cv2.imread('./samples/picasso.jpg')[:, :, ::-1]
image_open = Image.open('./samples/picasso.jpg')
image = np.asarray(image_open)
image_gray = image[:, :, 0][np.newaxis, :, :, np.newaxis]
image_rgba = np.concatenate((image, 255*np.atleast_3d(np.ones(shape=image.shape[:2], dtype=np.uint8))), axis=2)
image_rgba = image_rgba[np.newaxis, :, :, :]

View File

@ -13,8 +13,8 @@ humanfriendly>=2.1
jsonmodels>=2.2
jsonschema>=2.6.0
numpy>=1.10
opencv-python>=3.2.0.8
pathlib2>=2.3.0
Pillow>=4.1.1
pigar>=0.9.2
plotly>=3.9.0
psutil>=3.4.2

View File

@ -3,13 +3,13 @@ import time
from threading import Lock
import attr
import cv2
import numpy as np
import pathlib2
import six
from ...backend_api.services import events
from PIL import Image
from six.moves.urllib.parse import urlparse, urlunparse
from ...backend_api.services import events
from ...config import config
@ -248,12 +248,10 @@ class UploadEvent(MetricsEventAdapter):
image_data = np.reshape(image_data, (height, width))
# serialize image
_, img_bytes = cv2.imencode(
self._format, image_data,
params=(cv2.IMWRITE_JPEG_QUALITY, self._quality),
)
output = six.BytesIO(img_bytes.tostring())
image = Image.fromarray(image_data)
output = six.BytesIO()
image_format = Image.registered_extensions().get(self._format.lower(), 'JPEG')
image.save(output, format=image_format, quality=self._quality)
output.seek(0)
else:
local_file = self._local_image_path

View File

@ -203,7 +203,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type value: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
@ -488,7 +488,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type value: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)

View File

@ -4,11 +4,12 @@ import sys
import threading
from collections import defaultdict
from functools import partial
from io import BytesIO
from logging import ERROR, WARNING, getLogger
from typing import Any
import cv2
import numpy as np
from PIL import Image
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
from ..import_bind import PostImportHookPatching
@ -163,8 +164,11 @@ class EventTrainsWriter(object):
def _decode_image(self, img_str, width, height, color_channels):
# noinspection PyBroadException
try:
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
imdata = base64.b64decode(img_str)
output = BytesIO(imdata)
im = Image.open(output)
image = np.asarray(im)
output.close()
val = image.reshape(height, width, -1).astype(np.uint8)
if val.ndim == 3 and val.shape[2] == 3:
if self._visualization_mode == 'BGR':

View File

@ -6,6 +6,7 @@ from tempfile import mkstemp
import six
from six import BytesIO
import threading
from ..debugging.log import LoggerRoot
from ..config import running_remotely
@ -21,6 +22,7 @@ class PatchedMatplotlib:
_global_image_counter = -1
_current_task = None
_support_image_plot = False
_recursion_guard = {}
class _PatchWarnings(object):
def __init__(self):
@ -115,19 +117,21 @@ class PatchedMatplotlib:
@staticmethod
def patched_figure_show(self, *args, **kw):
if hasattr(self, '_trains_show'):
# flag will be cleared when calling clf() (object will be replaced)
tid = threading._get_ident() if six.PY2 else threading.get_ident()
if PatchedMatplotlib._recursion_guard.get(tid):
# we are inside a gaurd do nothing
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
try:
self._trains_show = True
except Exception:
pass
PatchedMatplotlib._recursion_guard[tid] = True
PatchedMatplotlib._report_figure(set_active=False, specific_fig=self)
ret = PatchedMatplotlib._patched_original_figure(self, *args, **kw)
PatchedMatplotlib._recursion_guard[tid] = False
return ret
@staticmethod
def patched_show(*args, **kw):
tid = threading._get_ident() if six.PY2 else threading.get_ident()
PatchedMatplotlib._recursion_guard[tid] = True
# noinspection PyBroadException
try:
figures = PatchedMatplotlib._get_output_figures(None, all_figures=True)
@ -145,6 +149,7 @@ class PatchedMatplotlib:
plt.clf()
except Exception:
pass
PatchedMatplotlib._recursion_guard[tid] = False
return ret
@staticmethod
@ -172,12 +177,6 @@ class PatchedMatplotlib:
else:
mpl_fig = specific_fig
# mark as processed, so nested calls to figure.show will do nothing
try:
mpl_fig._trains_show = True
except Exception:
pass
# convert to plotly
image = None
plotly_fig = None

View File

@ -5,11 +5,6 @@ try:
import numpy as np
except Exception:
np = None
try:
import cv2
except Exception:
cv2 = None
def make_deterministic(seed=1337, cudnn_deterministic=False):
"""
@ -40,12 +35,6 @@ def make_deterministic(seed=1337, cudnn_deterministic=False):
if np is not None:
np.random.seed(seed)
if cv2 is not None:
try:
cv2.setRNGSeed(seed)
except Exception:
pass
if torch is not None:
try:
torch.manual_seed(seed)