mirror of
https://github.com/clearml/clearml
synced 2025-04-05 21:26:20 +00:00
Fix tf.saved_model.load() binding for TF >=2.0
This commit is contained in:
parent
1c198a47fd
commit
6ed6c3ff70
@ -1864,7 +1864,12 @@ class PatchTensorflowModelIO(object):
|
||||
from tensorflow.saved_model import load # noqa
|
||||
# noinspection PyUnresolvedReferences
|
||||
import tensorflow.saved_model as saved_model_load # noqa
|
||||
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
||||
saved_model_load.load = _patched_call(
|
||||
saved_model_load.load,
|
||||
PatchTensorflowModelIO._load
|
||||
if int(tensorflow.__version__.partition(".")[0]) >= 2
|
||||
else PatchTensorflowModelIO._load_lt_2_0
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
@ -1957,7 +1962,7 @@ class PatchTensorflowModelIO(object):
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, sess, tags, export_dir, *args, **saver_kwargs):
|
||||
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
|
||||
@ -1965,14 +1970,48 @@ class PatchTensorflowModelIO(object):
|
||||
empty = _Empty()
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
export_dir = WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
)
|
||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
else:
|
||||
# try to load model before registering, it might fail
|
||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, export_dir, *args, **saver_kwargs):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
return original_fn(export_dir, *args, **saver_kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
export_dir = WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
)
|
||||
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||
else:
|
||||
# try to load model before registering, it might fail
|
||||
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
|
Loading…
Reference in New Issue
Block a user