Add callback for framework save/load binding

This commit is contained in:
allegroai 2020-05-31 12:06:15 +03:00
parent b865fc0072
commit f86198bbe5

View File

@ -3,15 +3,20 @@ import shutil
import sys import sys
import threading import threading
import weakref import weakref
from random import randint
from tempfile import mkstemp from tempfile import mkstemp
from typing import TYPE_CHECKING, Callable, Union
import six import six
from pathlib2 import Path from pathlib2 import Path
from ...debugging.log import get_logger
from ...config import running_remotely
from ...model import InputModel, OutputModel
from ...backend_interface.model import Model from ...backend_interface.model import Model
from ...config import running_remotely
from ...debugging.log import get_logger
from ...model import InputModel, OutputModel
if TYPE_CHECKING:
from ...task import Task
TrainsFrameworkAdapter = 'frameworks' TrainsFrameworkAdapter = 'frameworks'
_recursion_guard = {} _recursion_guard = {}
@ -47,12 +52,60 @@ class WeightsFileHandler(object):
_model_out_store_lookup = {} _model_out_store_lookup = {}
_model_in_store_lookup = {} _model_in_store_lookup = {}
_model_store_lookup_lock = threading.Lock() _model_store_lookup_lock = threading.Lock()
_model_pre_callbacks = {}
_model_post_callbacks = {}
@staticmethod
def add_pre_callback(callback_function):
# type: (Callable[[Union['load', 'save'], str, str, Task], str]) -> int
if callback_function in WeightsFileHandler._model_pre_callbacks.values():
return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
while True:
h = randint(0, 1 << 31)
if h not in WeightsFileHandler._model_pre_callbacks:
break
WeightsFileHandler._model_pre_callbacks[h] = callback_function
return h
@staticmethod
def add_post_callback(callback_function):
# type: (Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]) -> int
if callback_function in WeightsFileHandler._model_post_callbacks.values():
return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
while True:
h = randint(0, 1 << 31)
if h not in WeightsFileHandler._model_post_callbacks:
break
WeightsFileHandler._model_post_callbacks[h] = callback_function
return h
@staticmethod
def remove_pre_callback(handle):
# type: (int) -> bool
if handle in WeightsFileHandler._model_pre_callbacks:
WeightsFileHandler._model_pre_callbacks.pop(handle, None)
return True
return False
@staticmethod
def remove_post_callback(handle):
# type: (int) -> bool
if handle in WeightsFileHandler._model_post_callbacks:
WeightsFileHandler._model_post_callbacks.pop(handle, None)
return True
return False
@staticmethod @staticmethod
def restore_weights_file(model, filepath, framework, task): def restore_weights_file(model, filepath, framework, task):
if task is None: if task is None:
return filepath return filepath
# call pre model callback functions
for cb in WeightsFileHandler._model_pre_callbacks.values():
filepath = cb('load', filepath, framework, task)
if not filepath: if not filepath:
get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged") get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
return filepath return filepath
@ -102,6 +155,10 @@ class WeightsFileHandler(object):
create_as_published=False, create_as_published=False,
) )
# call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values():
trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
ref_model = weakref.ref(model) ref_model = weakref.ref(model)
@ -167,7 +224,11 @@ class WeightsFileHandler(object):
except Exception: except Exception:
pass pass
else: else:
target_filename = files[0] target_filename = Path(files[0]).name
# call pre model callback functions
for cb in WeightsFileHandler._model_pre_callbacks.values():
target_filename = cb('save', target_filename, framework, task)
# check if object already has InputModel # check if object already has InputModel
if trains_out_model is None: if trains_out_model is None:
@ -191,7 +252,9 @@ class WeightsFileHandler(object):
# config_dict=config, # config_dict=config,
name=(task.name + ' - ' + model_name) if model_name else None, name=(task.name + ' - ' + model_name) if model_name else None,
label_enumeration=task.get_labels_enumeration(), label_enumeration=task.get_labels_enumeration(),
framework=framework, base_model_id=in_model_id) framework=framework,
base_model_id=in_model_id
)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
ref_model = weakref.ref(model) ref_model = weakref.ref(model)
@ -199,23 +262,27 @@ class WeightsFileHandler(object):
ref_model = None ref_model = None
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
# call post model callback functions
for cb in WeightsFileHandler._model_post_callbacks.values():
trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
# upload files if we found them, or just register the original path # upload files if we found them, or just register the original path
if trains_out_model.upload_storage_uri: if trains_out_model.upload_storage_uri:
if len(files) > 1: if len(files) > 1:
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, trains_out_model.update_weights_package(
target_filename=target_filename) weights_filenames=files, auto_delete_file=False, target_filename=target_filename)
else: else:
# create a copy of the stored file, # create a copy of the stored file,
# protect against someone deletes/renames the file before async upload finish is done # protect against someone deletes/renames the file before async upload finish is done
target_filename = Path(files[0]).name
# HACK: if pytorch-lightning is used, remove the temp '.part' file extension # HACK: if pytorch-lightning is used, remove the temp '.part' file extension
if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'): if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'):
target_filename = target_filename[:-len('.part')] target_filename = target_filename[:-len('.part')]
fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp') fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp')
os.close(fd) os.close(fd)
shutil.copy(files[0], temp_file) shutil.copy(files[0], temp_file)
trains_out_model.update_weights(weights_filename=temp_file, auto_delete_file=True, trains_out_model.update_weights(
target_filename=target_filename) weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
else: else:
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
except Exception as ex: except Exception as ex: