mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Add callback for framework save/load binding
This commit is contained in:
parent
b865fc0072
commit
f86198bbe5
@ -3,15 +3,20 @@ import shutil
|
||||
import sys
|
||||
import threading
|
||||
import weakref
|
||||
from random import randint
|
||||
from tempfile import mkstemp
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ...debugging.log import get_logger
|
||||
from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel
|
||||
from ...backend_interface.model import Model
|
||||
from ...config import running_remotely
|
||||
from ...debugging.log import get_logger
|
||||
from ...model import InputModel, OutputModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...task import Task
|
||||
|
||||
TrainsFrameworkAdapter = 'frameworks'
|
||||
_recursion_guard = {}
|
||||
@ -47,12 +52,60 @@ class WeightsFileHandler(object):
|
||||
_model_out_store_lookup = {}
|
||||
_model_in_store_lookup = {}
|
||||
_model_store_lookup_lock = threading.Lock()
|
||||
_model_pre_callbacks = {}
|
||||
_model_post_callbacks = {}
|
||||
|
||||
@staticmethod
|
||||
def add_pre_callback(callback_function):
|
||||
# type: (Callable[[Union['load', 'save'], str, str, Task], str]) -> int
|
||||
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
|
||||
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
|
||||
|
||||
while True:
|
||||
h = randint(0, 1 << 31)
|
||||
if h not in WeightsFileHandler._model_pre_callbacks:
|
||||
break
|
||||
WeightsFileHandler._model_pre_callbacks[h] = callback_function
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def add_post_callback(callback_function):
|
||||
# type: (Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]) -> int
|
||||
if callback_function in WeightsFileHandler._model_post_callbacks.values():
|
||||
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
|
||||
|
||||
while True:
|
||||
h = randint(0, 1 << 31)
|
||||
if h not in WeightsFileHandler._model_post_callbacks:
|
||||
break
|
||||
WeightsFileHandler._model_post_callbacks[h] = callback_function
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def remove_pre_callback(handle):
|
||||
# type: (int) -> bool
|
||||
if handle in WeightsFileHandler._model_pre_callbacks:
|
||||
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def remove_post_callback(handle):
|
||||
# type: (int) -> bool
|
||||
if handle in WeightsFileHandler._model_post_callbacks:
|
||||
WeightsFileHandler._model_post_callbacks.pop(handle, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def restore_weights_file(model, filepath, framework, task):
|
||||
if task is None:
|
||||
return filepath
|
||||
|
||||
# call pre model callback functions
|
||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||
filepath = cb('load', filepath, framework, task)
|
||||
|
||||
if not filepath:
|
||||
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
|
||||
return filepath
|
||||
@ -102,6 +155,10 @@ class WeightsFileHandler(object):
|
||||
create_as_published=False,
|
||||
)
|
||||
|
||||
# call post model callback functions
|
||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
@ -167,7 +224,11 @@ class WeightsFileHandler(object):
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
target_filename = files[0]
|
||||
target_filename = Path(files[0]).name
|
||||
|
||||
# call pre model callback functions
|
||||
for cb in WeightsFileHandler._model_pre_callbacks.values():
|
||||
target_filename = cb('save', target_filename, framework, task)
|
||||
|
||||
# check if object already has InputModel
|
||||
if trains_out_model is None:
|
||||
@ -191,7 +252,9 @@ class WeightsFileHandler(object):
|
||||
# config_dict=config,
|
||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||
label_enumeration=task.get_labels_enumeration(),
|
||||
framework=framework, base_model_id=in_model_id)
|
||||
framework=framework,
|
||||
base_model_id=in_model_id
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
ref_model = weakref.ref(model)
|
||||
@ -199,23 +262,27 @@ class WeightsFileHandler(object):
|
||||
ref_model = None
|
||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||
|
||||
# call post model callback functions
|
||||
for cb in WeightsFileHandler._model_post_callbacks.values():
|
||||
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
|
||||
|
||||
# upload files if we found them, or just register the original path
|
||||
if trains_out_model.upload_storage_uri:
|
||||
if len(files) > 1:
|
||||
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
||||
target_filename=target_filename)
|
||||
trains_out_model.update_weights_package(
|
||||
weights_filenames=files, auto_delete_file=False, target_filename=target_filename)
|
||||
else:
|
||||
# create a copy of the stored file,
|
||||
# protect against someone deletes/renames the file before async upload finish is done
|
||||
target_filename = Path(files[0]).name
|
||||
|
||||
# HACK: if pytorch-lightning is used, remove the temp '.part' file extension
|
||||
if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'):
|
||||
target_filename = target_filename[:-len('.part')]
|
||||
fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp')
|
||||
os.close(fd)
|
||||
shutil.copy(files[0], temp_file)
|
||||
trains_out_model.update_weights(weights_filename=temp_file, auto_delete_file=True,
|
||||
target_filename=target_filename)
|
||||
trains_out_model.update_weights(
|
||||
weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
|
||||
else:
|
||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||
except Exception as ex:
|
||||
|
Loading…
Reference in New Issue
Block a user