mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Add callback for framework save/load binding
This commit is contained in:
parent
b865fc0072
commit
f86198bbe5
@ -3,15 +3,20 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
|
from random import randint
|
||||||
from tempfile import mkstemp
|
from tempfile import mkstemp
|
||||||
|
from typing import TYPE_CHECKING, Callable, Union
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from ...debugging.log import get_logger
|
|
||||||
from ...config import running_remotely
|
|
||||||
from ...model import InputModel, OutputModel
|
|
||||||
from ...backend_interface.model import Model
|
from ...backend_interface.model import Model
|
||||||
|
from ...config import running_remotely
|
||||||
|
from ...debugging.log import get_logger
|
||||||
|
from ...model import InputModel, OutputModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...task import Task
|
||||||
|
|
||||||
TrainsFrameworkAdapter = 'frameworks'
|
TrainsFrameworkAdapter = 'frameworks'
|
||||||
_recursion_guard = {}
|
_recursion_guard = {}
|
||||||
@ -47,12 +52,60 @@ class WeightsFileHandler(object):
|
|||||||
_model_out_store_lookup = {}
|
_model_out_store_lookup = {}
|
||||||
_model_in_store_lookup = {}
|
_model_in_store_lookup = {}
|
||||||
_model_store_lookup_lock = threading.Lock()
|
_model_store_lookup_lock = threading.Lock()
|
||||||
|
_model_pre_callbacks = {}
|
||||||
|
_model_post_callbacks = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_pre_callback(callback_function):
|
||||||
|
# type: (Callable[[Union['load', 'save'], str, str, Task], str]) -> int
|
||||||
|
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
|
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
h = randint(0, 1 << 31)
|
||||||
|
if h not in WeightsFileHandler._model_pre_callbacks:
|
||||||
|
break
|
||||||
|
WeightsFileHandler._model_pre_callbacks[h] = callback_function
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_post_callback(callback_function):
|
||||||
|
# type: (Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]) -> int
|
||||||
|
if callback_function in WeightsFileHandler._model_post_callbacks.values():
|
||||||
|
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
h = randint(0, 1 << 31)
|
||||||
|
if h not in WeightsFileHandler._model_post_callbacks:
|
||||||
|
break
|
||||||
|
WeightsFileHandler._model_post_callbacks[h] = callback_function
|
||||||
|
return h
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_pre_callback(handle):
|
||||||
|
# type: (int) -> bool
|
||||||
|
if handle in WeightsFileHandler._model_pre_callbacks:
|
||||||
|
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_post_callback(handle):
|
||||||
|
# type: (int) -> bool
|
||||||
|
if handle in WeightsFileHandler._model_post_callbacks:
|
||||||
|
WeightsFileHandler._model_post_callbacks.pop(handle, None)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def restore_weights_file(model, filepath, framework, task):
|
def restore_weights_file(model, filepath, framework, task):
|
||||||
if task is None:
|
if task is None:
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
|
# call pre model callback functions
|
||||||
|
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
|
filepath = cb('load', filepath, framework, task)
|
||||||
|
|
||||||
if not filepath:
|
if not filepath:
|
||||||
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
|
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
|
||||||
return filepath
|
return filepath
|
||||||
@ -102,6 +155,10 @@ class WeightsFileHandler(object):
|
|||||||
create_as_published=False,
|
create_as_published=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# call post model callback functions
|
||||||
|
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||||
|
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
ref_model = weakref.ref(model)
|
ref_model = weakref.ref(model)
|
||||||
@ -167,7 +224,11 @@ class WeightsFileHandler(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
target_filename = files[0]
|
target_filename = Path(files[0]).name
|
||||||
|
|
||||||
|
# call pre model callback functions
|
||||||
|
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||||
|
target_filename = cb('save', target_filename, framework, task)
|
||||||
|
|
||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
if trains_out_model is None:
|
if trains_out_model is None:
|
||||||
@ -191,7 +252,9 @@ class WeightsFileHandler(object):
|
|||||||
# config_dict=config,
|
# config_dict=config,
|
||||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||||
label_enumeration=task.get_labels_enumeration(),
|
label_enumeration=task.get_labels_enumeration(),
|
||||||
framework=framework, base_model_id=in_model_id)
|
framework=framework,
|
||||||
|
base_model_id=in_model_id
|
||||||
|
)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
ref_model = weakref.ref(model)
|
ref_model = weakref.ref(model)
|
||||||
@ -199,23 +262,27 @@ class WeightsFileHandler(object):
|
|||||||
ref_model = None
|
ref_model = None
|
||||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||||
|
|
||||||
|
# call post model callback functions
|
||||||
|
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||||
|
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
|
||||||
|
|
||||||
# upload files if we found them, or just register the original path
|
# upload files if we found them, or just register the original path
|
||||||
if trains_out_model.upload_storage_uri:
|
if trains_out_model.upload_storage_uri:
|
||||||
if len(files) > 1:
|
if len(files) > 1:
|
||||||
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
trains_out_model.update_weights_package(
|
||||||
target_filename=target_filename)
|
weights_filenames=files, auto_delete_file=False, target_filename=target_filename)
|
||||||
else:
|
else:
|
||||||
# create a copy of the stored file,
|
# create a copy of the stored file,
|
||||||
# protect against someone deletes/renames the file before async upload finish is done
|
# protect against someone deletes/renames the file before async upload finish is done
|
||||||
target_filename = Path(files[0]).name
|
|
||||||
# HACK: if pytorch-lightning is used, remove the temp '.part' file extension
|
# HACK: if pytorch-lightning is used, remove the temp '.part' file extension
|
||||||
if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'):
|
if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'):
|
||||||
target_filename = target_filename[:-len('.part')]
|
target_filename = target_filename[:-len('.part')]
|
||||||
fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp')
|
fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp')
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
shutil.copy(files[0], temp_file)
|
shutil.copy(files[0], temp_file)
|
||||||
trains_out_model.update_weights(weights_filename=temp_file, auto_delete_file=True,
|
trains_out_model.update_weights(
|
||||||
target_filename=target_filename)
|
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
|
||||||
else:
|
else:
|
||||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
Loading…
Reference in New Issue
Block a user