Fix tf.saved_model.load() binding for TF >=2.0

This commit is contained in:
allegroai 2021-11-08 10:02:39 +02:00
parent 1c198a47fd
commit 6ed6c3ff70

View File

@ -1864,7 +1864,12 @@ class PatchTensorflowModelIO(object):
from tensorflow.saved_model import load # noqa
# noinspection PyUnresolvedReferences
import tensorflow.saved_model as saved_model_load # noqa
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
saved_model_load.load = _patched_call(
saved_model_load.load,
PatchTensorflowModelIO._load
if int(tensorflow.__version__.partition(".")[0]) >= 2
else PatchTensorflowModelIO._load_lt_2_0
)
except ImportError:
pass
except Exception:
@ -1957,7 +1962,7 @@ class PatchTensorflowModelIO(object):
return model
@staticmethod
def _load(original_fn, sess, tags, export_dir, *args, **saver_kwargs):
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None:
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
@ -1965,14 +1970,48 @@ class PatchTensorflowModelIO(object):
empty = _Empty()
# Hack: disabled
if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
export_dir = WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
)
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
else:
# try to load model before registering, it might fail
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
@staticmethod
def _load(original_fn, export_dir, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None:
return original_fn(export_dir, *args, **saver_kwargs)
# register input model
empty = _Empty()
# Hack: disabled
if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
)
model = original_fn(export_dir, *args, **saver_kwargs)
else:
# try to load model before registering, it might fail
model = original_fn(export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
)
if empty.trains_in_model:
# noinspection PyBroadException