diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index a3d15d16..3b94f943 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -6,7 +6,8 @@ from collections import defaultdict from functools import partial from io import BytesIO from mimetypes import guess_extension -from typing import Any +from tempfile import mkstemp +from typing import Any, Union import numpy as np import six @@ -23,6 +24,11 @@ try: except ImportError: MessageToDict = None +try: + from PIL import GifImagePlugin # noqa +except ImportError: + pass + class TensorflowBinding(object): @classmethod @@ -360,6 +366,17 @@ class EventTrainsWriter(object): imdata = base64.b64decode(img_str) output = BytesIO(imdata) im = Image.open(output) + # if this is a GIF store as is + if getattr(im, 'is_animated'): + output.close() + fd, temp_file = mkstemp( + suffix=guess_extension(im.get_format_mimetype()) if hasattr(im, 'get_format_mimetype') + else ".{}".format(str(im.format).lower()) + ) + os.write(fd, imdata) + os.close(fd) + return temp_file + image = np.asarray(im) output.close() if height is not None and height > 0 and width is not None and width > 0: @@ -389,7 +406,7 @@ class EventTrainsWriter(object): return val def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None): - # type: (str, int, np.ndarray, int) -> () + # type: (str, int, Union[None, np.ndarray, str], int) -> () # only report images every specific interval if step % self.image_report_freq != 0: return None @@ -403,6 +420,18 @@ class EventTrainsWriter(object): force_add_prefix=self._logger._get_tensorboard_series_prefix()) step = self._fix_step_counter(title, series, step) + # check if this is a local temp file + if isinstance(img_data_np, str): + self._logger.report_image( + title=title, + series=series, + iteration=step, + local_path=img_data_np, + delete_after_upload=True, + max_image_history=self.max_keep_images if max_keep_images is None else max_keep_images, + ) + return + if img_data_np.dtype != np.uint8: # assume scale 0-1 img_data_np = (img_data_np * 255).astype(np.uint8)