mirror of
https://github.com/clearml/clearml
synced 2025-01-31 09:07:00 +00:00
Improve frameworks binding
This commit is contained in:
parent
3bc1ec2362
commit
a896f5b465
@ -66,6 +66,8 @@ class PatchOsFork(object):
|
||||
task = Task.init()
|
||||
task.get_logger().flush()
|
||||
|
||||
# Hack: now make sure we setup the reporter thread
|
||||
task._setup_reporter()
|
||||
# if we got here patch the os._exit of our instance to call us
|
||||
def _at_exit_callback(*args, **kwargs):
|
||||
# call at exit manually
|
||||
|
@ -7,6 +7,7 @@ from pathlib2 import Path
|
||||
|
||||
from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel
|
||||
from ...backend_interface.model import Model
|
||||
|
||||
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
|
||||
_recursion_guard = {}
|
||||
@ -75,15 +76,28 @@ class WeightsFileHandler(object):
|
||||
config_text = trains_in_model.config_text if trains_in_model else None
|
||||
except Exception:
|
||||
config_text = None
|
||||
trains_in_model = InputModel.import_model(
|
||||
weights_url=filepath,
|
||||
config_dict=config_dict,
|
||||
config_text=config_text,
|
||||
name=task.name + ' ' + model_name_id,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
create_as_published=False,
|
||||
)
|
||||
|
||||
# check if we already have the model object:
|
||||
model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None))
|
||||
if model_id:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
trains_in_model = InputModel(model_id)
|
||||
except Exception:
|
||||
model_id = None
|
||||
|
||||
# if we do not, we need to import the model
|
||||
if not model_id:
|
||||
trains_in_model = InputModel.import_model(
|
||||
weights_url=filepath,
|
||||
config_dict=config_dict,
|
||||
config_text=config_text,
|
||||
name=task.name + ' ' + model_name_id,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
create_as_published=False,
|
||||
)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
@ -94,7 +108,8 @@ class WeightsFileHandler(object):
|
||||
task.connect(trains_in_model)
|
||||
# if we are running remotely we should deserialize the object
|
||||
# because someone might have changed the config_dict
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# reload the model
|
||||
model_config = trains_in_model.config_dict
|
||||
# verify that this is the same model so we are not deserializing a diff model
|
||||
|
@ -47,16 +47,22 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'as_posix'):
|
||||
filename = f.as_posix()
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = None
|
||||
except Exception:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
@ -65,31 +71,40 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'as_posix'):
|
||||
filename = f.as_posix()
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
except Exception:
|
||||
filename = None
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||
PatchPyTorchModelIO.__main_task)
|
||||
|
||||
|
@ -63,7 +63,7 @@ class EventTrainsWriter(object):
|
||||
return self.variants.copy()
|
||||
|
||||
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
|
||||
logdir_header='series'):
|
||||
logdir_header='series', auto_reduce_num_split=False):
|
||||
"""
|
||||
Split a tf.summary tag line to variant and metric.
|
||||
Variant is the first part of the split tag, metric is the second.
|
||||
@ -74,9 +74,13 @@ class EventTrainsWriter(object):
|
||||
:param str default_title: variant to use in case no variant can be inferred automatically
|
||||
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
|
||||
if 'title_last' then title=header title, if 'title' then title=title header
|
||||
:param boolean auto_reduce_num_split: if True and the tag is split for less parts then requested,
|
||||
then requested number of split parts is adjusted.
|
||||
:return: (str, str) variant and metric
|
||||
"""
|
||||
splitted_tag = tag.split(split_char)
|
||||
if auto_reduce_num_split and num_split_parts > len(splitted_tag)-1:
|
||||
num_split_parts = max(1, len(splitted_tag)-1)
|
||||
series = join_char.join(splitted_tag[-num_split_parts:])
|
||||
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
|
||||
|
||||
@ -195,7 +199,8 @@ class EventTrainsWriter(object):
|
||||
if img_data_np is None:
|
||||
return
|
||||
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
|
||||
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title',
|
||||
auto_reduce_num_split=True)
|
||||
if img_data_np.dtype != np.uint8:
|
||||
# assume scale 0-1
|
||||
img_data_np = (img_data_np * 255).astype(np.uint8)
|
||||
@ -998,11 +1003,19 @@ class PatchKerasModelIO(object):
|
||||
if Sequential is not None:
|
||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||
PatchKerasModelIO._updated_config)
|
||||
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
|
||||
if hasattr(Sequential.from_config, '__func__'):
|
||||
Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__,
|
||||
PatchKerasModelIO._from_config)
|
||||
else:
|
||||
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
|
||||
|
||||
if Network is not None:
|
||||
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
|
||||
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
|
||||
if hasattr(Sequential.from_config, '__func__'):
|
||||
Network.from_config.__func__ = _patched_call(Network.from_config.__func__,
|
||||
PatchKerasModelIO._from_config)
|
||||
else:
|
||||
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
|
||||
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
|
||||
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
|
||||
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
|
||||
@ -1072,7 +1085,8 @@ class PatchKerasModelIO(object):
|
||||
PatchKerasModelIO.__main_task.connect(self.trains_in_model)
|
||||
# if we are running remotely we should deserialize the object
|
||||
# because someone might have changed the configuration
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# reload the model
|
||||
model_config = self.trains_in_model.config_dict
|
||||
# verify that this is the same model so we are not deserializing a diff model
|
||||
@ -1100,7 +1114,8 @@ class PatchKerasModelIO(object):
|
||||
|
||||
# get filepath
|
||||
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
|
||||
PatchKerasModelIO.__main_task)
|
||||
@ -1183,7 +1198,8 @@ class PatchKerasModelIO(object):
|
||||
return original_fn(filepath, *args, **kwargs)
|
||||
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
|
||||
PatchKerasModelIO.__main_task)
|
||||
@ -1351,7 +1367,8 @@ class PatchTensorflowModelIO(object):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
return original_fn(self, sess, save_path, *args, **kwargs)
|
||||
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
@ -1372,7 +1389,8 @@ class PatchTensorflowModelIO(object):
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
@ -1415,7 +1433,8 @@ class PatchTensorflowModelIO(object):
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
model = original_fn(self, save_path, *args, **kwargs)
|
||||
|
@ -82,7 +82,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||
PatchXGBoostModelIO.__main_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
|
@ -90,7 +90,8 @@ class PatchedJoblib(object):
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# we assume scikit-learn, for the time being
|
||||
current_framework = Framework.scikitlearn
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
|
||||
|
Loading…
Reference in New Issue
Block a user