Reuse Model objects if we are storing local files (reduce clutter)

This commit is contained in:
allegroai 2020-03-22 18:15:32 +02:00
parent 4e2564cd3a
commit 493cce443a

View File

@ -1,15 +1,15 @@
import threading import threading
import weakref import weakref
from logging import getLogger
import six import six
from pathlib2 import Path from pathlib2 import Path
from ...debugging.log import get_logger
from ...config import running_remotely from ...config import running_remotely
from ...model import InputModel, OutputModel from ...model import InputModel, OutputModel
from ...backend_interface.model import Model from ...backend_interface.model import Model
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter' TrainsFrameworkAdapter = 'frameworks'
_recursion_guard = {} _recursion_guard = {}
@ -50,7 +50,7 @@ class WeightsFileHandler(object):
return filepath return filepath
if not filepath: if not filepath:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored") get_logger(TrainsFrameworkAdapter).debug("Could retrieve model file location, model is not logged")
return filepath return filepath
try: try:
@ -120,7 +120,7 @@ class WeightsFileHandler(object):
# update filepath to point to downloaded weights file # update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block # actual model weights loading will be done outside the try/exception block
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex)) get_logger(TrainsFrameworkAdapter).debug(str(ex))
finally: finally:
WeightsFileHandler._model_store_lookup_lock.release() WeightsFileHandler._model_store_lookup_lock.release()
@ -136,19 +136,57 @@ class WeightsFileHandler(object):
# check if object already has InputModel # check if object already has InputModel
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None)) trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
# notice ref_model() is not an error/typo this is a weakref object call
if ref_model is not None and model != ref_model(): if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead # old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model)) WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None trains_out_model, ref_model = None, None
if not saved_path:
get_logger(TrainsFrameworkAdapter).warning("Could retrieve model location, skipping auto model logging")
return saved_path
# check if we have output storage, and generate list of files to upload
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
target_filename = None
if len(files) > 1:
# noinspection PyBroadException
try:
target_filename = Path(saved_path).stem
except Exception:
pass
else:
target_filename = files[0]
# check if object already has InputModel # check if object already has InputModel
if trains_out_model is None: if trains_out_model is None:
# if we are overwriting a local file, try to load registered model
# if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri:
try:
in_model_id = InputModel.load_model(weights_url=saved_path)
if in_model_id:
in_model_id = in_model_id.id
get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path))
except:
in_model_id = None
else:
in_model_id = None
trains_out_model = OutputModel( trains_out_model = OutputModel(
task=task, task=task,
# config_dict=config, # config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None, name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(), label_enumeration=task.get_labels_enumeration(),
framework=framework, ) framework=framework, base_model_id=in_model_id)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
ref_model = weakref.ref(model) ref_model = weakref.ref(model)
@ -156,29 +194,9 @@ class WeightsFileHandler(object):
ref_model = None ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
if not saved_path:
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
return saved_path
# check if we have output storage, and generate list of files to upload
if trains_out_model.upload_storage_uri:
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
else:
files = None
# upload files if we found them, or just register the original path # upload files if we found them, or just register the original path
if files: if trains_out_model.upload_storage_uri:
if len(files) > 1: if len(files) > 1:
# noinspection PyBroadException
try:
target_filename = Path(saved_path).stem
except Exception:
target_filename = None
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
target_filename=target_filename) target_filename=target_filename)
else: else:
@ -186,7 +204,7 @@ class WeightsFileHandler(object):
else: else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex: except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex)) get_logger(TrainsFrameworkAdapter).debug(str(ex))
finally: finally:
WeightsFileHandler._model_store_lookup_lock.release() WeightsFileHandler._model_store_lookup_lock.release()