diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index babe178a..d7276b0d 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -5,7 +5,7 @@ import threading import weakref from random import randint from tempfile import mkstemp -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Any import six from pathlib2 import Path @@ -24,6 +24,7 @@ _recursion_guard = {} def _patched_call(original_fn, patched_fn): def _inner_patch(*args, **kwargs): + # noinspection PyProtectedMember,PyUnresolvedReferences ident = threading._get_ident() if six.PY2 else threading.get_ident() if ident in _recursion_guard: return original_fn(*args, **kwargs) @@ -55,62 +56,118 @@ class WeightsFileHandler(object): _model_pre_callbacks = {} _model_post_callbacks = {} - @staticmethod - def add_pre_callback(callback_function): - # type: (Callable[[str, str, str, Task], str]) -> int + class ModelInfo(object): + def __init__(self, model, upload_filename, local_model_path, local_model_id, framework, task): + # type: (Optional[Model], Optional[str], str, str, str, Task) -> None + """ + :param model: None, OutputModel or InputModel + :param upload_filename: example 'filename.ext' + :param local_model_path: example /local/copy/filename.random_number.ext' + :param local_model_id: example /local/copy/filename.ext' + :param framework: example 'PyTorch' + :param task: Task object + """ + self.model = model + self.upload_filename = upload_filename + self.local_model_path = local_model_path + self.local_model_id = local_model_id + self.framework = framework + self.task = task - # callback is Callable[[Union['load', 'save'], str, str, Task], str] - if callback_function in WeightsFileHandler._model_pre_callbacks.values(): - return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0] + @staticmethod + def _add_callback(func, target): + # type: (Callable, Dict[int, Callable]) -> int + + if func in target.values(): + return [k for k, v in target.items() if v == func][0] while True: h = randint(0, 1 << 31) - if h not in WeightsFileHandler._model_pre_callbacks: + if h not in target: break - WeightsFileHandler._model_pre_callbacks[h] = callback_function + + target[h] = func return h @staticmethod - def add_post_callback(callback_function): - # type: (Callable[[str, Model, str, str, str, Task], Model]) -> int - - # callback is Callable[[Union['load', 'save'], Model, str, str, str, Task], Model] - if callback_function in WeightsFileHandler._model_post_callbacks.values(): - return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0] - - while True: - h = randint(0, 1 << 31) - if h not in WeightsFileHandler._model_post_callbacks: - break - WeightsFileHandler._model_post_callbacks[h] = callback_function - return h - - @staticmethod - def remove_pre_callback(handle): - # type: (int) -> bool - if handle in WeightsFileHandler._model_pre_callbacks: - WeightsFileHandler._model_pre_callbacks.pop(handle, None) + def _remove_callback(handle, target): + # type: (int, Dict[int, Callable]) -> bool + if handle in target: + target.pop(handle, None) return True return False - @staticmethod - def remove_post_callback(handle): + @classmethod + def add_pre_callback(cls, callback_function): + # type: (Callable[[str, ModelInfo], ModelInfo]) -> int + """ + Add a pre-save/load callback for weights files and return its handle. If the callback was already added, + return the existing handle. + + Use this callback to modify the weights filename registered in the Trains Server. In case Trains is + configured to upload the weights file, this will affect the uploaded filename as well. + + :param callback_function: A function accepting action type ("load" or "save"), + callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo + :return Callback handle + """ + return cls._add_callback(callback_function, cls._model_pre_callbacks) + + @classmethod + def add_post_callback(cls, callback_function): + # type: (Callable[[str, dict], dict]) -> int + """ + Add a post-save/load callback for weights files and return its handle. + If the callback was already added, return the existing handle. + + :param callback_function: A function accepting action type ("load" or "save"), + callback_function('load' or 'save', WeightsFileHandler.ModelInfo) -> WeightsFileHandler.ModelInfo + :return Callback handle + """ + return cls._add_callback(callback_function, cls._model_post_callbacks) + + @classmethod + def remove_pre_callback(cls, handle): # type: (int) -> bool - if handle in WeightsFileHandler._model_post_callbacks: - WeightsFileHandler._model_post_callbacks.pop(handle, None) - return True - return False + """ + Add a pre-save/load callback for weights files and return its handle. + If the callback was already added, return the existing handle. + + :param handle: A callback handle returned from :meth:WeightsFileHandler.add_pre_callback + :return True if callback removed, False otherwise + """ + return cls._remove_callback(handle, cls._model_pre_callbacks) + + @classmethod + def remove_post_callback(cls, handle): + # type: (int) -> bool + """ + Add a pre-save/load callback for weights files and return its handle. + If the callback was already added, return the existing handle. + + :param handle: A callback handle returned from :meth:WeightsFileHandler.add_post_callback + :return True if callback removed, False otherwise + """ + return cls._remove_callback(handle, cls._model_post_callbacks) @staticmethod def restore_weights_file(model, filepath, framework, task): + # type: (Optional[Any], Optional[str], Optional[str], Optional[Task]) -> str if task is None: return filepath + model_info = WeightsFileHandler.ModelInfo( + model=None, upload_filename=None, local_model_path=filepath, + local_model_id=filepath, framework=framework, task=task) # call pre model callback functions for cb in WeightsFileHandler._model_pre_callbacks.values(): - filepath = cb('load', filepath, framework, task) + # noinspection PyBroadException + try: + model_info = cb('load', model_info) + except Exception: + pass - if not filepath: + if not model_info.local_model_path: get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged") return filepath @@ -118,12 +175,16 @@ class WeightsFileHandler(object): WeightsFileHandler._model_store_lookup_lock.acquire() # check if object already has InputModel - trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get( - id(model) if model is not None else None, (None, None)) - if ref_model is not None and model != ref_model(): - # old id pop it - it was probably reused because the object is dead - WeightsFileHandler._model_in_store_lookup.pop(id(model)) - trains_in_model, ref_model = None, None + if model_info.model: + trains_in_model = model_info.model + else: + trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get( + id(model) if model is not None else None, (None, None)) + # noinspection PyCallingNonCallable + if ref_model is not None and model != ref_model(): + # old id pop it - it was probably reused because the object is dead + WeightsFileHandler._model_in_store_lookup.pop(id(model)) + trains_in_model, ref_model = None, None # check if object already has InputModel model_name_id = getattr(model, 'name', '') if model else '' @@ -139,30 +200,39 @@ class WeightsFileHandler(object): except Exception: config_text = None - # check if we already have the model object: - model_id, model_uri = Model._local_model_to_id_uri.get(filepath, (None, None)) - if model_id: - # noinspection PyBroadException - try: - trains_in_model = InputModel(model_id) - except Exception: - model_id = None + if not trains_in_model: + # check if we already have the model object: + # noinspection PyProtectedMember + model_id, model_uri = Model._local_model_to_id_uri.get( + model_info.local_model_id or model_info.local_model_path, (None, None)) + if model_id: + # noinspection PyBroadException + try: + trains_in_model = InputModel(model_id) + except Exception: + model_id = None - # if we do not, we need to import the model - if not model_id: - trains_in_model = InputModel.import_model( - weights_url=filepath, - config_dict=config_dict, - config_text=config_text, - name=task.name + (' ' + model_name_id) if model_name_id else '', - label_enumeration=task.get_labels_enumeration(), - framework=framework, - create_as_published=False, - ) + # if we do not, we need to import the model + if not model_id: + trains_in_model = InputModel.import_model( + weights_url=model_info.local_model_path, + config_dict=config_dict, + config_text=config_text, + name=task.name + (' ' + model_name_id) if model_name_id else '', + label_enumeration=task.get_labels_enumeration(), + framework=framework, + create_as_published=False, + ) + model_info.model = trains_in_model # call post model callback functions for cb in WeightsFileHandler._model_post_callbacks.values(): - trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task) + # noinspection PyBroadException + try: + model_info = cb('load', model_info) + except Exception: + pass + trains_in_model = model_info.model if model is not None: # noinspection PyBroadException @@ -179,13 +249,19 @@ class WeightsFileHandler(object): if False and running_remotely(): # reload the model model_config = trains_in_model.config_dict - # verify that this is the same model so we are not deserializing a diff model + # verify that this is the same model so we are not deserializing a different model if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and config_dict.get('config').get('name') == model_config.get('config').get('name')) or \ (not config_dict and not model_config): filepath = trains_in_model.get_weights() # update filepath to point to downloaded weights file # actual model weights loading will be done outside the try/exception block + + # update back the internal Model lookup, and replace the local file with our file + # noinspection PyProtectedMember + Model._local_model_to_id_uri[model_info.local_model_id] = ( + trains_in_model.id, trains_in_model.url) + except Exception as ex: get_logger(TrainsFrameworkAdapter).debug(str(ex)) finally: @@ -195,6 +271,7 @@ class WeightsFileHandler(object): @staticmethod def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None): + # type: (Optional[Any], Optional[str], Optional[str], Optional[Task], bool, Optional[str]) -> str if task is None: return saved_path @@ -205,55 +282,70 @@ class WeightsFileHandler(object): trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get( id(model) if model is not None else None, (None, None)) # notice ref_model() is not an error/typo this is a weakref object call + # noinspection PyCallingNonCallable if ref_model is not None and model != ref_model(): # old id pop it - it was probably reused because the object is dead WeightsFileHandler._model_out_store_lookup.pop(id(model)) trains_out_model, ref_model = None, None - if not saved_path: + model_info = WeightsFileHandler.ModelInfo( + model=trains_out_model, upload_filename=None, local_model_path=saved_path, + local_model_id=saved_path, framework=framework, task=task) + + if not model_info.local_model_path: get_logger(TrainsFrameworkAdapter).warning( "Could not retrieve model location, skipping auto model logging") return saved_path # check if we have output storage, and generate list of files to upload - if Path(saved_path).is_dir(): - files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()] + if Path(model_info.local_model_path).is_dir(): + files = [str(f) for f in Path(model_info.local_model_path).rglob('*') if f.is_file()] elif singlefile: - files = [str(Path(saved_path).absolute())] + files = [str(Path(model_info.local_model_path).absolute())] else: - files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')] + files = [str(f) for f in Path(model_info.local_model_path).parent.glob( + str(Path(model_info.local_model_path).name) + '.*')] target_filename = None if len(files) > 1: # noinspection PyBroadException try: - target_filename = Path(saved_path).stem + target_filename = Path(model_info.local_model_path).stem except Exception: pass else: target_filename = Path(files[0]).name # call pre model callback functions + model_info.upload_filename = target_filename for cb in WeightsFileHandler._model_pre_callbacks.values(): - target_filename = cb('save', target_filename, framework, task) + # noinspection PyBroadException + try: + model_info = cb('save', model_info) + except Exception: + pass + trains_out_model = model_info.model # check if object already has InputModel if trains_out_model is None: - in_model_id, model_uri = Model._local_model_to_id_uri.get(saved_path, (None, None)) + # noinspection PyProtectedMember + in_model_id, model_uri = Model._local_model_to_id_uri.get( + model_info.local_model_id or model_info.local_model_path, (None, None)) if not in_model_id: # if we are overwriting a local file, try to load registered model # if there is an output_uri, then by definition we will not overwrite previously stored models. if not task.output_uri: + # noinspection PyBroadException try: - in_model_id = InputModel.load_model(weights_url=saved_path) + in_model_id = InputModel.load_model(weights_url=model_info.local_model_path) if in_model_id: in_model_id = in_model_id.id get_logger(TrainsFrameworkAdapter).info( "Found existing registered model id={} [{}] reusing it.".format( - in_model_id, saved_path)) - except: + in_model_id, model_info.local_model_path)) + except Exception: in_model_id = None else: in_model_id = None @@ -274,9 +366,16 @@ class WeightsFileHandler(object): ref_model = None WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) + model_info.model = trains_out_model # call post model callback functions for cb in WeightsFileHandler._model_post_callbacks.values(): - trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task) + # noinspection PyBroadException + try: + model_info = cb('save', model_info) + except Exception: + pass + trains_out_model = model_info.model + target_filename = model_info.upload_filename # upload files if we found them, or just register the original path if trains_out_model.upload_storage_uri: @@ -294,9 +393,16 @@ class WeightsFileHandler(object): os.close(fd) shutil.copy(files[0], temp_file) trains_out_model.update_weights( - weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename) + weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename, + update_comment=False) else: - trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) + trains_out_model.update_weights(weights_filename=None, register_uri=model_info.local_model_path) + + # update back the internal Model lookup, and replace the local file with our file + # noinspection PyProtectedMember + Model._local_model_to_id_uri[model_info.local_model_id] = ( + trains_out_model.id, trains_out_model.url) + except Exception as ex: get_logger(TrainsFrameworkAdapter).debug(str(ex)) finally: