From 6ed6c3ff70114a7a57ad2eaab55c3b448d2e3cfc Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Mon, 8 Nov 2021 10:02:39 +0200
Subject: [PATCH] Fix tf.saved_model.load() binding for TF >=2.0

---
 clearml/binding/frameworks/tensorflow_bind.py | 51 ++++++++++++++++---
 1 file changed, 45 insertions(+), 6 deletions(-)

diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py
index 04d80255..e44a7dc2 100644
--- a/clearml/binding/frameworks/tensorflow_bind.py
+++ b/clearml/binding/frameworks/tensorflow_bind.py
@@ -1864,7 +1864,12 @@ class PatchTensorflowModelIO(object):
             from tensorflow.saved_model import load  # noqa
             # noinspection PyUnresolvedReferences
             import tensorflow.saved_model as saved_model_load  # noqa
-            saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
+            saved_model_load.load = _patched_call(
+                saved_model_load.load,
+                PatchTensorflowModelIO._load
+                if int(tensorflow.__version__.partition(".")[0]) >= 2
+                else PatchTensorflowModelIO._load_lt_2_0
+            )
         except ImportError:
             pass
         except Exception:
@@ -1957,7 +1962,7 @@ class PatchTensorflowModelIO(object):
         return model
 
     @staticmethod
-    def _load(original_fn, sess, tags, export_dir, *args, **saver_kwargs):
+    def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
         if PatchTensorflowModelIO.__main_task is None:
             return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
 
@@ -1965,14 +1970,48 @@ class PatchTensorflowModelIO(object):
         empty = _Empty()
         # Hack: disabled
         if False and running_remotely():
-            export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
-                                                                 PatchTensorflowModelIO.__main_task)
+            export_dir = WeightsFileHandler.restore_weights_file(
+                empty, export_dir, Framework.tensorflow,
+                PatchTensorflowModelIO.__main_task
+            )
             model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
         else:
             # try to load model before registering, it might fail
             model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
-            WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
-                                                    PatchTensorflowModelIO.__main_task)
+            WeightsFileHandler.restore_weights_file(
+                empty, export_dir, Framework.tensorflow,
+                PatchTensorflowModelIO.__main_task
+            )
+
+        if empty.trains_in_model:
+            # noinspection PyBroadException
+            try:
+                model.trains_in_model = empty.trains_in_model
+            except Exception:
+                pass
+        return model
+
+    @staticmethod
+    def _load(original_fn, export_dir, *args, **saver_kwargs):
+        if PatchTensorflowModelIO.__main_task is None:
+            return original_fn(export_dir, *args, **saver_kwargs)
+
+        # register input model
+        empty = _Empty()
+        # Hack: disabled
+        if False and running_remotely():
+            export_dir = WeightsFileHandler.restore_weights_file(
+                empty, export_dir, Framework.tensorflow,
+                PatchTensorflowModelIO.__main_task
+            )
+            model = original_fn(export_dir, *args, **saver_kwargs)
+        else:
+            # try to load model before registering, it might fail
+            model = original_fn(export_dir, *args, **saver_kwargs)
+            WeightsFileHandler.restore_weights_file(
+                empty, export_dir, Framework.tensorflow,
+                PatchTensorflowModelIO.__main_task
+            )
 
         if empty.trains_in_model:
             # noinspection PyBroadException