Fix WeightsFileHandler support for model=None

This commit is contained in:
allegroai 2020-05-31 12:08:14 +03:00
parent 38230626c2
commit 92d003657b

View File

@ -114,14 +114,15 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
id(model) if model is not None else None, (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
model_name_id = getattr(model, 'name', '') if model else ''
# noinspection PyBroadException
try:
config_text = None
@ -149,7 +150,7 @@ class WeightsFileHandler(object):
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
name=task.name + (' ' + model_name_id) if model_name_id else '',
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
@ -159,6 +160,7 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values():
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
if model is not None:
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
@ -196,7 +198,8 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
id(model) if model is not None else None, (None, None))
# notice ref_model() is not an error/typo this is a weakref object call
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
@ -232,6 +235,9 @@ class WeightsFileHandler(object):
# check if object already has InputModel
if trains_out_model is None:
in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None))
if not in_model_id:
# if we are overwriting a local file, try to load registered model
# if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri:
@ -239,6 +245,7 @@ class WeightsFileHandler(object):
in_model_id = InputModel.load_model(weights_url=saved_path)
if in_model_id:
in_model_id = in_model_id.id
get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path))
@ -255,6 +262,7 @@ class WeightsFileHandler(object):
framework=framework,
base_model_id=in_model_id
)
if model is not None:
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)