Improve frameworks binding

This commit is contained in:
allegroai 2019-08-19 21:18:44 +03:00
parent 3bc1ec2362
commit a896f5b465
6 changed files with 94 additions and 41 deletions

View File

@ -66,6 +66,8 @@ class PatchOsFork(object):
task = Task.init()
task.get_logger().flush()
# Hack: now make sure we setup the reporter thread
task._setup_reporter()
# if we got here patch the os._exit of our instance to call us
def _at_exit_callback(*args, **kwargs):
# call at exit manually

View File

@ -7,6 +7,7 @@ from pathlib2 import Path
from ...config import running_remotely
from ...model import InputModel, OutputModel
from ...backend_interface.model import Model
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
_recursion_guard = {}
@ -75,15 +76,28 @@ class WeightsFileHandler(object):
config_text = trains_in_model.config_text if trains_in_model else None
except Exception:
config_text = None
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
# check if we already have the model object:
model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None))
if model_id:
# noinspection PyBroadException
try:
trains_in_model = InputModel(model_id)
except Exception:
model_id = None
# if we do not, we need to import the model
if not model_id:
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
@ -94,7 +108,8 @@ class WeightsFileHandler(object):
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the config_dict
if running_remotely():
# Hack: disabled
if False and running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model

View File

@ -47,16 +47,22 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
if not PatchPyTorchModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
# noinspection PyBroadException
try:
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'as_posix'):
filename = f.as_posix()
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
except Exception:
filename = None
# give the model a descriptive name based on the file name
@ -65,31 +71,40 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchPyTorchModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# noinspection PyBroadException
try:
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'as_posix'):
filename = f.as_posix()
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
except Exception:
filename = None
# register input model
empty = _Empty()
if running_remotely():
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(filename or f, *args, **kwargs)
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
PatchPyTorchModelIO.__main_task)

View File

@ -63,7 +63,7 @@ class EventTrainsWriter(object):
return self.variants.copy()
def tag_splitter(self, tag, num_split_parts, split_char='/', join_char='_', default_title='variant',
logdir_header='series'):
logdir_header='series', auto_reduce_num_split=False):
"""
Split a tf.summary tag line to variant and metric.
Variant is the first part of the split tag, metric is the second.
@ -74,9 +74,13 @@ class EventTrainsWriter(object):
:param str default_title: variant to use in case no variant can be inferred automatically
:param str logdir_header: if 'series_last' then series=header: series, if 'series then series=series :header,
if 'title_last' then title=header title, if 'title' then title=title header
:param boolean auto_reduce_num_split: if True and the tag is split for less parts then requested,
then requested number of split parts is adjusted.
:return: (str, str) variant and metric
"""
splitted_tag = tag.split(split_char)
if auto_reduce_num_split and num_split_parts > len(splitted_tag)-1:
num_split_parts = max(1, len(splitted_tag)-1)
series = join_char.join(splitted_tag[-num_split_parts:])
title = join_char.join(splitted_tag[:-num_split_parts]) or default_title
@ -195,7 +199,8 @@ class EventTrainsWriter(object):
if img_data_np is None:
return
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title')
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Images', logdir_header='title',
auto_reduce_num_split=True)
if img_data_np.dtype != np.uint8:
# assume scale 0-1
img_data_np = (img_data_np * 255).astype(np.uint8)
@ -998,11 +1003,19 @@ class PatchKerasModelIO(object):
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if hasattr(Sequential.from_config, '__func__'):
Sequential.from_config.__func__ = _patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config)
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
if hasattr(Sequential.from_config, '__func__'):
Network.from_config.__func__ = _patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config)
else:
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
@ -1072,7 +1085,8 @@ class PatchKerasModelIO(object):
PatchKerasModelIO.__main_task.connect(self.trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the configuration
if running_remotely():
# Hack: disabled
if False and running_remotely():
# reload the model
model_config = self.trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
@ -1100,7 +1114,8 @@ class PatchKerasModelIO(object):
# get filepath
filepath = kwargs['filepath'] if 'filepath' in kwargs else args[0]
if running_remotely():
# Hack: disabled
if False and running_remotely():
# register/load model weights
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
PatchKerasModelIO.__main_task)
@ -1183,7 +1198,8 @@ class PatchKerasModelIO(object):
return original_fn(filepath, *args, **kwargs)
empty = _Empty()
if running_remotely():
# Hack: disabled
if False and running_remotely():
# register/load model weights
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
PatchKerasModelIO.__main_task)
@ -1351,7 +1367,8 @@ class PatchTensorflowModelIO(object):
if PatchTensorflowModelIO.__main_task is None:
return original_fn(self, sess, save_path, *args, **kwargs)
if running_remotely():
# Hack: disabled
if False and running_remotely():
# register/load model weights
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
@ -1372,7 +1389,8 @@ class PatchTensorflowModelIO(object):
# register input model
empty = _Empty()
if running_remotely():
# Hack: disabled
if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file(empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
@ -1415,7 +1433,8 @@ class PatchTensorflowModelIO(object):
# register input model
empty = _Empty()
if running_remotely():
# Hack: disabled
if False and running_remotely():
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
model = original_fn(self, save_path, *args, **kwargs)

View File

@ -82,7 +82,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
# register input model
empty = _Empty()
if running_remotely():
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task)
model = original_fn(filename or f, *args, **kwargs)

View File

@ -90,7 +90,8 @@ class PatchedJoblib(object):
# register input model
empty = _Empty()
if running_remotely():
# Hack: disabled
if False and running_remotely():
# we assume scikit-learn, for the time being
current_framework = Framework.scikitlearn
filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework,