Add support for GIF images in Tensorboard (issue #372)

This commit is contained in:
allegroai 2021-06-01 00:18:39 +03:00
parent fcc3c12b59
commit d22cf7557d

View File

@ -6,7 +6,8 @@ from collections import defaultdict
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Any from tempfile import mkstemp
from typing import Any, Union
import numpy as np import numpy as np
import six import six
@ -23,6 +24,11 @@ try:
except ImportError: except ImportError:
MessageToDict = None MessageToDict = None
try:
from PIL import GifImagePlugin # noqa
except ImportError:
pass
class TensorflowBinding(object): class TensorflowBinding(object):
@classmethod @classmethod
@ -360,6 +366,17 @@ class EventTrainsWriter(object):
imdata = base64.b64decode(img_str) imdata = base64.b64decode(img_str)
output = BytesIO(imdata) output = BytesIO(imdata)
im = Image.open(output) im = Image.open(output)
# if this is a GIF store as is
if getattr(im, 'is_animated'):
output.close()
fd, temp_file = mkstemp(
suffix=guess_extension(im.get_format_mimetype()) if hasattr(im, 'get_format_mimetype')
else ".{}".format(str(im.format).lower())
)
os.write(fd, imdata)
os.close(fd)
return temp_file
image = np.asarray(im) image = np.asarray(im)
output.close() output.close()
if height is not None and height > 0 and width is not None and width > 0: if height is not None and height > 0 and width is not None and width > 0:
@ -389,7 +406,7 @@ class EventTrainsWriter(object):
return val return val
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None): def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
# type: (str, int, np.ndarray, int) -> () # type: (str, int, Union[None, np.ndarray, str], int) -> ()
# only report images every specific interval # only report images every specific interval
if step % self.image_report_freq != 0: if step % self.image_report_freq != 0:
return None return None
@ -403,6 +420,18 @@ class EventTrainsWriter(object):
force_add_prefix=self._logger._get_tensorboard_series_prefix()) force_add_prefix=self._logger._get_tensorboard_series_prefix())
step = self._fix_step_counter(title, series, step) step = self._fix_step_counter(title, series, step)
# check if this is a local temp file
if isinstance(img_data_np, str):
self._logger.report_image(
title=title,
series=series,
iteration=step,
local_path=img_data_np,
delete_after_upload=True,
max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images,
)
return
if img_data_np.dtype != np.uint8: if img_data_np.dtype != np.uint8:
# assume scale 0-1 # assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8) img_data_np = (img_data_np * 255).astype(np.uint8)