Add support for GIF images in Tensorboard (issue #372)

This commit is contained in:
allegroai 2021-06-01 00:18:39 +03:00
parent fcc3c12b59
commit d22cf7557d

View File

@ -6,7 +6,8 @@ from collections import defaultdict
from functools import partial
from io import BytesIO
from mimetypes import guess_extension
from typing import Any
from tempfile import mkstemp
from typing import Any, Union
import numpy as np
import six
@ -23,6 +24,11 @@ try:
except ImportError:
MessageToDict = None
try:
from PIL import GifImagePlugin # noqa
except ImportError:
pass
class TensorflowBinding(object):
@classmethod
@ -360,6 +366,17 @@ class EventTrainsWriter(object):
imdata = base64.b64decode(img_str)
output = BytesIO(imdata)
im = Image.open(output)
# if this is a GIF store as is
if getattr(im, 'is_animated'):
output.close()
fd, temp_file = mkstemp(
suffix=guess_extension(im.get_format_mimetype()) if hasattr(im, 'get_format_mimetype')
else ".{}".format(str(im.format).lower())
)
os.write(fd, imdata)
os.close(fd)
return temp_file
image = np.asarray(im)
output.close()
if height is not None and height > 0 and width is not None and width > 0:
@ -389,7 +406,7 @@ class EventTrainsWriter(object):
return val
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
# type: (str, int, np.ndarray, int) -> ()
# type: (str, int, Union[None, np.ndarray, str], int) -> ()
# only report images every specific interval
if step % self.image_report_freq != 0:
return None
@ -403,6 +420,18 @@ class EventTrainsWriter(object):
force_add_prefix=self._logger._get_tensorboard_series_prefix())
step = self._fix_step_counter(title, series, step)
# check if this is a local temp file
if isinstance(img_data_np, str):
self._logger.report_image(
title=title,
series=series,
iteration=step,
local_path=img_data_np,
delete_after_upload=True,
max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images,
)
return
if img_data_np.dtype != np.uint8:
# assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8)