mirror of
https://github.com/clearml/clearml
synced 2025-04-08 22:54:44 +00:00
Fix tf.saved_model.load() binding for TF >=2.0
This commit is contained in:
parent
1c198a47fd
commit
6ed6c3ff70
@ -1864,7 +1864,12 @@ class PatchTensorflowModelIO(object):
|
|||||||
from tensorflow.saved_model import load # noqa
|
from tensorflow.saved_model import load # noqa
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import tensorflow.saved_model as saved_model_load # noqa
|
import tensorflow.saved_model as saved_model_load # noqa
|
||||||
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
saved_model_load.load = _patched_call(
|
||||||
|
saved_model_load.load,
|
||||||
|
PatchTensorflowModelIO._load
|
||||||
|
if int(tensorflow.__version__.partition(".")[0]) >= 2
|
||||||
|
else PatchTensorflowModelIO._load_lt_2_0
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1957,7 +1962,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, sess, tags, export_dir, *args, **saver_kwargs):
|
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO.__main_task is None:
|
||||||
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||||
|
|
||||||
@ -1965,14 +1970,48 @@ class PatchTensorflowModelIO(object):
|
|||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
|
export_dir = WeightsFileHandler.restore_weights_file(
|
||||||
PatchTensorflowModelIO.__main_task)
|
empty, export_dir, Framework.tensorflow,
|
||||||
|
PatchTensorflowModelIO.__main_task
|
||||||
|
)
|
||||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, it might fail
|
# try to load model before registering, it might fail
|
||||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
|
WeightsFileHandler.restore_weights_file(
|
||||||
PatchTensorflowModelIO.__main_task)
|
empty, export_dir, Framework.tensorflow,
|
||||||
|
PatchTensorflowModelIO.__main_task
|
||||||
|
)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, export_dir, *args, **saver_kwargs):
|
||||||
|
if PatchTensorflowModelIO.__main_task is None:
|
||||||
|
return original_fn(export_dir, *args, **saver_kwargs)
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
# Hack: disabled
|
||||||
|
if False and running_remotely():
|
||||||
|
export_dir = WeightsFileHandler.restore_weights_file(
|
||||||
|
empty, export_dir, Framework.tensorflow,
|
||||||
|
PatchTensorflowModelIO.__main_task
|
||||||
|
)
|
||||||
|
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||||
|
else:
|
||||||
|
# try to load model before registering, it might fail
|
||||||
|
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(
|
||||||
|
empty, export_dir, Framework.tensorflow,
|
||||||
|
PatchTensorflowModelIO.__main_task
|
||||||
|
)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
|
Loading…
Reference in New Issue
Block a user