mirror of
https://github.com/clearml/clearml
synced 2025-04-08 06:34:37 +00:00
Change callback structure, access thorough designated class WeightsFileHandler.ModelInfo
See https://github.com/pytorch/ignite/issues/1056
This commit is contained in:
parent
5475d00d52
commit
6a28a6e21d
@ -5,7 +5,7 @@ import threading
|
|||||||
import weakref
|
import weakref
|
||||||
from random import randint
|
from random import randint
|
||||||
from tempfile import mkstemp
|
from tempfile import mkstemp
|
||||||
from typing import TYPE_CHECKING, Callable, Union
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Any
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
@ -24,6 +24,7 @@ _recursion_guard = {}
|
|||||||
|
|
||||||
def _patched_call(original_fn, patched_fn):
|
def _patched_call(original_fn, patched_fn):
|
||||||
def _inner_patch(*args, **kwargs):
|
def _inner_patch(*args, **kwargs):
|
||||||
|
# noinspection PyProtectedMember,PyUnresolvedReferences
|
||||||
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||||
if ident in _recursion_guard:
|
if ident in _recursion_guard:
|
||||||
return original_fn(*args, **kwargs)
|
return original_fn(*args, **kwargs)
|
||||||
@ -55,62 +56,118 @@ class WeightsFileHandler(object):
|
|||||||
_model_pre_callbacks = {}
|
_model_pre_callbacks = {}
|
||||||
_model_post_callbacks = {}
|
_model_post_callbacks = {}
|
||||||
|
|
||||||
@staticmethod
|
class ModelInfo(object):
|
||||||
def add_pre_callback(callback_function):
|
def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task):
|
||||||
# type: (Callable[[str, str, str, Task], str]) -> int
|
# type: (Optional[Model], Optional[str], str, str, str, Task) -> None
|
||||||
|
"""
|
||||||
|
:param model: None, OutputModel or InputModel
|
||||||
|
:param upload_filename: example 'filename.ext'
|
||||||
|
:param local_model_path: example /local/copy/filename.random_number.ext'
|
||||||
|
:param local_model_id: example /local/copy/filename.ext'
|
||||||
|
:param framework: example 'PyTorch'
|
||||||
|
:param task: Task object
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.upload_filename = upload_filename
|
||||||
|
self.local_model_path = local_model_path
|
||||||
|
self.local_model_id = local_model_id
|
||||||
|
self.framework = framework
|
||||||
|
self.task = task
|
||||||
|
|
||||||
# callback is Callable[[Union['load', 'save'], str, str, Task], str]
|
@staticmethod
|
||||||
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
|
def _add_callback(func, target):
|
||||||
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
|
# type: (Callable, Dict[int, Callable]) -> int
|
||||||
|
|
||||||
|
if func in target.values():
|
||||||
|
return [k for k, v in target.items() if v == func][0]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
h = randint(0, 1 << 31)
|
h = randint(0, 1 << 31)
|
||||||
if h not in WeightsFileHandler._model_pre_callbacks:
|
if h not in target:
|
||||||
break
|
break
|
||||||
WeightsFileHandler._model_pre_callbacks[h] = callback_function
|
|
||||||
|
target[h] = func
|
||||||
return h
|
return h
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_post_callback(callback_function):
|
def _remove_callback(handle, target):
|
||||||
# type: (Callable[[str, Model, str, str, str, Task], Model]) -> int
|
# type: (int, Dict[int, Callable]) -> bool
|
||||||
|
if handle in target:
|
||||||
# callback is Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]
|
target.pop(handle, None)
|
||||||
if callback_function in WeightsFileHandler._model_post_callbacks.values():
|
|
||||||
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
|
|
||||||
|
|
||||||
while True:
|
|
||||||
h = randint(0, 1 << 31)
|
|
||||||
if h not in WeightsFileHandler._model_post_callbacks:
|
|
||||||
break
|
|
||||||
WeightsFileHandler._model_post_callbacks[h] = callback_function
|
|
||||||
return h
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_pre_callback(handle):
|
|
||||||
# type: (int) -> bool
|
|
||||||
if handle in WeightsFileHandler._model_pre_callbacks:
|
|
||||||
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def remove_post_callback(handle):
|
def add_pre_callback(cls, callback_function):
|
||||||
|
# type: (Callable[[str, ModelInfo], ModelInfo]) -> int
|
||||||
|
"""
|
||||||
|
Add a pre-save/load callback for weights files and return its handle. If the callback was already added,
|
||||||
|
return the existing handle.
|
||||||
|
|
||||||
|
Use this callback to modify the weights filename registered in the Trains Server. In case Trains is
|
||||||
|
configured to upload the weights file, this will affect the uploaded filename as well.
|
||||||
|
|
||||||
|
:param callback_function: A function accepting action type ("load" or "save"),
|
||||||
|
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
|
||||||
|
:return Callback handle
|
||||||
|
"""
|
||||||
|
return cls._add_callback(callback_function, cls._model_pre_callbacks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_post_callback(cls, callback_function):
|
||||||
|
# type: (Callable[[str, dict], dict]) -> int
|
||||||
|
"""
|
||||||
|
Add a post-save/load callback for weights files and return its handle.
|
||||||
|
If the callback was already added, return the existing handle.
|
||||||
|
|
||||||
|
:param callback_function: A function accepting action type ("load" or "save"),
|
||||||
|
callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo
|
||||||
|
:return Callback handle
|
||||||
|
"""
|
||||||
|
return cls._add_callback(callback_function, cls._model_post_callbacks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remove_pre_callback(cls, handle):
|
||||||
# type: (int) -> bool
|
# type: (int) -> bool
|
||||||
if handle in WeightsFileHandler._model_post_callbacks:
|
"""
|
||||||
WeightsFileHandler._model_post_callbacks.pop(handle, None)
|
Add a pre-save/load callback for weights files and return its handle.
|
||||||
return True
|
If the callback was already added, return the existing handle.
|
||||||
return False
|
|
||||||
|
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_pre_callback
|
||||||
|
:return True if callback removed, False otherwise
|
||||||
|
"""
|
||||||
|
return cls._remove_callback(handle, cls._model_pre_callbacks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remove_post_callback(cls, handle):
|
||||||
|
# type: (int) -> bool
|
||||||
|
"""
|
||||||
|
Add a pre-save/load callback for weights files and return its handle.
|
||||||
|
If the callback was already added, return the existing handle.
|
||||||
|
|
||||||
|
:param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback
|
||||||
|
:return True if callback removed, False otherwise
|
||||||
|
"""
|
||||||
|
return cls._remove_callback(handle, cls._model_post_callbacks)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def restore_weights_file(model, filepath, framework, task):
|
def restore_weights_file(model, filepath, framework, task):
|
||||||
|
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task]) -> str
|
||||||
if task is None:
|
if task is None:
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
|
model_info = WeightsFileHandler.ModelInfo(
|
||||||
|
model=None, upload_filename=None, local_model_path=filepath,
|
||||||
|
local_model_id=filepath, framework=framework, task=task)
|
||||||
# call pre model callback functions
|
# call pre model callback functions
|
||||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
filepath = cb('load', filepath, framework, task)
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_info = cb('load', model_info)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if not filepath:
|
if not model_info.local_model_path:
|
||||||
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
|
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
@ -118,12 +175,16 @@ class WeightsFileHandler(object):
|
|||||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
|
if model_info.model:
|
||||||
id(model) if model is not None else None, (None, None))
|
trains_in_model = model_info.model
|
||||||
if ref_model is not None and model != ref_model():
|
else:
|
||||||
# old id pop it - it was probably reused because the object is dead
|
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(
|
||||||
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
id(model) if model is not None else None, (None, None))
|
||||||
trains_in_model, ref_model = None, None
|
# noinspection PyCallingNonCallable
|
||||||
|
if ref_model is not None and model != ref_model():
|
||||||
|
# old id pop it - it was probably reused because the object is dead
|
||||||
|
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
||||||
|
trains_in_model, ref_model = None, None
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
model_name_id = getattr(model, 'name', '') if model else ''
|
model_name_id = getattr(model, 'name', '') if model else ''
|
||||||
@ -139,30 +200,39 @@ class WeightsFileHandler(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
config_text = None
|
config_text = None
|
||||||
|
|
||||||
# check if we already have the model object:
|
if not trains_in_model:
|
||||||
model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None))
|
# check if we already have the model object:
|
||||||
if model_id:
|
# noinspection PyProtectedMember
|
||||||
# noinspection PyBroadException
|
model_id, model_uri = Model._local_model_to_id_uri.get(
|
||||||
try:
|
model_info.local_model_id or model_info.local_model_path, (None, None))
|
||||||
trains_in_model = InputModel(model_id)
|
if model_id:
|
||||||
except Exception:
|
# noinspection PyBroadException
|
||||||
model_id = None
|
try:
|
||||||
|
trains_in_model = InputModel(model_id)
|
||||||
|
except Exception:
|
||||||
|
model_id = None
|
||||||
|
|
||||||
# if we do not, we need to import the model
|
# if we do not, we need to import the model
|
||||||
if not model_id:
|
if not model_id:
|
||||||
trains_in_model = InputModel.import_model(
|
trains_in_model = InputModel.import_model(
|
||||||
weights_url=filepath,
|
weights_url=model_info.local_model_path,
|
||||||
config_dict=config_dict,
|
config_dict=config_dict,
|
||||||
config_text=config_text,
|
config_text=config_text,
|
||||||
name=task.name + (' ' + model_name_id) if model_name_id else '',
|
name=task.name + (' ' + model_name_id) if model_name_id else '',
|
||||||
label_enumeration=task.get_labels_enumeration(),
|
label_enumeration=task.get_labels_enumeration(),
|
||||||
framework=framework,
|
framework=framework,
|
||||||
create_as_published=False,
|
create_as_published=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_info.model = trains_in_model
|
||||||
# call post model callback functions
|
# call post model callback functions
|
||||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||||
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_info = cb('load', model_info)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
trains_in_model = model_info.model
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -179,13 +249,19 @@ class WeightsFileHandler(object):
|
|||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
# reload the model
|
# reload the model
|
||||||
model_config = trains_in_model.config_dict
|
model_config = trains_in_model.config_dict
|
||||||
# verify that this is the same model so we are not deserializing a diff model
|
# verify that this is the same model so we are not deserializing a different model
|
||||||
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
||||||
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
||||||
(not config_dict and not model_config):
|
(not config_dict and not model_config):
|
||||||
filepath = trains_in_model.get_weights()
|
filepath = trains_in_model.get_weights()
|
||||||
# update filepath to point to downloaded weights file
|
# update filepath to point to downloaded weights file
|
||||||
# actual model weights loading will be done outside the try/exception block
|
# actual model weights loading will be done outside the try/exception block
|
||||||
|
|
||||||
|
# update back the internal Model lookup, and replace the local file with our file
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
Model._local_model_to_id_uri[model_info.local_model_id] = (
|
||||||
|
trains_in_model.id, trains_in_model.url)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
||||||
finally:
|
finally:
|
||||||
@ -195,6 +271,7 @@ class WeightsFileHandler(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
||||||
|
# type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str
|
||||||
if task is None:
|
if task is None:
|
||||||
return saved_path
|
return saved_path
|
||||||
|
|
||||||
@ -205,55 +282,70 @@ class WeightsFileHandler(object):
|
|||||||
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
|
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(
|
||||||
id(model) if model is not None else None, (None, None))
|
id(model) if model is not None else None, (None, None))
|
||||||
# notice ref_model() is not an error/typo this is a weakref object call
|
# notice ref_model() is not an error/typo this is a weakref object call
|
||||||
|
# noinspection PyCallingNonCallable
|
||||||
if ref_model is not None and model != ref_model():
|
if ref_model is not None and model != ref_model():
|
||||||
# old id pop it - it was probably reused because the object is dead
|
# old id pop it - it was probably reused because the object is dead
|
||||||
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
||||||
trains_out_model, ref_model = None, None
|
trains_out_model, ref_model = None, None
|
||||||
|
|
||||||
if not saved_path:
|
model_info = WeightsFileHandler.ModelInfo(
|
||||||
|
model=trains_out_model, upload_filename=None, local_model_path=saved_path,
|
||||||
|
local_model_id=saved_path, framework=framework, task=task)
|
||||||
|
|
||||||
|
if not model_info.local_model_path:
|
||||||
get_logger(TrainsFrameworkAdapter).warning(
|
get_logger(TrainsFrameworkAdapter).warning(
|
||||||
"Could not retrieve model location, skipping auto model logging")
|
"Could not retrieve model location, skipping auto model logging")
|
||||||
return saved_path
|
return saved_path
|
||||||
|
|
||||||
# check if we have output storage, and generate list of files to upload
|
# check if we have output storage, and generate list of files to upload
|
||||||
if Path(saved_path).is_dir():
|
if Path(model_info.local_model_path).is_dir():
|
||||||
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
files = [str(f) for f in Path(model_info.local_model_path).rglob('*') if f.is_file()]
|
||||||
elif singlefile:
|
elif singlefile:
|
||||||
files = [str(Path(saved_path).absolute())]
|
files = [str(Path(model_info.local_model_path).absolute())]
|
||||||
else:
|
else:
|
||||||
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
|
files = [str(f) for f in Path(model_info.local_model_path).parent.glob(
|
||||||
|
str(Path(model_info.local_model_path).name) + '.*')]
|
||||||
|
|
||||||
target_filename = None
|
target_filename = None
|
||||||
if len(files) > 1:
|
if len(files) > 1:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
target_filename = Path(saved_path).stem
|
target_filename = Path(model_info.local_model_path).stem
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
target_filename = Path(files[0]).name
|
target_filename = Path(files[0]).name
|
||||||
|
|
||||||
# call pre model callback functions
|
# call pre model callback functions
|
||||||
|
model_info.upload_filename = target_filename
|
||||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
target_filename = cb('save', target_filename, framework, task)
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_info = cb('save', model_info)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
trains_out_model = model_info.model
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
if trains_out_model is None:
|
if trains_out_model is None:
|
||||||
in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None))
|
# noinspection PyProtectedMember
|
||||||
|
in_model_id, model_uri = Model._local_model_to_id_uri.get(
|
||||||
|
model_info.local_model_id or model_info.local_model_path, (None, None))
|
||||||
|
|
||||||
if not in_model_id:
|
if not in_model_id:
|
||||||
# if we are overwriting a local file, try to load registered model
|
# if we are overwriting a local file, try to load registered model
|
||||||
# if there is an output_uri, then by definition we will not overwrite previously stored models.
|
# if there is an output_uri, then by definition we will not overwrite previously stored models.
|
||||||
if not task.output_uri:
|
if not task.output_uri:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
in_model_id = InputModel.load_model(weights_url=saved_path)
|
in_model_id = InputModel.load_model(weights_url=model_info.local_model_path)
|
||||||
if in_model_id:
|
if in_model_id:
|
||||||
in_model_id = in_model_id.id
|
in_model_id = in_model_id.id
|
||||||
|
|
||||||
get_logger(TrainsFrameworkAdapter).info(
|
get_logger(TrainsFrameworkAdapter).info(
|
||||||
"Found existing registered model id={} [{}] reusing it.".format(
|
"Found existing registered model id={} [{}] reusing it.".format(
|
||||||
in_model_id, saved_path))
|
in_model_id, model_info.local_model_path))
|
||||||
except:
|
except Exception:
|
||||||
in_model_id = None
|
in_model_id = None
|
||||||
else:
|
else:
|
||||||
in_model_id = None
|
in_model_id = None
|
||||||
@ -274,9 +366,16 @@ class WeightsFileHandler(object):
|
|||||||
ref_model = None
|
ref_model = None
|
||||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||||
|
|
||||||
|
model_info.model = trains_out_model
|
||||||
# call post model callback functions
|
# call post model callback functions
|
||||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||||
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_info = cb('save', model_info)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
trains_out_model = model_info.model
|
||||||
|
target_filename = model_info.upload_filename
|
||||||
|
|
||||||
# upload files if we found them, or just register the original path
|
# upload files if we found them, or just register the original path
|
||||||
if trains_out_model.upload_storage_uri:
|
if trains_out_model.upload_storage_uri:
|
||||||
@ -294,9 +393,16 @@ class WeightsFileHandler(object):
|
|||||||
os.close(fd)
|
os.close(fd)
|
||||||
shutil.copy(files[0], temp_file)
|
shutil.copy(files[0], temp_file)
|
||||||
trains_out_model.update_weights(
|
trains_out_model.update_weights(
|
||||||
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
|
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename,
|
||||||
|
update_comment=False)
|
||||||
else:
|
else:
|
||||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
trains_out_model.update_weights(weights_filename=None, register_uri=model_info.local_model_path)
|
||||||
|
|
||||||
|
# update back the internal Model lookup, and replace the local file with our file
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
Model._local_model_to_id_uri[model_info.local_model_id] = (
|
||||||
|
trains_out_model.id, trains_out_model.url)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
get_logger(TrainsFrameworkAdapter).debug(str(ex))
|
||||||
finally:
|
finally:
|
||||||
|
Loading…
Reference in New Issue
Block a user