mirror of
https://github.com/clearml/clearml
synced 2025-03-16 00:17:15 +00:00
Reuse Model objects if we are storing local files (reduce clutter)
This commit is contained in:
parent
4e2564cd3a
commit
493cce443a
@ -1,15 +1,15 @@
|
|||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ...debugging.log import get_logger
|
||||||
from ...config import running_remotely
|
from ...config import running_remotely
|
||||||
from ...model import InputModel, OutputModel
|
from ...model import InputModel, OutputModel
|
||||||
from ...backend_interface.model import Model
|
from ...backend_interface.model import Model
|
||||||
|
|
||||||
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
|
TrainsFrameworkAdapter = 'frameworks'
|
||||||
_recursion_guard = {}
|
_recursion_guard = {}
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ class WeightsFileHandler(object):
|
|||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
if not filepath:
|
if not filepath:
|
||||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
|
get_logger(TrainsFrameworkAdapter).debug("Could retrieve model file location, model is not logged")
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -120,7 +120,7 @@ class WeightsFileHandler(object):
|
|||||||
# update filepath to point to downloaded weights file
|
# update filepath to point to downloaded weights file
|
||||||
# actual model weights loading will be done outside the try/exception block
|
# actual model weights loading will be done outside the try/exception block
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
||||||
finally:
|
finally:
|
||||||
WeightsFileHandler._model_store_lookup_lock.release()
|
WeightsFileHandler._model_store_lookup_lock.release()
|
||||||
|
|
||||||
@ -136,19 +136,57 @@ class WeightsFileHandler(object):
|
|||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
|
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
|
||||||
|
# notice ref_model() is not an error/typo this is a weakref object call
|
||||||
if ref_model is not None and model != ref_model():
|
if ref_model is not None and model != ref_model():
|
||||||
# old id pop it - it was probably reused because the object is dead
|
# old id pop it - it was probably reused because the object is dead
|
||||||
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
||||||
trains_out_model, ref_model = None, None
|
trains_out_model, ref_model = None, None
|
||||||
|
|
||||||
|
if not saved_path:
|
||||||
|
get_logger(TrainsFrameworkAdapter).warning("Could retrieve model location, skipping auto model logging")
|
||||||
|
return saved_path
|
||||||
|
|
||||||
|
# check if we have output storage, and generate list of files to upload
|
||||||
|
if Path(saved_path).is_dir():
|
||||||
|
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
||||||
|
elif singlefile:
|
||||||
|
files = [str(Path(saved_path).absolute())]
|
||||||
|
else:
|
||||||
|
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
|
||||||
|
|
||||||
|
target_filename = None
|
||||||
|
if len(files) > 1:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
target_filename = Path(saved_path).stem
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
target_filename = files[0]
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
if trains_out_model is None:
|
if trains_out_model is None:
|
||||||
|
# if we are overwriting a local file, try to load registered model
|
||||||
|
# if there is an output_uri, then by definition we will not overwrite previously stored models.
|
||||||
|
if not task.output_uri:
|
||||||
|
try:
|
||||||
|
in_model_id = InputModel.load_model(weights_url=saved_path)
|
||||||
|
if in_model_id:
|
||||||
|
in_model_id = in_model_id.id
|
||||||
|
get_logger(TrainsFrameworkAdapter).info(
|
||||||
|
"Found existing registered model id={} [{}] reusing it.".format(
|
||||||
|
in_model_id, saved_path))
|
||||||
|
except:
|
||||||
|
in_model_id = None
|
||||||
|
else:
|
||||||
|
in_model_id = None
|
||||||
|
|
||||||
trains_out_model = OutputModel(
|
trains_out_model = OutputModel(
|
||||||
task=task,
|
task=task,
|
||||||
# config_dict=config,
|
# config_dict=config,
|
||||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||||
label_enumeration=task.get_labels_enumeration(),
|
label_enumeration=task.get_labels_enumeration(),
|
||||||
framework=framework, )
|
framework=framework, base_model_id=in_model_id)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
ref_model = weakref.ref(model)
|
ref_model = weakref.ref(model)
|
||||||
@ -156,29 +194,9 @@ class WeightsFileHandler(object):
|
|||||||
ref_model = None
|
ref_model = None
|
||||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||||
|
|
||||||
if not saved_path:
|
|
||||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
|
|
||||||
return saved_path
|
|
||||||
|
|
||||||
# check if we have output storage, and generate list of files to upload
|
|
||||||
if trains_out_model.upload_storage_uri:
|
|
||||||
if Path(saved_path).is_dir():
|
|
||||||
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
|
||||||
elif singlefile:
|
|
||||||
files = [str(Path(saved_path).absolute())]
|
|
||||||
else:
|
|
||||||
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
|
|
||||||
else:
|
|
||||||
files = None
|
|
||||||
|
|
||||||
# upload files if we found them, or just register the original path
|
# upload files if we found them, or just register the original path
|
||||||
if files:
|
if trains_out_model.upload_storage_uri:
|
||||||
if len(files) > 1:
|
if len(files) > 1:
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
target_filename = Path(saved_path).stem
|
|
||||||
except Exception:
|
|
||||||
target_filename = None
|
|
||||||
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
||||||
target_filename=target_filename)
|
target_filename=target_filename)
|
||||||
else:
|
else:
|
||||||
@ -186,7 +204,7 @@ class WeightsFileHandler(object):
|
|||||||
else:
|
else:
|
||||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
||||||
finally:
|
finally:
|
||||||
WeightsFileHandler._model_store_lookup_lock.release()
|
WeightsFileHandler._model_store_lookup_lock.release()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user