mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add callback for framework save/load binding
This commit is contained in:
		
							parent
							
								
									b865fc0072
								
							
						
					
					
						commit
						f86198bbe5
					
				@ -3,15 +3,20 @@ import shutil
 | 
			
		||||
import sys
 | 
			
		||||
import threading
 | 
			
		||||
import weakref
 | 
			
		||||
from random import randint
 | 
			
		||||
from tempfile import mkstemp
 | 
			
		||||
from typing import TYPE_CHECKING, Callable, Union
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
from pathlib2 import Path
 | 
			
		||||
 | 
			
		||||
from ...debugging.log import get_logger
 | 
			
		||||
from ...config import running_remotely
 | 
			
		||||
from ...model import InputModel, OutputModel
 | 
			
		||||
from ...backend_interface.model import Model
 | 
			
		||||
from ...config import running_remotely
 | 
			
		||||
from ...debugging.log import get_logger
 | 
			
		||||
from ...model import InputModel, OutputModel
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ...task import Task
 | 
			
		||||
 | 
			
		||||
TrainsFrameworkAdapter = 'frameworks'
 | 
			
		||||
_recursion_guard = {}
 | 
			
		||||
@ -47,12 +52,60 @@ class WeightsFileHandler(object):
 | 
			
		||||
    _model_out_store_lookup = {}
 | 
			
		||||
    _model_in_store_lookup = {}
 | 
			
		||||
    _model_store_lookup_lock = threading.Lock()
 | 
			
		||||
    _model_pre_callbacks = {}
 | 
			
		||||
    _model_post_callbacks = {}
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_pre_callback(callback_function):
 | 
			
		||||
        # type: (Callable[[Union['load', 'save'], str, str, Task], str]) -> int
 | 
			
		||||
        if callback_function in WeightsFileHandler._model_pre_callbacks.values():
 | 
			
		||||
            return [k for k, v in WeightsFileHandler._model_pre_callbacks.items() if v == callback_function][0]
 | 
			
		||||
 | 
			
		||||
        while True:
 | 
			
		||||
            h = randint(0, 1 << 31)
 | 
			
		||||
            if h not in WeightsFileHandler._model_pre_callbacks:
 | 
			
		||||
                break
 | 
			
		||||
        WeightsFileHandler._model_pre_callbacks[h] = callback_function
 | 
			
		||||
        return h
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_post_callback(callback_function):
 | 
			
		||||
        # type: (Callable[[Union['load', 'save'], Model, str, str, str, Task], Model]) -> int
 | 
			
		||||
        if callback_function in WeightsFileHandler._model_post_callbacks.values():
 | 
			
		||||
            return [k for k, v in WeightsFileHandler._model_post_callbacks.items() if v == callback_function][0]
 | 
			
		||||
 | 
			
		||||
        while True:
 | 
			
		||||
            h = randint(0, 1 << 31)
 | 
			
		||||
            if h not in WeightsFileHandler._model_post_callbacks:
 | 
			
		||||
                break
 | 
			
		||||
        WeightsFileHandler._model_post_callbacks[h] = callback_function
 | 
			
		||||
        return h
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def remove_pre_callback(handle):
 | 
			
		||||
        # type: (int) -> bool
 | 
			
		||||
        if handle in WeightsFileHandler._model_pre_callbacks:
 | 
			
		||||
            WeightsFileHandler._model_pre_callbacks.pop(handle, None)
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def remove_post_callback(handle):
 | 
			
		||||
        # type: (int) -> bool
 | 
			
		||||
        if handle in WeightsFileHandler._model_post_callbacks:
 | 
			
		||||
            WeightsFileHandler._model_post_callbacks.pop(handle, None)
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def restore_weights_file(model, filepath, framework, task):
 | 
			
		||||
        if task is None:
 | 
			
		||||
            return filepath
 | 
			
		||||
 | 
			
		||||
        # call pre model callback functions
 | 
			
		||||
        for cb in WeightsFileHandler._model_pre_callbacks.values():
 | 
			
		||||
            filepath = cb('load', filepath, framework, task)
 | 
			
		||||
 | 
			
		||||
        if not filepath:
 | 
			
		||||
            get_logger(TrainsFrameworkAdapter).debug("Could not retrieve model file location, model is not logged")
 | 
			
		||||
            return filepath
 | 
			
		||||
