From a5b1ed03300d62a3cd8b35b8ed75e5c1cad87914 Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Sat, 13 Jun 2020 22:09:45 +0300
Subject: [PATCH] Improve matplotlib integration, issue #140

---
 trains/backend_interface/metrics/reporter.py | 17 +++-
 trains/binding/matplotlib_bind.py            | 90 ++++++++++++++------
 trains/utilities/plotly_reporter.py          | 11 +--
 3 files changed, 84 insertions(+), 34 deletions(-)

diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py
index 7afb6d16..678b4f1c 100644
--- a/trains/backend_interface/metrics/reporter.py
+++ b/trains/backend_interface/metrics/reporter.py
@@ -629,6 +629,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
             raise ValueError('Expected only one of [filename, matrix]')
         kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
                       file_history_size=max_image_history)
+
+        if matrix is not None:
+            width = matrix.shape[1]
+            height = matrix.shape[0]
+        else:
+            # noinspection PyBroadException
+            try:
+                from PIL import Image
+                width, height = Image.open(path).size
+            except Exception:
+                width = 640
+                height = 480
+
         ev = UploadEvent(image_data=matrix, upload_uri=upload_uri, local_image_path=path,
                          delete_after_upload=delete_after_upload, **kwargs)
         _, url = ev.get_target_full_upload_uri(upload_uri or self._storage_uri, self._metrics.storage_key_prefix)
@@ -643,8 +656,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
         plotly_dict = create_image_plot(
             image_src=url,
             title=title + '/' + series,
-            width=matrix.shape[1] if matrix is not None else 640,
-            height=matrix.shape[0] if matrix is not None else 480,
+            width=640,
+            height=int(640*float(height or 480)/float(width or 640)),
         )
 
         return self.report_plot(
diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py
index 84830888..b5f58006 100644
--- a/trains/binding/matplotlib_bind.py
+++ b/trains/binding/matplotlib_bind.py
@@ -25,7 +25,7 @@ class PatchedMatplotlib:
     _global_plot_counter = -1
     _global_image_counter = -1
     _global_image_counter_limit = None
-    _last_iteration_plot_titles = (-1, [])
+    _last_iteration_plot_titles = {}
     _current_task = None
     _support_image_plot = False
     _matplotlylib = None
@@ -179,6 +179,21 @@ class PatchedMatplotlib:
     @staticmethod
     def patched_savefig(self, *args, **kw):
         ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
+        # noinspection PyBroadException
+        try:
+            fname = kw.get('fname') or args[0]
+            from pathlib2 import Path
+            if six.PY3:
+                from pathlib import Path as Path3
+            else:
+                Path3 = Path
+
+            # if we are not storing into a file (str/Path) do not log the matplotlib
+            if not isinstance(fname, (str, Path, Path3)):
+                return ret
+        except Exception:
+            pass
+
         tid = threading._get_ident() if six.PY2 else threading.get_ident()
         if not PatchedMatplotlib._recursion_guard.get(tid):
             PatchedMatplotlib._recursion_guard[tid] = True
@@ -273,35 +288,36 @@ class PatchedMatplotlib:
                     def our_mpl_to_plotly(fig):
                         if not PatchedMatplotlib._matplotlylib or not PatchedMatplotlib._plotly_renderer:
                             return None
-                        PatchedMatplotlib._matplotlylib.Exporter(PatchedMatplotlib._plotly_renderer,
-                                                                 close_mpl=False).run(fig)
-                        x_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_xticklabels())
+                        plotly_renderer = PatchedMatplotlib._matplotlylib.PlotlyRenderer()
+                        PatchedMatplotlib._matplotlylib.Exporter(plotly_renderer, close_mpl=False).run(fig)
+
+                        x_ticks = list(plotly_renderer.current_mpl_ax.get_xticklabels())
                         if x_ticks:
                             try:
                                 # check if all values can be cast to float
                                 values = [float(t.get_text().replace('−', '-')) for t in x_ticks]
                             except:
                                 try:
