mirror of
https://github.com/clearml/clearml
synced 2025-04-06 05:35:32 +00:00
Change callback structure, access thorough designated class WeightsFileHandler.ModelInfo
See https://github.com/pytorch/ignite/issues/1056
This commit is contained in:
parent
5475d00d52
commit
6a28a6e21d
@ -5,7 +5,7 @@ import threading
|
||||
import weakref
|
||||
from random import randint
|
||||
from tempfile import mkstemp
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
@ -24,6 +24,7 @@ _recursion_guard = {}
|
||||
|
||||
def _patched_call(original_fn, patched_fn):
|
||||
def _inner_patch(*args, **kwargs):
|
||||
# noinspection PyProtectedMember,PyUnresolvedReferences
|
||||
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||
if ident in _recursion_guard:
|
||||
return original_fn(*args, **kwargs)
|
||||
@ -55,62 +56,118 @@ class WeightsFileHandler(object):
|
||||
_model_pre_callbacks = {}
|
||||
_model_post_callbacks = {}
|
||||
|
||||
@staticmethod
|
||||
def add_pre_callback(callback_function):
|
||||
# type: (Callable[[str, str, str, Task], str]) -> int
|
||||
class ModelInfo(object):
|
||||
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
|
||||
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
|
||||
"""
|
||||
:param model: None, OutputModel or InputModel
|
||||
:param upload_filename: example 'filename.ext'
|
||||
:param local_model_path: example /local/copy/filename.random_number.ext'
|
||||
:param local_model_id: example /local/copy/filename.ext'
|
||||
:param framework: example 'PyTorch'
|
||||
:param task: Task object
|
||||
"""
|
||||
self.model = model
|
||||
self.upload_filename = upload_filename
|
||||
self.local_model_path = local_model_path
|
||||
self.local_model_id = local_model_id
|
||||
self.framework = framework
|
||||
self.task = task
|
||||
|
||||
# callback is Callable[[Union['load', 'save'], str, str, Task], str]
|
||||
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
|
||||
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
|
||||
@staticmethod
|
||||
def _add_callback(func, target):
|
||||
# type: (Callable, Dict[int, Callable]) -> int
|
||||
|
||||
if func in target.values():
|
||||
return [k for k, v in target.items() if v == func][0]
|
||||
|
||||
while True:
|
||||
h = randint(0, 1 << 31)
|
||||
if h not in WeightsFileHandler._model_pre_callbacks:
|
||||
if h not in target:
|
||||
break
|
||||
WeightsFileHandler._model_pre_callbacks[h] = callback_function
|
||||
|
||||
target[h] = func
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def add_post_callback(callback_function):
|
||||
# type: (Callable[[str, Model, str, str, str, Task], Model]) -> int
|
||||
|
||||
# callback is Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]
|
||||
if callback_function in WeightsFileHandler._model_post_callbacks.values():
|
||||
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
|
||||
|
||||
while True:
|
||||
h = randint(0, 1 << 31)
|
||||
if h not in WeightsFileHandler._model_post_callbacks:
|
||||
break
|
||||
WeightsFileHandler._model_post_callbacks[h] = callback_function
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def remove_pre_callback(handle):
|
||||
# type: (int) -> bool
|
||||
if handle in WeightsFileHandler._model_pre_callbacks:
|
||||
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
|
||||
def _remove_callback(handle, target):
|
||||
# type: (int, Dict[int, Callable]) -> bool
|
||||
if handle in target:
|
||||
target.pop(handle, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def remove_post_callback(handle):
|
||||
@classmethod
|
||||
def add_pre_callback(cls, callback_function):
|
||||
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
|
||||
"""
|
||||
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
|
||||
return the existing handle.
|
||||
|
||||
Use this callback to modify the weights filename registered in the Trains Server. In case Trains is
|
||||
configured to upload the weights file, this will affect the uploaded filename as well.
|
||||
|
||||
:param callback_function: A function accepting action type ("load" or "save"),
|
||||
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
|
||||
:return Callback handle
|
||||
"""
|
||||
return cls._add_callback(callback_function, cls._model_pre_callbacks)
|
||||
|
||||
@classmethod
|
||||
def add_post_callback(cls, callback_function):
|
||||
# type: (Callable[[str, dict], dict]) -> int
|
||||
"""
|
||||
Add a post-save/load callback for weights files and return its handle.
|
||||
If the callback was already added, return the existing handle.
|
||||
|
||||
:param callback_function: A function accepting action type ("load" or "save"),
|
||||
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
|
||||
:return Callback handle
|
||||
"""
|
||||
return cls._add_callback(callback_function, cls._model_post_callbacks)
|
||||
|
||||
@classmethod
|
||||
def remove_pre_callback(cls, handle):
|
||||
# type: (int) -> bool
|
||||
if handle in WeightsFileHandler._model_post_callbacks:
|
||||
WeightsFileHandler._model_post_callbacks.pop(handle, None)
|
||||
return True
|
||||
return False
|
||||
"""
|
||||
Add a pre-save/load callback for weights files and return its handle.
|
||||
If the callback was already added, return the existing handle.
|
||||
|
||||
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_pre_callback
|
||||
:return True if callback removed, False otherwise
|
||||
"""
|
||||
return cls._remove_callback(handle, cls._model_pre_callbacks)
|
||||
|
||||
@classmethod
|
||||
def remove_post_callback(cls, handle):
|
||||
# type: (int) -> bool
|
||||
"""
|
||||
Add a pre-save/load callback for weights files and return its handle.
|
||||
If the callback was already added, return the existing handle.
|
||||
|
||||
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
|
||||
:return True if callback removed, False otherwise
|
||||
"""
|
||||
return cls._remove_callback(handle, cls._model_post_callbacks)
|
||||
|
||||
@staticmethod
|
||||
def restore_weights_file(model, filepath, framework, task):
|
||||
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task]) -> str
|
||||
if task is None:
|
||||
return filepath
|
||||
|
||||
model_info = WeightsFileHandler.ModelInfo(
|
||||
model=None, upload_filename=None, local_model_path=filepath,
|
||||
local_model_id=filepath, framework=framework, task=task)
|
||||
# call pre model callback functions
|
||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||
filepath = cb('load', filepath, framework, task)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('load', model_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not filepath:
|
||||
if not model_info.local_model_path:
|
||||
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
|
||||
return filepath
|
||||
|
||||
@ -118,12 +175,16 @@ class WeightsFileHandler(object):
|
||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||
|
||||
# check if object already has InputModel
|
||||
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
|
||||
id(model) if model is not None else None, (None, None))
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
||||
trains_in_model, ref_model = None, None
|
||||
if model_info.model:
|
||||
trains_in_model = model_info.model
|
||||
else:
|
||||
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
|
||||
id(model) if model is not None else None, (None, None))
|
||||
# noinspection PyCallingNonCallable
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
||||
trains_in_model, ref_model = None, None
|
||||
|
||||
# check if object already has InputModel
|
||||
model_name_id = getattr(model, 'name', '') if model else ''
|
||||
@ -139,30 +200,39 @@ class WeightsFileHandler(object):
|
||||
except Exception:
|
||||
config_text = None
|
||||
|
||||
# check if we already have the model object:
|
||||
model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None))
|
||||
if model_id:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
trains_in_model = InputModel(model_id)
|
||||
except Exception:
|
||||
model_id = None
|
||||
if not trains_in_model:
|
||||
# check if we already have the model object:
|
||||
# noinspection PyProtectedMember
|
||||
model_id, model_uri = Model._local_model_to_id_uri.get(
|
||||
model_info.local_model_id or model_info.local_model_path, (None, None))
|
||||
if model_id:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
trains_in_model = InputModel(model_id)
|
||||
except Exception:
|
||||
model_id = None
|
||||
|
||||
# if we do not, we need to import the model
|
||||
if not model_id:
|
||||
trains_in_model = InputModel.import_model(
|
||||
weights_url=filepath,
|
||||
config_dict=config_dict,
|
||||
config_text=config_text,
|
||||
name=task.name + (' ' + model_name_id) if model_name_id else '',
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
create_as_published=False,
|
||||
)
|
||||
# if we do not, we need to import the model
|
||||
if not model_id:
|
||||
trains_in_model = InputModel.import_model(
|
||||
weights_url=model_info.local_model_path,
|
||||
config_dict=config_dict,
|
||||
config_text=config_text,
|
||||
name=task.name + (' ' + model_name_id) if model_name_id else '',
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework,
|
||||
create_as_published=False,
|
||||
)
|
||||
|
||||
model_info.model = trains_in_model
|
||||
# call post model callback functions
|
||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('load', model_info)
|
||||
except Exception:
|
||||
pass
|
||||
trains_in_model = model_info.model
|
||||
|
||||
if model is not None:
|
||||
# noinspection PyBroadException
|
||||
@ -179,13 +249,19 @@ class WeightsFileHandler(object):
|
||||
if False and running_remotely():
|
||||
# reload the model
|
||||
model_config = trains_in_model.config_dict
|
||||
# verify that this is the same model so we are not deserializing a diff model
|
||||
# verify that this is the same model so we are not deserializing a different model
|
||||
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
||||
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
||||
(not config_dict and not model_config):
|
||||
filepath = trains_in_model.get_weights()
|
||||
# update filepath to point to downloaded weights file
|
||||
# actual model weights loading will be done outside the try/exception block
|
||||
|
||||
# update back the internal Model lookup, and replace the local file with our file
|
||||
# noinspection PyProtectedMember
|
||||
Model._local_model_to_id_uri[model_info.local_model_id] = (
|
||||
trains_in_model.id, trains_in_model.url)
|
||||
|
||||
except Exception as ex:
|
||||
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
finally:
|
||||
@ -195,6 +271,7 @@ class WeightsFileHandler(object):
|
||||
|
||||
@staticmethod
|
||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
||||
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
|
||||
if task is None:
|
||||
return saved_path
|
||||
|
||||
@ -205,55 +282,70 @@ class WeightsFileHandler(object):
|
||||
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
|
||||
id(model) if model is not None else None, (None, None))
|
||||
# notice ref_model() is not an error/typo this is a weakref object call
|
||||
# noinspection PyCallingNonCallable
|
||||
if ref_model is not None and model != ref_model():
|
||||
# old id pop it - it was probably reused because the object is dead
|
||||
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
||||
trains_out_model, ref_model = None, None
|
||||
|
||||
if not saved_path:
|
||||
model_info = WeightsFileHandler.ModelInfo(
|
||||
model=trains_out_model, upload_filename=None, local_model_path=saved_path,
|
||||
local_model_id=saved_path, framework=framework, task=task)
|
||||
|
||||
if not model_info.local_model_path:
|
||||
get_logger(TrainsFrameworkAdapter).warning(
|
||||
"Could not retrieve model location, skipping auto model logging")
|
||||
return saved_path
|
||||
|
||||
# check if we have output storage, and generate list of files to upload
|
||||
if Path(saved_path).is_dir():
|
||||
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
||||
if Path(model_info.local_model_path).is_dir():
|
||||
files = [str(f) for f in Path(model_info.local_model_path).rglob('*') if f.is_file()]
|
||||
elif singlefile:
|
||||
files = [str(Path(saved_path).absolute())]
|
||||
files = [str(Path(model_info.local_model_path).absolute())]
|
||||
else:
|
||||
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
|
||||
files = [str(f) for f in Path(model_info.local_model_path).parent.glob(
|
||||
str(Path(model_info.local_model_path).name) + '.*')]
|
||||
|
||||
target_filename = None
|
||||
if len(files) > 1:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
target_filename = Path(saved_path).stem
|
||||
target_filename = Path(model_info.local_model_path).stem
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
target_filename = Path(files[0]).name
|
||||
|
||||
# call pre model callback functions
|
||||
model_info.upload_filename = target_filename
|
||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||
target_filename = cb('save', target_filename, framework, task)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('save', model_info)
|
||||
except Exception:
|
||||
pass
|
||||
trains_out_model = model_info.model
|
||||
|
||||
# check if object already has InputModel
|
||||
if trains_out_model is None:
|
||||
in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None))
|
||||
# noinspection PyProtectedMember
|
||||
in_model_id, model_uri = Model._local_model_to_id_uri.get(
|
||||
model_info.local_model_id or model_info.local_model_path, (None, None))
|
||||
|
||||
if not in_model_id:
|
||||
# if we are overwriting a local file, try to load registered model
|
||||
# if there is an output_uri, then by definition we will not overwrite previously stored models.
|
||||
if not task.output_uri:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
in_model_id = InputModel.load_model(weights_url=saved_path)
|
||||
in_model_id = InputModel.load_model(weights_url=model_info.local_model_path)
|
||||
if in_model_id:
|
||||
in_model_id = in_model_id.id
|
||||
|
||||
get_logger(TrainsFrameworkAdapter).info(
|
||||
"Found existing registered model id={} [{}] reusing it.".format(
|
||||
in_model_id, saved_path))
|
||||
except:
|
||||
in_model_id, model_info.local_model_path))
|
||||
except Exception:
|
||||
in_model_id = None
|
||||
else:
|
||||
in_model_id = None
|
||||
@ -274,9 +366,16 @@ class WeightsFileHandler(object):
|
||||
ref_model = None
|
||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||
|
||||
model_info.model = trains_out_model
|
||||
# call post model callback functions
|
||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_info = cb('save', model_info)
|
||||
except Exception:
|
||||
pass
|
||||
trains_out_model = model_info.model
|
||||
target_filename = model_info.upload_filename
|
||||
|
||||
# upload files if we found them, or just register the original path
|
||||
if trains_out_model.upload_storage_uri:
|
||||
@ -294,9 +393,16 @@ class WeightsFileHandler(object):
|
||||
os.close(fd)
|
||||
shutil.copy(files[0], temp_file)
|
||||
trains_out_model.update_weights(
|
||||
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
|
||||
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename,
|
||||
update_comment=False)
|
||||
else:
|
||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||
trains_out_model.update_weights(weights_filename=None, register_uri=model_info.local_model_path)
|
||||
|
||||
# update back the internal Model lookup, and replace the local file with our file
|
||||
# noinspection PyProtectedMember
|
||||
Model._local_model_to_id_uri[model_info.local_model_id] = (
|
||||
trains_out_model.id, trains_out_model.url)
|
||||
|
||||
except Exception as ex:
|
||||
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
finally:
|
||||
|
Loading…
Reference in New Issue
Block a user