diff --git a/trains/binding/frameworks/__init__.py b/trains/binding/frameworks/__init__.py index b7937ea8..3e9dacbc 100644 --- a/trains/binding/frameworks/__init__.py +++ b/trains/binding/frameworks/__init__.py @@ -1,15 +1,15 @@ import threading import weakref -from logging import getLogger import six from pathlib2 import Path +from ...debugging.log import get_logger from ...config import running_remotely from ...model import InputModel, OutputModel from ...backend_interface.model import Model -TrainsFrameworkAdapter = 'TrainsFrameworkAdapter' +TrainsFrameworkAdapter = 'frameworks' _recursion_guard = {} @@ -50,7 +50,7 @@ class WeightsFileHandler(object): return filepath if not filepath: - getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored") + get_logger(TrainsFrameworkAdapter).debug("Could retrieve model file location, model is not logged") return filepath try: @@ -120,7 +120,7 @@ class WeightsFileHandler(object): # update filepath to point to downloaded weights file # actual model weights loading will be done outside the try/exception block except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + get_logger(TrainsFrameworkAdapter).debug(str(ex)) finally: WeightsFileHandler._model_store_lookup_lock.release() @@ -136,19 +136,57 @@ class WeightsFileHandler(object): # check if object already has InputModel trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None)) + # notice ref_model() is not an error/typo this is a weakref object call if ref_model is not None and model != ref_model(): # old id pop it - it was probably reused because the object is dead WeightsFileHandler._model_out_store_lookup.pop(id(model)) trains_out_model, ref_model = None, None + if not saved_path: + get_logger(TrainsFrameworkAdapter).warning("Could retrieve model location, skipping auto model logging") + return saved_path + + # check if we have output storage, and generate list of files to upload + if Path(saved_path).is_dir(): + files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()] + elif singlefile: + files = [str(Path(saved_path).absolute())] + else: + files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')] + + target_filename = None + if len(files) > 1: + # noinspection PyBroadException + try: + target_filename = Path(saved_path).stem + except Exception: + pass + else: + target_filename = files[0] + # check if object already has InputModel if trains_out_model is None: + # if we are overwriting a local file, try to load registered model + # if there is an output_uri, then by definition we will not overwrite previously stored models. + if not task.output_uri: + try: + in_model_id = InputModel.load_model(weights_url=saved_path) + if in_model_id: + in_model_id = in_model_id.id + get_logger(TrainsFrameworkAdapter).info( + "Found existing registered model id={} [{}] reusing it.".format( + in_model_id, saved_path)) + except: + in_model_id = None + else: + in_model_id = None + trains_out_model = OutputModel( task=task, # config_dict=config, name=(task.name + ' - ' + model_name) if model_name else None, label_enumeration=task.get_labels_enumeration(), - framework=framework, ) + framework=framework, base_model_id=in_model_id) # noinspection PyBroadException try: ref_model = weakref.ref(model) @@ -156,29 +194,9 @@ class WeightsFileHandler(object): ref_model = None WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model) - if not saved_path: - getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ") - return saved_path - - # check if we have output storage, and generate list of files to upload - if trains_out_model.upload_storage_uri: - if Path(saved_path).is_dir(): - files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()] - elif singlefile: - files = [str(Path(saved_path).absolute())] - else: - files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')] - else: - files = None - # upload files if we found them, or just register the original path - if files: + if trains_out_model.upload_storage_uri: if len(files) > 1: - # noinspection PyBroadException - try: - target_filename = Path(saved_path).stem - except Exception: - target_filename = None trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False, target_filename=target_filename) else: @@ -186,7 +204,7 @@ class WeightsFileHandler(object): else: trains_out_model.update_weights(weights_filename=None, register_uri=saved_path) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + get_logger(TrainsFrameworkAdapter).debug(str(ex)) finally: WeightsFileHandler._model_store_lookup_lock.release()