@ -102,6 +155,10 @@ class WeightsFileHandler(object):
 | 
			
		||||
                    create_as_published=False,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # call post model callback functions
 | 
			
		||||
            for cb in WeightsFileHandler._model_post_callbacks.values():
 | 
			
		||||
                trains_in_model = cb('load', trains_in_model, filepath, filepath, framework, task)
 | 
			
		||||
 | 
			
		||||
            # noinspection PyBroadException
 | 
			
		||||
            try:
 | 
			
		||||
                ref_model = weakref.ref(model)
 | 
			
		||||
@ -167,7 +224,11 @@ class WeightsFileHandler(object):
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass
 | 
			
		||||
            else:
 | 
			
		||||
                target_filename = files[0]
 | 
			
		||||
                target_filename = Path(files[0]).name
 | 
			
		||||
 | 
			
		||||
            # call pre model callback functions
 | 
			
		||||
            for cb in WeightsFileHandler._model_pre_callbacks.values():
 | 
			
		||||
                target_filename = cb('save', target_filename, framework, task)
 | 
			
		||||
 | 
			
		||||
            # check if object already has InputModel
 | 
			
		||||
            if trains_out_model is None:
 | 
			
		||||
@ -191,7 +252,9 @@ class WeightsFileHandler(object):
 | 
			
		||||
                    # config_dict=config,
 | 
			
		||||
                    name=(task.name + ' - ' + model_name) if model_name else None,
 | 
			
		||||
                    label_enumeration=task.get_labels_enumeration(),
 | 
			
		||||
                    framework=framework, base_model_id=in_model_id)
 | 
			
		||||
                    framework=framework,
 | 
			
		||||
                    base_model_id=in_model_id
 | 
			
		||||
                )
 | 
			
		||||
                # noinspection PyBroadException
 | 
			
		||||
                try:
 | 
			
		||||
                    ref_model = weakref.ref(model)
 | 
			
		||||
@ -199,23 +262,27 @@ class WeightsFileHandler(object):
 | 
			
		||||
                    ref_model = None
 | 
			
		||||
                WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
 | 
			
		||||
 | 
			
		||||
            # call post model callback functions
 | 
			
		||||
            for cb in WeightsFileHandler._model_post_callbacks.values():
 | 
			
		||||
                trains_out_model = cb('save', trains_out_model, target_filename, saved_path, framework, task)
 | 
			
		||||
 | 
			
		||||
            # upload files if we found them, or just register the original path
 | 
			
		||||
            if trains_out_model.upload_storage_uri:
 | 
			
		||||
                if len(files) > 1:
 | 
			
		||||
                    trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
 | 
			
		||||
                                                            target_filename=target_filename)
 | 
			
		||||
                    trains_out_model.update_weights_package(
 | 
			
		||||
                        weights_filenames=files, auto_delete_file=False, target_filename=target_filename)
 | 
			
		||||
                else:
 | 
			
		||||
                    # create a copy of the stored file,
 | 
			
		||||
                    # protect against someone deletes/renames the file before async upload finish is done
 | 
			
		||||
                    target_filename = Path(files[0]).name
 | 
			
		||||
 | 
			
		||||
                    # HACK: if pytorch-lightning is used, remove the temp '.part' file extension
 | 
			
		||||
                    if sys.modules.get('pytorch_lightning') and target_filename.lower().endswith('.part'):
 | 
			
		||||
                        target_filename = target_filename[:-len('.part')]
 | 
			
		||||
                    fd, temp_file = mkstemp(prefix='.trains.upload_model_', suffix='.tmp')
 | 
			
		||||
                    os.close(fd)
 | 
			
		||||
                    shutil.copy(files[0], temp_file)
 | 
			
		||||
                    trains_out_model.update_weights(weights_filename=temp_file, auto_delete_file=True,
 | 
			
		||||
                                                    target_filename=target_filename)
 | 
			
		||||
                    trains_out_model.update_weights(
 | 
			
		||||
                        weights_filename=temp_file, auto_delete_file=True, target_filename=target_filename)
 | 
			
		||||
            else:
 | 
			
		||||
                trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user