mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Add support for GIF images in Tensorboard (issue #372)
This commit is contained in:
parent
fcc3c12b59
commit
d22cf7557d
@ -6,7 +6,8 @@ from collections import defaultdict
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
from typing import Any
|
from tempfile import mkstemp
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -23,6 +24,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
MessageToDict = None
|
MessageToDict = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from PIL import GifImagePlugin # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TensorflowBinding(object):
|
class TensorflowBinding(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -360,6 +366,17 @@ class EventTrainsWriter(object):
|
|||||||
imdata = base64.b64decode(img_str)
|
imdata = base64.b64decode(img_str)
|
||||||
output = BytesIO(imdata)
|
output = BytesIO(imdata)
|
||||||
im = Image.open(output)
|
im = Image.open(output)
|
||||||
|
# if this is a GIF store as is
|
||||||
|
if getattr(im, 'is_animated'):
|
||||||
|
output.close()
|
||||||
|
fd, temp_file = mkstemp(
|
||||||
|
suffix=guess_extension(im.get_format_mimetype()) if hasattr(im, 'get_format_mimetype')
|
||||||
|
else ".{}".format(str(im.format).lower())
|
||||||
|
)
|
||||||
|
os.write(fd, imdata)
|
||||||
|
os.close(fd)
|
||||||
|
return temp_file
|
||||||
|
|
||||||
image = np.asarray(im)
|
image = np.asarray(im)
|
||||||
output.close()
|
output.close()
|
||||||
if height is not None and height > 0 and width is not None and width > 0:
|
if height is not None and height > 0 and width is not None and width > 0:
|
||||||
@ -389,7 +406,7 @@ class EventTrainsWriter(object):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
|
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
|
||||||
# type: (str, int, np.ndarray, int) -> ()
|
# type: (str, int, Union[None, np.ndarray, str], int) -> ()
|
||||||
# only report images every specific interval
|
# only report images every specific interval
|
||||||
if step % self.image_report_freq != 0:
|
if step % self.image_report_freq != 0:
|
||||||
return None
|
return None
|
||||||
@ -403,6 +420,18 @@ class EventTrainsWriter(object):
|
|||||||
force_add_prefix=self._logger._get_tensorboard_series_prefix())
|
force_add_prefix=self._logger._get_tensorboard_series_prefix())
|
||||||
step = self._fix_step_counter(title, series, step)
|
step = self._fix_step_counter(title, series, step)
|
||||||
|
|
||||||
|
# check if this is a local temp file
|
||||||
|
if isinstance(img_data_np, str):
|
||||||
|
self._logger.report_image(
|
||||||
|
title=title,
|
||||||
|
series=series,
|
||||||
|
iteration=step,
|
||||||
|
local_path=img_data_np,
|
||||||
|
delete_after_upload=True,
|
||||||
|
max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if img_data_np.dtype != np.uint8:
|
if img_data_np.dtype != np.uint8:
|
||||||
# assume scale 0-1
|
# assume scale 0-1
|
||||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||||
|
Loading…
Reference in New Issue
Block a user