Fix WeightsFileHandler support for model=None

This commit is contained in:
allegroai 2020-05-31 12:08:14 +03:00
parent 38230626c2
commit 92d003657b

View File

@ -114,14 +114,15 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
id(model) if model is not None else None, (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '')
model_name_id = getattr(model, 'name', '') if model else ''
# noinspection PyBroadException
try:
config_text = None
@ -149,7 +150,7 @@ class WeightsFileHandler(object):
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + ' ' + model_name_id,
name=task.name + (' ' + model_name_id) if model_name_id else '',
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
@ -159,12 +160,13 @@ class WeightsFileHandler(object):
for cb in WeightsFileHandler._model_post_callbacks.values():
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
if model is not None:
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
# todo: support multiple models for the same task
task.connect(trains_in_model)
# if we are running remotely we should deserialize the object
@ -196,7 +198,8 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
id(model) if model is not None else None, (None, None))
# notice ref_model() is not an error/typo this is a weakref object call
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
@ -232,20 +235,24 @@ class WeightsFileHandler(object):
# check if object already has InputModel
if trains_out_model is None:
# if we are overwriting a local file, try to load registered model
# if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri:
try:
in_model_id = InputModel.load_model(weights_url=saved_path)
if in_model_id:
in_model_id = in_model_id.id
get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path))
except:
in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None))
if not in_model_id:
# if we are overwriting a local file, try to load registered model
# if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri:
try:
in_model_id = InputModel.load_model(weights_url=saved_path)
if in_model_id:
in_model_id = in_model_id.id
get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path))
except:
in_model_id = None
else:
in_model_id = None
else:
in_model_id = None
trains_out_model = OutputModel(
task=task,
@ -255,12 +262,13 @@ class WeightsFileHandler(object):
framework=framework,
base_model_id=in_model_id
)
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if model is not None:
# noinspection PyBroadException
try:
ref_model = weakref.ref(model)
except Exception:
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
# call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values():