-                                    PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['xaxis1'].update({
+                                    plotly_renderer.plotly_fig['layout']['xaxis1'].update({
                                         'ticktext': [t.get_text() for t in x_ticks],
                                         'tickvals': [t.get_position()[0] for t in x_ticks],
                                     })
                                 except:
                                     pass
-                        y_ticks = list(PatchedMatplotlib._plotly_renderer.current_mpl_ax.get_yticklabels())
+                        y_ticks = list(plotly_renderer.current_mpl_ax.get_yticklabels())
                         if y_ticks:
                             try:
                                 # check if all values can be cast to float
                                 values = [float(t.get_text().replace('−', '-')) for t in y_ticks]
                             except:
                                 try:
-                                    PatchedMatplotlib._plotly_renderer.plotly_fig['layout']['yaxis1'].update({
+                                    plotly_renderer.plotly_fig['layout']['yaxis1'].update({
                                         'ticktext': [t.get_text() for t in y_ticks],
                                         'tickvals': [t.get_position()[1] for t in y_ticks],
                                     })
                                 except:
                                     pass
-                        return deepcopy(PatchedMatplotlib._plotly_renderer.plotly_fig)
+                        return deepcopy(plotly_renderer.plotly_fig)
 
                     plotly_fig = our_mpl_to_plotly(mpl_fig)
                     try:
@@ -366,7 +382,7 @@ class PatchedMatplotlib:
                         title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
                     else:
                         PatchedMatplotlib._global_plot_counter += 1
-                        title = 'untitled %d' % PatchedMatplotlib._global_plot_counter
+                        title = 'untitled %02d' % PatchedMatplotlib._global_plot_counter
 
                     plotly_fig.layout.margin = {}
                     plotly_fig.layout.autosize = True
@@ -392,8 +408,8 @@ class PatchedMatplotlib:
                             title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
                         else:
                             PatchedMatplotlib._global_image_counter += 1
-                            title = 'untitled %d' % (PatchedMatplotlib._global_image_counter %
-                                                     PatchedMatplotlib._global_image_counter_limit)
+                            title = 'untitled %02d' % (PatchedMatplotlib._global_image_counter %
+                                                       PatchedMatplotlib._global_image_counter_limit)
 
                         PatchedMatplotlib._matplotlib_reported_titles.add(title)
                         logger.report_image(title=title, series='plot image', local_path=image,
@@ -405,12 +421,14 @@ class PatchedMatplotlib:
                             title = PatchedMatplotlib._enforce_unique_title_per_iteration(plot_title, last_iteration)
                         else:
                             PatchedMatplotlib._global_plot_counter += 1
-                            title = 'untitled %d' % (PatchedMatplotlib._global_plot_counter %
-                                                     PatchedMatplotlib._global_image_counter_limit)
+                            title = 'untitled %02d' % (PatchedMatplotlib._global_plot_counter %
+                                                       PatchedMatplotlib._global_image_counter_limit)
 
                         PatchedMatplotlib._matplotlib_reported_titles.add(title)
-                        logger._report_image_plot_and_upload(title=title, series='plot image', path=image,
-                                                             delete_after_upload=True, iteration=last_iteration)
+                        # noinspection PyProtectedMember
+                        logger._report_image_plot_and_upload(
+                            title=title, series='plot image', path=image,
+                            delete_after_upload=True, iteration=last_iteration)
         except Exception:
             # plotly failed
             pass
@@ -419,19 +437,37 @@ class PatchedMatplotlib:
 
     @staticmethod
     def _enforce_unique_title_per_iteration(title, last_iteration):
-        if last_iteration != PatchedMatplotlib._last_iteration_plot_titles[0]:
-            PatchedMatplotlib._last_iteration_plot_titles = (last_iteration, [title])
-        elif title not in PatchedMatplotlib._last_iteration_plot_titles[1]:
-            PatchedMatplotlib._last_iteration_plot_titles[1].append(title)
+        # type: (str, int) -> str
+        """
+        Matplotlib with specific title will reset the title counter on every new iteration.
+        Calling title twice each iteration will produce "title" and "title/1" for every iteration
+
+        :param title: original matplotlib title
+        :param last_iteration: the current "last_iteration"
+        :return: new title to use (with counter attached if necessary)
+        """
+        # check if we already encountered the title
+        if title in PatchedMatplotlib._last_iteration_plot_titles:
+            # if we have check the last iteration
+            title_last_iteration, title_counter = PatchedMatplotlib._last_iteration_plot_titles[title]
+            # if this is a new iteration start from the beginning
+            if last_iteration == title_last_iteration:
+                title_counter += 1
+            else:  # if this is a new iteration start from the beginning
+                title_last_iteration = last_iteration
+                title_counter = 0
         else:
-            base_title = title
-            counter = 1
-            while title in PatchedMatplotlib._last_iteration_plot_titles[1]:
-                # we already used this title in this iteration, we should change the title
-                title = base_title + ' %d' % counter
-                counter += 1
-            # store the new title
-            PatchedMatplotlib._last_iteration_plot_titles[1].append(title)
+            # this is a new title
+            title_last_iteration = last_iteration
+            title_counter = 0
+
+        base_title = title
+        # if this is the zero counter to not add the counter to the title
+        if title_counter != 0:
+            title = base_title + '/%d' % title_counter
+        # update back the title iteration counter
+        PatchedMatplotlib._last_iteration_plot_titles[base_title] = (title_last_iteration, title_counter)
+
         return title
 
     @staticmethod
diff --git a/trains/utilities/plotly_reporter.py b/trains/utilities/plotly_reporter.py
index 1e1d56eb..8daa93bf 100644
--- a/trains/utilities/plotly_reporter.py
+++ b/trains/utilities/plotly_reporter.py
@@ -314,9 +314,9 @@ def create_image_plot(image_src, title, width=640, height=480, series=None, comm
         "data": [],
         "layout": {
             "xaxis": {"visible": False, "range": [0, width]},
-            "yaxis": {"visible": False, "range": [0, height]},
-            # "width": width,
-            # "height": height,
+            "yaxis": {"visible": False, "range": [0, height], "scaleanchor": "x"},
+            "width": width,
+            "height": height,
             "margin": {'l': 0, 'r': 0, 't': 0, 'b': 0},
             "images": [{
                 "sizex": width,
@@ -325,8 +325,9 @@ def create_image_plot(image_src, title, width=640, height=480, series=None, comm
                 "yref": "y",
                 "opacity": 1.0,
                 "x": 0,
-                "y": int(height / 2),
-                "yanchor": "middle",
+                "y": height,
+                # "xanchor": "left",
+                # "yanchor": "bottom",
                 "sizing": "contain",
                 "layer": "below",
                 "source": image_src