Fix WeightsFileHandler support for model=None

This commit is contained in:
allegroai 2020-05-31 12:08:14 +03:00
parent 38230626c2
commit 92d003657b

View File

@ -114,14 +114,15 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire() WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel # check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None)) trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
id(model) if model is not None else None, (None, None))
if ref_model is not None and model != ref_model(): if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead # old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model)) WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None trains_in_model, ref_model = None, None
# check if object already has InputModel # check if object already has InputModel
model_name_id = getattr(model, 'name', '') model_name_id = getattr(model, 'name', '') if model else ''
# noinspection PyBroadException # noinspection PyBroadException
try: try:
config_text = None config_text = None
@ -149,7 +150,7 @@ class WeightsFileHandler(object):
weights_url=filepath, weights_url=filepath,
config_dict=config_dict, config_dict=config_dict,
config_text=config_text, config_text=config_text,
name=task.name + ' ' + model_name_id, name=task.name + (' ' + model_name_id) if model_name_id else '',
label_enumeration=task.get_labels_enumeration(), label_enumeration=task.get_labels_enumeration(),
framework=framework, framework=framework,
create_as_published=False, create_as_published=False,
@ -159,12 +160,13 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values(): for cb in WeightsFileHandler._model_post_callbacks.values():
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task) trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
# noinspection PyBroadException if model is not None:
try: # noinspection PyBroadException
ref_model = weakref.ref(model) try:
except Exception: ref_model = weakref.ref(model)
ref_model = None except Exception:
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model) ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task # todo: support multiple models for the same task
task.connect(trains_in_model) task.connect(trains_in_model)
# if we are running remotely we should deserialize the object # if we are running remotely we should deserialize the object
@ -196,7 +198,8 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire() WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel # check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None)) trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
id(model) if model is not None else None, (None, None))
# notice ref_model() is not an error/typo this is a weakref object call # notice ref_model() is not an error/typo this is a weakref object call
if ref_model is not None and model != ref_model(): if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead # old id pop it - it was probably reused because the object is dead
@ -232,20 +235,24 @@ class WeightsFileHandler(object):
# check if object already has InputModel # check if object already has InputModel
if trains_out_model is None: if trains_out_model is None:
# if we are overwriting a local file, try to load registered model in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None))
# if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri: if not in_model_id:
try: # if we are overwriting a local file, try to load registered model
in_model_id = InputModel.load_model(weights_url=saved_path) # if there is an output_uri, then by definition we will not overwrite previously stored models.
if in_model_id: if not task.output_uri:
in_model_id = in_model_id.id try:
get_logger(TrainsFrameworkAdapter).info( in_model_id = InputModel.load_model(weights_url=saved_path)
"Found existing registered model id={} [{}] reusing it.".format( if in_model_id:
in_model_id, saved_path)) in_model_id = in_model_id.id
except:
get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path))
except:
in_model_id = None
else:
in_model_id = None in_model_id = None
else:
in_model_id = None
trains_out_model = OutputModel( trains_out_model = OutputModel(
task=task, task=task,
@ -255,12 +262,13 @@ class WeightsFileHandler(object):
framework=framework, framework=framework,
base_model_id=in_model_id base_model_id=in_model_id
) )
# noinspection PyBroadException if model is not None:
try: # noinspection PyBroadException
ref_model = weakref.ref(model) try:
except Exception: ref_model = weakref.ref(model)
ref_model = None except Exception:
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
# call post model callback functions # call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values(): for cb in WeightsFileHandler._model_post_callbacks.values():