diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index 04d80255..e44a7dc2 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -1864,7 +1864,12 @@ class PatchTensorflowModelIO(object): from tensorflow.saved_model import load # noqa # noinspection PyUnresolvedReferences import tensorflow.saved_model as saved_model_load # noqa - saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load) + saved_model_load.load = _patched_call( + saved_model_load.load, + PatchTensorflowModelIO._load + if int(tensorflow.__version__.partition(".")[0]) >= 2 + else PatchTensorflowModelIO._load_lt_2_0 + ) except ImportError: pass except Exception: @@ -1957,7 +1962,7 @@ class PatchTensorflowModelIO(object): return model @staticmethod - def _load(original_fn, sess, tags, export_dir, *args, **saver_kwargs): + def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs): if PatchTensorflowModelIO.__main_task is None: return original_fn(sess, tags, export_dir, *args, **saver_kwargs) @@ -1965,14 +1970,48 @@ class PatchTensorflowModelIO(object): empty = _Empty() # Hack: disabled if False and running_remotely(): - export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + export_dir = WeightsFileHandler.restore_weights_file( + empty, export_dir, Framework.tensorflow, + PatchTensorflowModelIO.__main_task + ) model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) else: # try to load model before registering, it might fail model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) - WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + WeightsFileHandler.restore_weights_file( + empty, export_dir, Framework.tensorflow, + PatchTensorflowModelIO.__main_task + ) + + if empty.trains_in_model: + # noinspection PyBroadException + try: + model.trains_in_model = empty.trains_in_model + except Exception: + pass + return model + + @staticmethod + def _load(original_fn, export_dir, *args, **saver_kwargs): + if PatchTensorflowModelIO.__main_task is None: + return original_fn(export_dir, *args, **saver_kwargs) + + # register input model + empty = _Empty() + # Hack: disabled + if False and running_remotely(): + export_dir = WeightsFileHandler.restore_weights_file( + empty, export_dir, Framework.tensorflow, + PatchTensorflowModelIO.__main_task + ) + model = original_fn(export_dir, *args, **saver_kwargs) + else: + # try to load model before registering, it might fail + model = original_fn(export_dir, *args, **saver_kwargs) + WeightsFileHandler.restore_weights_file( + empty, export_dir, Framework.tensorflow, + PatchTensorflowModelIO.__main_task + ) if empty.trains_in_model: # noinspection PyBroadException