Change callback structure, access thorough designated class WeightsFileHandler.ModelInfo

See https://github.com/pytorch/ignite/issues/1056
This commit is contained in:
allegroai 2020-06-11 15:05:11 +03:00
parent 5475d00d52
commit 6a28a6e21d

View File

@ -5,7 +5,7 @@ import threading
import weakref
from random import randint
from tempfile import mkstemp
from typing import TYPE_CHECKING, Callable, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
import six
from pathlib2 import Path
@ -24,6 +24,7 @@ _recursion_guard = {}
def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs):
# noinspection PyProtectedMember,PyUnresolvedReferences
ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard:
return original_fn(*args, **kwargs)
@ -55,62 +56,118 @@ class WeightsFileHandler(object):
_model_pre_callbacks = {}
_model_post_callbacks = {}
@staticmethod
def add_pre_callback(callback_function):
# type: (Callable[[str, str, str, Task], str]) -> int
class ModelInfo(object):
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
"""
:param model: None, OutputModel or InputModel
:param upload_filename: example 'filename.ext'
:param local_model_path: example /local/copy/filename.random_number.ext'
:param local_model_id: example /local/copy/filename.ext'
:param framework: example 'PyTorch'
:param task: Task object
"""
self.model = model
self.upload_filename = upload_filename
self.local_model_path = local_model_path
self.local_model_id = local_model_id
self.framework = framework
self.task = task
# callback is Callable[[Union['load', 'save'], str, str, Task], str]
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
@staticmethod
def _add_callback(func, target):
# type: (Callable, Dict[int, Callable]) -> int
if func in target.values():
return [k for k, v in target.items() if v == func][0]
while True:
h = randint(0, 1 << 31)
if h not in WeightsFileHandler._model_pre_callbacks:
if h not in target:
break
WeightsFileHandler._model_pre_callbacks[h] = callback_function
target[h] = func
return h
@staticmethod
def add_post_callback(callback_function):
# type: (Callable[[str, Model, str, str, str, Task], Model]) -> int
# callback is Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]
if callback_function in WeightsFileHandler._model_post_callbacks.values():
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
while True:
h = randint(0, 1 << 31)
if h not in WeightsFileHandler._model_post_callbacks:
break
WeightsFileHandler._model_post_callbacks[h] = callback_function
return h
@staticmethod
def remove_pre_callback(handle):
# type: (int) -> bool
if handle in WeightsFileHandler._model_pre_callbacks:
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
def _remove_callback(handle, target):
# type: (int, Dict[int, Callable]) -> bool
if handle in target:
target.pop(handle, None)
return True
return False
@staticmethod
def remove_post_callback(handle):
@classmethod
def add_pre_callback(cls, callback_function):
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
"""
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
return the existing handle.
Use this callback to modify the weights filename registered in the Trains Server. In case Trains is
configured to upload the weights file, this will affect the uploaded filename as well.
:param callback_function: A function accepting action type ("load" or "save"),
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
:return Callback handle
"""
return cls._add_callback(callback_function, cls._model_pre_callbacks)
@classmethod
def add_post_callback(cls, callback_function):
# type: (Callable[[str, dict], dict]) -> int
"""
Add a post-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle.
:param callback_function: A function accepting action type ("load" or "save"),
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
:return Callback handle
"""
return cls._add_callback(callback_function, cls._model_post_callbacks)
@classmethod
def remove_pre_callback(cls, handle):
# type: (int) -> bool
if handle in WeightsFileHandler._model_post_callbacks:
WeightsFileHandler._model_post_callbacks.pop(handle, None)
return True
return False
"""
Add a pre-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle.
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_pre_callback
:return True if callback removed, False otherwise
"""
return cls._remove_callback(handle, cls._model_pre_callbacks)
@classmethod
def remove_post_callback(cls, handle):
# type: (int) -> bool
"""
Add a pre-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle.
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
:return True if callback removed, False otherwise
"""
return cls._remove_callback(handle, cls._model_post_callbacks)
@staticmethod
def restore_weights_file(model, filepath, framework, task):
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task]) -> str
if task is None:
return filepath
model_info = WeightsFileHandler.ModelInfo(
model=None, upload_filename=None, local_model_path=filepath,
local_model_id=filepath, framework=framework, task=task)
# call pre model callback functions
for cb in WeightsFileHandler._model_pre_callbacks.values():
filepath = cb('load', filepath, framework, task)
# noinspection PyBroadException
try:
model_info = cb('load', model_info)
except Exception:
pass
if not filepath:
if not model_info.local_model_path:
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
return filepath
@ -118,12 +175,16 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
id(model) if model is not None else None, (None, None))
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
if model_info.model:
trains_in_model = model_info.model
else:
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
id(model) if model is not None else None, (None, None))
# noinspection PyCallingNonCallable
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel
model_name_id = getattr(model, 'name', '') if model else ''
@ -139,30 +200,39 @@ class WeightsFileHandler(object):
except Exception:
config_text = None
# check if we already have the model object:
model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None))
if model_id:
# noinspection PyBroadException
try:
trains_in_model = InputModel(model_id)
except Exception:
model_id = None
if not trains_in_model:
# check if we already have the model object:
# noinspection PyProtectedMember
model_id, model_uri = Model._local_model_to_id_uri.get(
model_info.local_model_id or model_info.local_model_path, (None, None))
if model_id:
# noinspection PyBroadException
try:
trains_in_model = InputModel(model_id)
except Exception:
model_id = None
# if we do not, we need to import the model
if not model_id:
trains_in_model = InputModel.import_model(
weights_url=filepath,
config_dict=config_dict,
config_text=config_text,
name=task.name + (' ' + model_name_id) if model_name_id else '',
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
# if we do not, we need to import the model
if not model_id:
trains_in_model = InputModel.import_model(
weights_url=model_info.local_model_path,
config_dict=config_dict,
config_text=config_text,
name=task.name + (' ' + model_name_id) if model_name_id else '',
label_enumeration=task.get_labels_enumeration(),
framework=framework,
create_as_published=False,
)
model_info.model = trains_in_model
# call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values():
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
# noinspection PyBroadException
try:
model_info = cb('load', model_info)
except Exception:
pass
trains_in_model = model_info.model
if model is not None:
# noinspection PyBroadException
@ -179,13 +249,19 @@ class WeightsFileHandler(object):
if False and running_remotely():
# reload the model
model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model
# verify that this is the same model so we are not deserializing a different model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config):
filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block
# update back the internal Model lookup, and replace the local file with our file
# noinspection PyProtectedMember
Model._local_model_to_id_uri[model_info.local_model_id] = (
trains_in_model.id, trains_in_model.url)
except Exception as ex:
get_logger(TrainsFrameworkAdapter).debug(str(ex))
finally:
@ -195,6 +271,7 @@ class WeightsFileHandler(object):
@staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
if task is None:
return saved_path
@ -205,55 +282,70 @@ class WeightsFileHandler(object):
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
id(model) if model is not None else None, (None, None))
# notice ref_model() is not an error/typo this is a weakref object call
# noinspection PyCallingNonCallable
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None
if not saved_path:
model_info = WeightsFileHandler.ModelInfo(
model=trains_out_model, upload_filename=None, local_model_path=saved_path,
local_model_id=saved_path, framework=framework, task=task)
if not model_info.local_model_path:
get_logger(TrainsFrameworkAdapter).warning(
"Could not retrieve model location, skipping auto model logging")
return saved_path
# check if we have output storage, and generate list of files to upload
if Path(saved_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
if Path(model_info.local_model_path).is_dir():
files = [str(f) for f in Path(model_info.local_model_path).rglob('*') if f.is_file()]
elif singlefile:
files = [str(Path(saved_path).absolute())]
files = [str(Path(model_info.local_model_path).absolute())]
else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
files = [str(f) for f in Path(model_info.local_model_path).parent.glob(
str(Path(model_info.local_model_path).name) + '.*')]
target_filename = None
if len(files) > 1:
# noinspection PyBroadException
try:
target_filename = Path(saved_path).stem
target_filename = Path(model_info.local_model_path).stem
except Exception:
pass
else:
target_filename = Path(files[0]).name
# call pre model callback functions
model_info.upload_filename = target_filename
for cb in WeightsFileHandler._model_pre_callbacks.values():
target_filename = cb('save', target_filename, framework, task)
# noinspection PyBroadException
try:
model_info = cb('save', model_info)
except Exception:
pass
trains_out_model = model_info.model
# check if object already has InputModel
if trains_out_model is None:
in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None))
# noinspection PyProtectedMember
in_model_id, model_uri = Model._local_model_to_id_uri.get(
model_info.local_model_id or model_info.local_model_path, (None, None))
if not in_model_id:
# if we are overwriting a local file, try to load registered model
# if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri:
# noinspection PyBroadException
try:
in_model_id = InputModel.load_model(weights_url=saved_path)
in_model_id = InputModel.load_model(weights_url=model_info.local_model_path)
if in_model_id:
in_model_id = in_model_id.id
get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path))
except:
in_model_id, model_info.local_model_path))
except Exception:
in_model_id = None
else:
in_model_id = None
@ -274,9 +366,16 @@ class WeightsFileHandler(object):
ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
model_info.model = trains_out_model
# call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values():
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
# noinspection PyBroadException
try:
model_info = cb('save', model_info)
except Exception:
pass
trains_out_model = model_info.model
target_filename = model_info.upload_filename
# upload files if we found them, or just register the original path
if trains_out_model.upload_storage_uri:
@ -294,9 +393,16 @@ class WeightsFileHandler(object):
os.close(fd)
shutil.copy(files[0], temp_file)
trains_out_model.update_weights(
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename,
update_comment=False)
else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
trains_out_model.update_weights(weights_filename=None, register_uri=model_info.local_model_path)
# update back the internal Model lookup, and replace the local file with our file
# noinspection PyProtectedMember
Model._local_model_to_id_uri[model_info.local_model_id] = (
trains_out_model.id, trains_out_model.url)
except Exception as ex:
get_logger(TrainsFrameworkAdapter).debug(str(ex))
finally: