mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Add support for GIF images in Tensorboard (issue #372)
This commit is contained in:
parent
fcc3c12b59
commit
d22cf7557d
@ -6,7 +6,8 @@ from collections import defaultdict
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from mimetypes import guess_extension
|
||||
from typing import Any
|
||||
from tempfile import mkstemp
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -23,6 +24,11 @@ try:
|
||||
except ImportError:
|
||||
MessageToDict = None
|
||||
|
||||
try:
|
||||
from PIL import GifImagePlugin # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class TensorflowBinding(object):
|
||||
@classmethod
|
||||
@ -360,6 +366,17 @@ class EventTrainsWriter(object):
|
||||
imdata = base64.b64decode(img_str)
|
||||
output = BytesIO(imdata)
|
||||
im = Image.open(output)
|
||||
# if this is a GIF store as is
|
||||
if getattr(im, 'is_animated'):
|
||||
output.close()
|
||||
fd, temp_file = mkstemp(
|
||||
suffix=guess_extension(im.get_format_mimetype()) if hasattr(im, 'get_format_mimetype')
|
||||
else ".{}".format(str(im.format).lower())
|
||||
)
|
||||
os.write(fd, imdata)
|
||||
os.close(fd)
|
||||
return temp_file
|
||||
|
||||
image = np.asarray(im)
|
||||
output.close()
|
||||
if height is not None and height > 0 and width is not None and width > 0:
|
||||
@ -389,7 +406,7 @@ class EventTrainsWriter(object):
|
||||
return val
|
||||
|
||||
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
|
||||
# type: (str, int, np.ndarray, int) -> ()
|
||||
# type: (str, int, Union[None, np.ndarray, str], int) -> ()
|
||||
# only report images every specific interval
|
||||
if step % self.image_report_freq != 0:
|
||||
return None
|
||||
@ -403,6 +420,18 @@ class EventTrainsWriter(object):
|
||||
force_add_prefix=self._logger._get_tensorboard_series_prefix())
|
||||
step = self._fix_step_counter(title, series, step)
|
||||
|
||||
# check if this is a local temp file
|
||||
if isinstance(img_data_np, str):
|
||||
self._logger.report_image(
|
||||
title=title,
|
||||
series=series,
|
||||
iteration=step,
|
||||
local_path=img_data_np,
|
||||
delete_after_upload=True,
|
||||
max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images,
|
||||
)
|
||||
return
|
||||
|
||||
if img_data_np.dtype != np.uint8:
|
||||
# assume scale 0-1
|
||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||
|
Loading…
Reference in New Issue
Block a user