Change callback structure, access thorough designated class WeightsFileHandler.ModelInfo

See https://github.com/pytorch/ignite/issues/1056
This commit is contained in:
allegroai 2020-06-11 15:05:11 +03:00
parent 5475d00d52
commit 6a28a6e21d

View File

@ -5,7 +5,7 @@ import threading
import weakref import weakref
from random import randint from random import randint
from tempfile import mkstemp from tempfile import mkstemp
from typing import TYPE_CHECKING, Callable, Union from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
import six import six
from pathlib2 import Path from pathlib2 import Path
@ -24,6 +24,7 @@ _recursion_guard = {}
def _patched_call(original_fn, patched_fn): def _patched_call(original_fn, patched_fn):
def _inner_patch(*args, **kwargs): def _inner_patch(*args, **kwargs):
# noinspection PyProtectedMember,PyUnresolvedReferences
ident = threading._get_ident() if six.PY2 else threading.get_ident() ident = threading._get_ident() if six.PY2 else threading.get_ident()
if ident in _recursion_guard: if ident in _recursion_guard:
return original_fn(*args, **kwargs) return original_fn(*args, **kwargs)
@ -55,62 +56,118 @@ class WeightsFileHandler(object):
_model_pre_callbacks = {} _model_pre_callbacks = {}
_model_post_callbacks = {} _model_post_callbacks = {}
@staticmethod class ModelInfo(object):
def add_pre_callback(callback_function): def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
# type: (Callable[[str, str, str, Task], str]) -> int # type: (Optional[Model], Optional[str], str, str, str, Task) -> None
"""
:param model: None, OutputModel or InputModel
:param upload_filename: example 'filename.ext'
:param local_model_path: example /local/copy/filename.random_number.ext'
:param local_model_id: example /local/copy/filename.ext'
:param framework: example 'PyTorch'
:param task: Task object
"""
self.model = model
self.upload_filename = upload_filename
self.local_model_path = local_model_path
self.local_model_id = local_model_id
self.framework = framework
self.task = task
# callback is Callable[[Union['load', 'save'], str, str, Task], str] @staticmethod
if callback_function in WeightsFileHandler._model_pre_callbacks.values(): def _add_callback(func, target):
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0] # type: (Callable, Dict[int, Callable]) -> int
if func in target.values():
return [k for k, v in target.items() if v == func][0]
while True: while True:
h = randint(0, 1 << 31) h = randint(0, 1 << 31)
if h not in WeightsFileHandler._model_pre_callbacks: if h not in target:
break break
WeightsFileHandler._model_pre_callbacks[h] = callback_function
target[h] = func
return h return h
@staticmethod @staticmethod
def add_post_callback(callback_function): def _remove_callback(handle, target):
# type: (Callable[[str, Model, str, str, str, Task], Model]) -> int # type: (int, Dict[int, Callable]) -> bool
if handle in target:
# callback is Callable[[Union['load', 'save'], Model, str, str, str, Task], Model] target.pop(handle, None)
if callback_function in WeightsFileHandler._model_post_callbacks.values():
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
while True:
h = randint(0, 1 << 31)
if h not in WeightsFileHandler._model_post_callbacks:
break
WeightsFileHandler._model_post_callbacks[h] = callback_function
return h
@staticmethod
def remove_pre_callback(handle):
# type: (int) -> bool
if handle in WeightsFileHandler._model_pre_callbacks:
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
return True return True
return False return False
@staticmethod @classmethod
def remove_post_callback(handle): def add_pre_callback(cls, callback_function):
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
"""
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
return the existing handle.
Use this callback to modify the weights filename registered in the Trains Server. In case Trains is
configured to upload the weights file, this will affect the uploaded filename as well.
:param callback_function: A function accepting action type ("load" or "save"),
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
:return Callback handle
"""
return cls._add_callback(callback_function, cls._model_pre_callbacks)
@classmethod
def add_post_callback(cls, callback_function):
# type: (Callable[[str, dict], dict]) -> int
"""
Add a post-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle.
:param callback_function: A function accepting action type ("load" or "save"),
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
:return Callback handle
"""
return cls._add_callback(callback_function, cls._model_post_callbacks)
@classmethod
def remove_pre_callback(cls, handle):
# type: (int) -> bool # type: (int) -> bool
if handle in WeightsFileHandler._model_post_callbacks: """
WeightsFileHandler._model_post_callbacks.pop(handle, None) Add a pre-save/load callback for weights files and return its handle.
return True If the callback was already added, return the existing handle.
return False
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_pre_callback
:return True if callback removed, False otherwise
"""
return cls._remove_callback(handle, cls._model_pre_callbacks)
@classmethod
def remove_post_callback(cls, handle):
# type: (int) -> bool
"""
Add a pre-save/load callback for weights files and return its handle.
If the callback was already added, return the existing handle.
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
:return True if callback removed, False otherwise
"""
return cls._remove_callback(handle, cls._model_post_callbacks)
@staticmethod @staticmethod
def restore_weights_file(model, filepath, framework, task): def restore_weights_file(model, filepath, framework, task):
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task]) -> str
if task is None: if task is None:
return filepath return filepath
model_info = WeightsFileHandler.ModelInfo(
model=None, upload_filename=None, local_model_path=filepath,
local_model_id=filepath, framework=framework, task=task)
# call pre model callback functions # call pre model callback functions
for cb in WeightsFileHandler._model_pre_callbacks.values(): for cb in WeightsFileHandler._model_pre_callbacks.values():
filepath = cb('load', filepath, framework, task) # noinspection PyBroadException
try:
model_info = cb('load', model_info)
except Exception:
pass
if not filepath: if not model_info.local_model_path:
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged") get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
return filepath return filepath
@ -118,12 +175,16 @@ class WeightsFileHandler(object):
WeightsFileHandler._model_store_lookup_lock.acquire() WeightsFileHandler._model_store_lookup_lock.acquire()
# check if object already has InputModel # check if object already has InputModel
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get( if model_info.model:
id(model) if model is not None else None, (None, None)) trains_in_model = model_info.model
if ref_model is not None and model != ref_model(): else:
# old id pop it - it was probably reused because the object is dead trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
WeightsFileHandler._model_in_store_lookup.pop(id(model)) id(model) if model is not None else None, (None, None))
trains_in_model, ref_model = None, None # noinspection PyCallingNonCallable
if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_in_store_lookup.pop(id(model))
trains_in_model, ref_model = None, None
# check if object already has InputModel # check if object already has InputModel
model_name_id = getattr(model, 'name', '') if model else '' model_name_id = getattr(model, 'name', '') if model else ''
@ -139,30 +200,39 @@ class WeightsFileHandler(object):
except Exception: except Exception:
config_text = None config_text = None
# check if we already have the model object: if not trains_in_model:
model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None)) # check if we already have the model object:
if model_id: # noinspection PyProtectedMember
# noinspection PyBroadException model_id, model_uri = Model._local_model_to_id_uri.get(
try: model_info.local_model_id or model_info.local_model_path, (None, None))
trains_in_model = InputModel(model_id) if model_id:
except Exception: # noinspection PyBroadException
model_id = None try:
trains_in_model = InputModel(model_id)
except Exception:
model_id = None
# if we do not, we need to import the model # if we do not, we need to import the model
if not model_id: if not model_id:
trains_in_model = InputModel.import_model( trains_in_model = InputModel.import_model(
weights_url=filepath, weights_url=model_info.local_model_path,
config_dict=config_dict, config_dict=config_dict,
config_text=config_text, config_text=config_text,
name=task.name + (' ' + model_name_id) if model_name_id else '', name=task.name + (' ' + model_name_id) if model_name_id else '',
label_enumeration=task.get_labels_enumeration(), label_enumeration=task.get_labels_enumeration(),
framework=framework, framework=framework,
create_as_published=False, create_as_published=False,
) )
model_info.model = trains_in_model
# call post model callback functions # call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values(): for cb in WeightsFileHandler._model_post_callbacks.values():
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task) # noinspection PyBroadException
try:
model_info = cb('load', model_info)
except Exception:
pass
trains_in_model = model_info.model
if model is not None: if model is not None:
# noinspection PyBroadException # noinspection PyBroadException
@ -179,13 +249,19 @@ class WeightsFileHandler(object):
if False and running_remotely(): if False and running_remotely():
# reload the model # reload the model
model_config = trains_in_model.config_dict model_config = trains_in_model.config_dict
# verify that this is the same model so we are not deserializing a diff model # verify that this is the same model so we are not deserializing a different model
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \ config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
(not config_dict and not model_config): (not config_dict and not model_config):
filepath = trains_in_model.get_weights() filepath = trains_in_model.get_weights()
# update filepath to point to downloaded weights file # update filepath to point to downloaded weights file
# actual model weights loading will be done outside the try/exception block # actual model weights loading will be done outside the try/exception block
# update back the internal Model lookup, and replace the local file with our file
# noinspection PyProtectedMember
Model._local_model_to_id_uri[model_info.local_model_id] = (
trains_in_model.id, trains_in_model.url)
except Exception as ex: except Exception as ex:
get_logger(TrainsFrameworkAdapter).debug(str(ex)) get_logger(TrainsFrameworkAdapter).debug(str(ex))
finally: finally:
@ -195,6 +271,7 @@ class WeightsFileHandler(object):
@staticmethod @staticmethod
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None): def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
if task is None: if task is None:
return saved_path return saved_path
@ -205,55 +282,70 @@ class WeightsFileHandler(object):
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get( trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
id(model) if model is not None else None, (None, None)) id(model) if model is not None else None, (None, None))
# notice ref_model() is not an error/typo this is a weakref object call # notice ref_model() is not an error/typo this is a weakref object call
# noinspection PyCallingNonCallable
if ref_model is not None and model != ref_model(): if ref_model is not None and model != ref_model():
# old id pop it - it was probably reused because the object is dead # old id pop it - it was probably reused because the object is dead
WeightsFileHandler._model_out_store_lookup.pop(id(model)) WeightsFileHandler._model_out_store_lookup.pop(id(model))
trains_out_model, ref_model = None, None trains_out_model, ref_model = None, None
if not saved_path: model_info = WeightsFileHandler.ModelInfo(
model=trains_out_model, upload_filename=None, local_model_path=saved_path,
local_model_id=saved_path, framework=framework, task=task)
if not model_info.local_model_path:
get_logger(TrainsFrameworkAdapter).warning( get_logger(TrainsFrameworkAdapter).warning(
"Could not retrieve model location, skipping auto model logging") "Could not retrieve model location, skipping auto model logging")
return saved_path return saved_path
# check if we have output storage, and generate list of files to upload # check if we have output storage, and generate list of files to upload
if Path(saved_path).is_dir(): if Path(model_info.local_model_path).is_dir():
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()] files = [str(f) for f in Path(model_info.local_model_path).rglob('*') if f.is_file()]
elif singlefile: elif singlefile:
files = [str(Path(saved_path).absolute())] files = [str(Path(model_info.local_model_path).absolute())]
else: else:
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')] files = [str(f) for f in Path(model_info.local_model_path).parent.glob(
str(Path(model_info.local_model_path).name) + '.*')]
target_filename = None target_filename = None
if len(files) > 1: if len(files) > 1:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
target_filename = Path(saved_path).stem target_filename = Path(model_info.local_model_path).stem
except Exception: except Exception:
pass pass
else: else:
target_filename = Path(files[0]).name target_filename = Path(files[0]).name
# call pre model callback functions # call pre model callback functions
model_info.upload_filename = target_filename
for cb in WeightsFileHandler._model_pre_callbacks.values(): for cb in WeightsFileHandler._model_pre_callbacks.values():
target_filename = cb('save', target_filename, framework, task) # noinspection PyBroadException
try:
model_info = cb('save', model_info)
except Exception:
pass
trains_out_model = model_info.model
# check if object already has InputModel # check if object already has InputModel
if trains_out_model is None: if trains_out_model is None:
in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None)) # noinspection PyProtectedMember
in_model_id, model_uri = Model._local_model_to_id_uri.get(
model_info.local_model_id or model_info.local_model_path, (None, None))
if not in_model_id: if not in_model_id:
# if we are overwriting a local file, try to load registered model # if we are overwriting a local file, try to load registered model
# if there is an output_uri, then by definition we will not overwrite previously stored models. # if there is an output_uri, then by definition we will not overwrite previously stored models.
if not task.output_uri: if not task.output_uri:
# noinspection PyBroadException
try: try:
in_model_id = InputModel.load_model(weights_url=saved_path) in_model_id = InputModel.load_model(weights_url=model_info.local_model_path)
if in_model_id: if in_model_id:
in_model_id = in_model_id.id in_model_id = in_model_id.id
get_logger(TrainsFrameworkAdapter).info( get_logger(TrainsFrameworkAdapter).info(
"Found existing registered model id={} [{}] reusing it.".format( "Found existing registered model id={} [{}] reusing it.".format(
in_model_id, saved_path)) in_model_id, model_info.local_model_path))
except: except Exception:
in_model_id = None in_model_id = None
else: else:
in_model_id = None in_model_id = None
@ -274,9 +366,16 @@ class WeightsFileHandler(object):
ref_model = None ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
model_info.model = trains_out_model
# call post model callback functions # call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values(): for cb in WeightsFileHandler._model_post_callbacks.values():
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task) # noinspection PyBroadException
try:
model_info = cb('save', model_info)
except Exception:
pass
trains_out_model = model_info.model
target_filename = model_info.upload_filename
# upload files if we found them, or just register the original path # upload files if we found them, or just register the original path
if trains_out_model.upload_storage_uri: if trains_out_model.upload_storage_uri:
@ -294,9 +393,16 @@ class WeightsFileHandler(object):
os.close(fd) os.close(fd)
shutil.copy(files[0], temp_file) shutil.copy(files[0], temp_file)
trains_out_model.update_weights( trains_out_model.update_weights(
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename) weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename,
update_comment=False)
else: else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) trains_out_model.update_weights(weights_filename=None, register_uri=model_info.local_model_path)
# update back the internal Model lookup, and replace the local file with our file
# noinspection PyProtectedMember
Model._local_model_to_id_uri[model_info.local_model_id] = (
trains_out_model.id, trains_out_model.url)
except Exception as ex: except Exception as ex:
get_logger(TrainsFrameworkAdapter).debug(str(ex)) get_logger(TrainsFrameworkAdapter).debug(str(ex))
finally: finally: