diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index de2c347b..8b36ef9c 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -3,15 +3,20 @@ import shutil import sys import threading import weakref +from random import randint from tempfile import mkstemp +from typing import TYPE_CHECKING, Callable, Union import six from pathlib2 import Path -from ...debugging.log import get_logger -from ...config import running_remotely -from ...model import InputModel, OutputModel from ...backend_interface.model import Model +from ...config import running_remotely +from ...debugging.log import get_logger +from ...model import InputModel, OutputModel + +if TYPE_CHECKING: + from ...task import Task TrainsFrameworkAdapter = 'frameworks' _recursion_guard = {} @@ -47,12 +52,60 @@ class WeightsFileHandler(object): _model_out_store_lookup = {} _model_in_store_lookup = {} _model_store_lookup_lock = threading.Lock() + _model_pre_callbacks = {} + _model_post_callbacks = {} + + @staticmethod + def add_pre_callback(callback_function): + # type: (Callable[[Union['load', 'save'], str, str, Task], str]) -> int + if callback_function in WeightsFileHandler._model_pre_callbacks.values(): + return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0] + + while True: + h = randint(0, 1 << 31) + if h not in WeightsFileHandler._model_pre_callbacks: + break + WeightsFileHandler._model_pre_callbacks[h] = callback_function + return h + + @staticmethod + def add_post_callback(callback_function): + # type: (Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]) -> int + if callback_function in WeightsFileHandler._model_post_callbacks.values(): + return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0] + + while True: + h = randint(0, 1 << 31) + if h not in WeightsFileHandler._model_post_callbacks: + break + WeightsFileHandler._model_post_callbacks[h] = callback_function + return h + + @staticmethod + def remove_pre_callback(handle): + # type: (int) -> bool + if handle in WeightsFileHandler._model_pre_callbacks: + WeightsFileHandler._model_pre_callbacks.pop(handle, None) + return True + return False + + @staticmethod + def remove_post_callback(handle): + # type: (int) -> bool + if handle in WeightsFileHandler._model_post_callbacks: + WeightsFileHandler._model_post_callbacks.pop(handle, None) + return True + return False @staticmethod def restore_weights_file(model, filepath, framework, task): if task is None: return filepath + # call pre model callback functions + for cb in WeightsFileHandler._model_pre_callbacks.values(): + filepath = cb('load', filepath, framework, task) + if not filepath: get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged") return filepath @@ -102,6 +155,10 @@ class WeightsFileHandler(object): create_as_published=False, ) + # call post model callback functions + for cb in WeightsFileHandler._model_post_callbacks.values(): + trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task) + # noinspection PyBroadException try: ref_model = weakref.ref(model) @@ -167,7 +224,11 @@ class WeightsFileHandler(object): except Exception: pass else: - target_filename = files[0] + target_filename = Path(files[0]).name + + # call pre model callback functions + for cb in WeightsFileHandler._model_pre_callbacks.values(): + target_filename = cb('save', target_filename, framework, task) # check if object already has InputModel if trains_out_model is None: @@ -191,7 +252,9 @@ class WeightsFileHandler(object): # config_dict=config, name=(task.name + ' - ' + model_name) if model_name else None, label_enumeration=task.get_labels_enumeration(), - framework=framework, base_model_id=in_model_id) + framework=framework, + base_model_id=in_model_id + ) # noinspection PyBroadException try: ref_model = weakref.ref(model) @@ -199,23 +262,27 @@ class WeightsFileHandler(object): ref_model = None WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) + # call post model callback functions + for cb in WeightsFileHandler._model_post_callbacks.values(): + trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task) + # upload files if we found them, or just register the original path if trains_out_model.upload_storage_uri: if len(files) > 1: - trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, - target_filename=target_filename) + trains_out_model.update_weights_package( + weights_filenames=files, auto_delete_file=False, target_filename=target_filename) else: # create a copy of the stored file, # protect against someone deletes/renames the file before async upload finish is done - target_filename = Path(files[0]).name + # HACK: if pytorch-lightning is used, remove the temp '.part' file extension if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'): target_filename = target_filename[:-len('.part')] fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp') os.close(fd) shutil.copy(files[0], temp_file) - trains_out_model.update_weights(weights_filename=temp_file, auto_delete_file=True, - target_filename=target_filename) + trains_out_model.update_weights( + weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename) else: trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) except Exception as ex: